diff --git a/cmd/sup/main.go b/cmd/sup/main.go index e1f35ee..a4c0caa 100644 --- a/cmd/sup/main.go +++ b/cmd/sup/main.go @@ -137,6 +137,12 @@ func parseArgs(conf *sup.Supfile) (*sup.Network, []*sup.Command, error) { network.Env.Set(env[:i], env[i+1:]) } + bastion, err := network.ParseBastion() + if err != nil { + return nil, nil, err + } + network.Bastion = bastion + hosts, err := network.ParseInventory() if err != nil { return nil, nil, err diff --git a/supfile.go b/supfile.go index 2cf88b5..001570a 100644 --- a/supfile.go +++ b/supfile.go @@ -364,3 +364,37 @@ func (n Network) ParseInventory() ([]string, error) { } return hosts, nil } + +// ParseBastion returns the bastion if it is a valid connection string +// or runs the bastion command, if provided, and returns the first +// line of the command's output. +func (n Network) ParseBastion() (string, error) { + if n.Bastion == "" { + return "", nil + } + + // check if its a connection string + testConn := &SSHClient{} + if err := testConn.Connect(n.Bastion); err == nil { + return n.Bastion, nil + } + + cmd := exec.Command("/bin/sh", "-c", n.Bastion) + cmd.Env = os.Environ() + cmd.Env = append(cmd.Env, n.Env.Slice()...) + cmd.Stderr = os.Stderr + output, err := cmd.Output() + if err != nil { + return "", err + } + + buf := bytes.NewBuffer(output) + + bastion, err := buf.ReadString('\n') + if err != nil { + return "", err + } + + bastion = strings.TrimSpace(bastion) + return bastion, nil +}