kubekey/pkg/util/ssh/ssh.go
2020-05-26 22:25:05 +08:00

343 lines
8.2 KiB
Go

/*
Copyright 2020 The KubeSphere Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ssh
import (
"bufio"
"context"
"fmt"
"github.com/kubesphere/kubekey/pkg/apis/kubekey/v1alpha1"
"io"
"io/ioutil"
"net"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/pkg/errors"
"github.com/pkg/sftp"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
)
const socketEnvPrefix = "env:"
var (
_ Connection = &connection{}
_ Tunneler = &connection{}
)
// Connection represents an established connection to an SSH server.
type Connection interface {
Exec(cmd string, host *v1alpha1.HostCfg) (stdout string, exitCode int, err error)
File(filename string, flags int) (io.ReadWriteCloser, error)
Scp(src, dst string) error
io.Closer
}
// Tunneler interface creates net.Conn originating from the remote ssh host to
// target `addr`
type Tunneler interface {
// `network` can be tcp, tcp4, tcp6, unix
TunnelTo(ctx context.Context, network, addr string) (net.Conn, error)
}
// SSHCfg represents all the possible options for connecting to
// a remote server via SSH.
type SSHCfg struct {
Username string
Password string
Address string
Port int
PrivateKey string
KeyFile string
AgentSocket string
Timeout time.Duration
Bastion string
BastionPort int
BastionUser string
}
func validateOptions(cfg SSHCfg) (SSHCfg, error) {
if len(cfg.Username) == 0 {
return cfg, errors.New("No username specified for SSH connection")
}
if len(cfg.Address) == 0 {
return cfg, errors.New("No address specified for SSH connection")
}
if len(cfg.Password) == 0 && len(cfg.PrivateKey) == 0 && len(cfg.KeyFile) == 0 && len(cfg.AgentSocket) == 0 {
return cfg, errors.New("Must specify at least one of password, private key, keyfile or agent socket")
}
if len(cfg.KeyFile) > 0 {
content, err := ioutil.ReadFile(cfg.KeyFile)
if err != nil {
return cfg, errors.Wrapf(err, "Failed to read keyfile %q", cfg.KeyFile)
}
cfg.PrivateKey = string(content)
cfg.KeyFile = ""
}
if cfg.Port <= 0 {
cfg.Port = 22
}
if cfg.BastionPort <= 0 {
cfg.BastionPort = 22
}
if cfg.BastionUser == "" {
cfg.BastionUser = cfg.Username
}
if cfg.Timeout == 0 {
cfg.Timeout = 60 * time.Second
}
return cfg, nil
}
type connection struct {
mu sync.Mutex
sftpclient *sftp.Client
sshclient *ssh.Client
ctx context.Context
cancel context.CancelFunc
}
// NewConnection attempts to create a new SSH connection to the host
// specified via the given options.
func NewConnection(cfg SSHCfg) (Connection, error) {
cfg, err := validateOptions(cfg)
if err != nil {
return nil, errors.Wrap(err, "Failed to validate ssh connection options")
}
authMethods := make([]ssh.AuthMethod, 0)
if len(cfg.Password) > 0 {
authMethods = append(authMethods, ssh.Password(cfg.Password))
}
if len(cfg.PrivateKey) > 0 {
signer, parseErr := ssh.ParsePrivateKey([]byte(cfg.PrivateKey))
if parseErr != nil {
return nil, errors.Wrap(parseErr, "The given SSH key could not be parsed (note that password-protected keys are not supported)")
}
authMethods = append(authMethods, ssh.PublicKeys(signer))
}
if len(cfg.AgentSocket) > 0 {
addr := cfg.AgentSocket
if strings.HasPrefix(cfg.AgentSocket, socketEnvPrefix) {
envName := strings.TrimPrefix(cfg.AgentSocket, socketEnvPrefix)
if envAddr := os.Getenv(envName); len(envAddr) > 0 {
addr = envAddr
}
}
socket, dialErr := net.Dial("unix", addr)
if dialErr != nil {
return nil, errors.Wrapf(dialErr, "could not open socket %q", addr)
}
agentClient := agent.NewClient(socket)
signers, signersErr := agentClient.Signers()
if signersErr != nil {
socket.Close()
return nil, errors.Wrap(signersErr, "error when creating signer for SSH agent")
}
authMethods = append(authMethods, ssh.PublicKeys(signers...))
}
sshConfig := &ssh.ClientConfig{
User: cfg.Username,
Timeout: cfg.Timeout,
Auth: authMethods,
HostKeyCallback: ssh.InsecureIgnoreHostKey(), //nolint:gosec
}
targetHost := cfg.Address
targetPort := strconv.Itoa(cfg.Port)
if cfg.Bastion != "" {
targetHost = cfg.Bastion
targetPort = strconv.Itoa(cfg.BastionPort)
sshConfig.User = cfg.BastionUser
}
// do not use fmt.Sprintf() to allow proper IPv6 handling if hostname is an IP address
endpoint := net.JoinHostPort(targetHost, targetPort)
client, err := ssh.Dial("tcp", endpoint, sshConfig)
if err != nil {
return nil, errors.Wrapf(err, "could not establish connection to %s", endpoint)
}
ctx, cancelFn := context.WithCancel(context.Background())
sshConn := &connection{
ctx: ctx,
cancel: cancelFn,
}
if cfg.Bastion == "" {
sshConn.sshclient = client
// connection established
return sshConn, nil
}
// continue to setup if we are running over bastion
endpointBehindBastion := net.JoinHostPort(cfg.Address, strconv.Itoa(cfg.Port))
// Dial a connection to the service host, from the bastion
conn, err := client.Dial("tcp", endpointBehindBastion)
if err != nil {
return nil, errors.Wrapf(err, "could not establish connection to %s", endpointBehindBastion)
}
sshConfig.User = cfg.Username
ncc, chans, reqs, err := ssh.NewClientConn(conn, endpointBehindBastion, sshConfig)
if err != nil {
return nil, errors.Wrapf(err, "could not establish connection to %s", endpointBehindBastion)
}
sshConn.sshclient = ssh.NewClient(ncc, chans, reqs)
return sshConn, nil
}
// File return remote file (as an io.ReadWriteCloser).
//
// mode is os package file modes: https://golang.org/pkg/os/#pkg-constants
// returned file optionally implement
func (c *connection) File(filename string, flags int) (io.ReadWriteCloser, error) {
sftpClient, err := c.sftp()
if err != nil {
return nil, errors.Wrap(err, "Failed to open SFTP")
}
return sftpClient.OpenFile(filename, flags)
}
func (c *connection) TunnelTo(_ context.Context, network, addr string) (net.Conn, error) {
netconn, err := c.sshclient.Dial(network, addr)
if err == nil {
go func() {
<-c.ctx.Done()
netconn.Close()
}()
}
return netconn, err
}
func (c *connection) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.sshclient == nil {
return nil
}
c.cancel()
defer func() { c.sshclient = nil }()
defer func() { c.sftpclient = nil }()
return c.sshclient.Close()
}
func (c *connection) Exec(cmd string, host *v1alpha1.HostCfg) (string, int, error) {
sess, err := c.session()
if err != nil {
return "", 0, errors.Wrap(err, "Failed to get SSH session")
}
defer sess.Close()
modes := ssh.TerminalModes{
ssh.ECHO: 0, // disable echoing
ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud
ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud
}
err = sess.RequestPty("xterm", 100, 50, modes)
if err != nil {
return "", 0, err
}
stdin, _ := sess.StdinPipe()
out, _ := sess.StdoutPipe()
var output []byte
err = sess.Start(strings.TrimSpace(cmd))
if err != nil {
return "", 0, err
}
var (
line = ""
r = bufio.NewReader(out)
)
for {
b, err := r.ReadByte()
if err != nil {
break
}
output = append(output, b)
if b == byte('\n') {
line = ""
continue
}
line += string(b)
if (strings.HasPrefix(line, "[sudo] password for ") || strings.HasPrefix(line, "Password")) && strings.HasSuffix(line, ": ") {
_, err = stdin.Write([]byte(host.Password + "\n"))
if err != nil {
break
}
}
}
exitCode := 0
err = sess.Wait()
if err != nil {
exitCode = 1
}
outStr := strings.TrimPrefix(string(output), fmt.Sprintf("[sudo] password for %s:", host.User))
return strings.TrimSpace(outStr), exitCode, errors.Wrapf(err, "Failed to exec command: %s", cmd)
}
func (c *connection) session() (*ssh.Session, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.sshclient == nil {
return nil, errors.New("connection closed")
}
return c.sshclient.NewSession()
}