From bc2b0947bf2bd76cc6cdc256cf351474a73f43da Mon Sep 17 00:00:00 2001 From: gknw Date: Thu, 18 Feb 2021 21:29:43 +0500 Subject: [PATCH] Run SSH-commands with context --- sshwrapper.go | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/sshwrapper.go b/sshwrapper.go index 2f2d210..d6e3329 100644 --- a/sshwrapper.go +++ b/sshwrapper.go @@ -1,6 +1,7 @@ package sshwrapper import ( + "context" "fmt" "io" "net" @@ -152,7 +153,7 @@ func (s *SSHConn) CombinedOutput(cmd string, in io.Reader) ([]byte, error) { // Run runs cmd on the remote host. // // See https://godoc.org/golang.org/x/crypto/ssh#Session.Run for details. -func (s *SSHConn) Run(cmd string, in io.Reader, outWriter, errWriter io.Writer) error { +func (s *SSHConn) RunContext(ctx context.Context, cmd string, in io.Reader, outWriter, errWriter io.Writer) error { session, err := s.client.NewSession() if err != nil { return err @@ -172,10 +173,29 @@ func (s *SSHConn) Run(cmd string, in io.Reader, outWriter, errWriter io.Writer) session.Stdout = outWriter session.Stderr = errWriter session.Stdin = in + + exit := make(chan struct{}, 1) + defer close(exit) + + go func() { + select { + case <-ctx.Done(): + if ctx.Err() != nil { + session.Signal(ssh.SIGINT) + session.Close() + } + case <-exit: + } + }() + err = session.Run(cmd) return err } +func (s *SSHConn) Run(cmd string, in io.Reader, outWriter, errWriter io.Writer) error { + return s.RunContext(context.Background(), cmd, in, outWriter, errWriter) +} + // SetEnvs specifies the environment that will be applied // to any command executed by Output/CombinedOutput/Run. func (s *SSHConn) SetEnvs(e map[string]string) {