Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions remote.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (

sshutil "github.com/aucloud/go-sshutil"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
)

type RemoteCmd struct {
Expand Down Expand Up @@ -107,6 +108,38 @@ func NewRemoteKeyAuthRunner(ctx context.Context, user, host, key string) (*Remot
return &Remote{client}, nil
}

func NewRemoteAgentAuthRunner(ctx context.Context, user, host, agentSocket string) (*Remote, error) {

if _, err := os.Stat(agentSocket); os.IsNotExist(err) {
return nil, fmt.Errorf("agent socket %s does not exist: %w", agentSocket, err)
}
agentConn, err := net.Dial("unix", agentSocket)
if err != nil {
return nil, fmt.Errorf("failed to open SSH agent socket %s: %v", agentSocket, err)
}
agentClient := agent.NewClient(agentConn)
config := &ssh.ClientConfig{
User: user,
// FIXME: This is insecure. We should verify RSA fingerprints of hosts...
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Auth: []ssh.AuthMethod{ssh.PublicKeysCallback(agentClient.Signers)},
}
addr, err := ResolveHostname(host)
if err != nil {
return nil, fmt.Errorf("failed to resolve hostname %s: %w", host, err)
}
client, err := sshutil.NewClient(
ctx,
sshutil.ConstantAddrResolver{addr},
config,
sshutil.DefaultConnectBackoff(),
)
if err != nil {
return nil, fmt.Errorf("failed to establish an SSH connection to %s: %w", host, err)
}
return &Remote{client}, nil
}

func NewRemoteKeyAuthRunnerViaJumphost(ctx context.Context, user, host, jumphost, key string) (*Remote, error) {
if _, err := os.Stat(key); os.IsNotExist(err) {
return nil, fmt.Errorf("error reading private ssh key %s: %w", key, err)
Expand Down
22 changes: 17 additions & 5 deletions runcmd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,19 @@ var (

/* FIXME: Mock an SSH server
func TestKeyAuth(t *testing.T) {
rRunner, err := NewRemoteKeyAuthRunner(user, host, key)
rRunner, err := NewRemoteKeyAuthRunner(context.TODO(), user, host, key)
if err != nil {
t.Error(err)
}
if err := testRun(rRunner); err != nil {
t.Error(err)
}
}
*/

/* FIXME: Mock an SSH server
func TestKeyAuth(t *testing.T) {
rRunner, err := NewRemoteAgentAuthRunner(context.TODO(), os.Getenv("USER"), "localhost", os.Getenv("SSH_AUTH_SOCK"))
if err != nil {
t.Error(err)
}
Expand All @@ -58,7 +70,7 @@ func TestPassAuth(t *testing.T) {
os.Exit(1)
}
}()
rRunner, err := NewRemotePassAuthRunner(user, host, pass)
rRunner, err := NewRemotePassAuthRunner(context.TODO(), user, host, pass)
if err != nil {
t.Error(err)
}
Expand All @@ -80,7 +92,7 @@ func TestLocalRun(t *testing.T) {

/* FIXME: Mock anSSH server
func TestRemoteRun(t *testing.T) {
rRunner, err := NewRemoteKeyAuthRunner(user, host, key)
rRunner, err := NewRemoteKeyAuthRunner(context.TODO(), user, host, key)
if err != nil {
t.Error(err)
}
Expand All @@ -102,7 +114,7 @@ func TestLocalStartWait(t *testing.T) {

/* FIXME: Mock an SSH server
func TestRemoteStartWait(t *testing.T) {
rRunner, err := NewRemoteKeyAuthRunner(user, host, key)
rRunner, err := NewRemoteKeyAuthRunner(context.TODO(), user, host, key)
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -242,7 +254,7 @@ func testPipe(localToRemote bool) error {
if err != nil {
return err
}
rRunner, err := NewRemoteKeyAuthRunner(user, host, key)
rRunner, err := NewRemoteKeyAuthRunner(context.TODO(), user, host, key)
if err != nil {
return err
}
Expand Down