/* Copyright 2020 The KubeSphere Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package ssh import ( "bufio" "context" "fmt" kubekeyapiv1alpha1 "github.com/kubesphere/kubekey/api/v1alpha1" "io/ioutil" "net" "os" "strconv" "strings" "sync" "time" "github.com/pkg/errors" "github.com/pkg/sftp" "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/agent" ) const socketEnvPrefix = "env:" var ( _ Connection = &connection{} ) type Connection interface { Exec(cmd string, host *kubekeyapiv1alpha1.HostCfg) (stdout string, err error) Scp(src, dst string) error } type Cfg struct { Username string Password string Address string Port int PrivateKey string KeyFile string AgentSocket string Timeout time.Duration Bastion string BastionPort int BastionUser string } type connection struct { mu sync.Mutex sftpclient *sftp.Client sshclient *ssh.Client ctx context.Context cancel context.CancelFunc } func validateOptions(cfg Cfg) (Cfg, error) { if len(cfg.Username) == 0 { return cfg, errors.New("No username specified for SSH connection") } if len(cfg.Address) == 0 { return cfg, errors.New("No address specified for SSH connection") } if len(cfg.Password) == 0 && len(cfg.PrivateKey) == 0 && len(cfg.KeyFile) == 0 && len(cfg.AgentSocket) == 0 { return cfg, errors.New("Must specify at least one of password, private key, keyfile or agent socket") } if len(cfg.KeyFile) > 0 { content, err := ioutil.ReadFile(cfg.KeyFile) if err != nil { return cfg, errors.Wrapf(err, "Failed to read keyfile %q", cfg.KeyFile) } cfg.PrivateKey = string(content) cfg.KeyFile = "" } if cfg.Port <= 0 { cfg.Port = 22 } if cfg.BastionPort <= 0 { cfg.BastionPort = 22 } if cfg.BastionUser == "" { cfg.BastionUser = cfg.Username } if cfg.Timeout == 0 { cfg.Timeout = 60 * time.Second } return cfg, nil } func NewConnection(cfg Cfg) (Connection, error) { cfg, err := validateOptions(cfg) if err != nil { return nil, errors.Wrap(err, "Failed to validate ssh connection parameters") } authMethods := make([]ssh.AuthMethod, 0) if len(cfg.Password) > 0 { authMethods = append(authMethods, ssh.Password(cfg.Password)) } if len(cfg.PrivateKey) > 0 { signer, parseErr := ssh.ParsePrivateKey([]byte(cfg.PrivateKey)) if parseErr != nil { return nil, errors.Wrap(parseErr, "The given SSH key could not be parsed") } authMethods = append(authMethods, ssh.PublicKeys(signer)) } if len(cfg.AgentSocket) > 0 { addr := cfg.AgentSocket if strings.HasPrefix(cfg.AgentSocket, socketEnvPrefix) { envName := strings.TrimPrefix(cfg.AgentSocket, socketEnvPrefix) if envAddr := os.Getenv(envName); len(envAddr) > 0 { addr = envAddr } } socket, dialErr := net.Dial("unix", addr) if dialErr != nil { return nil, errors.Wrapf(dialErr, "could not open socket %q", addr) } agentClient := agent.NewClient(socket) signers, signersErr := agentClient.Signers() if signersErr != nil { _ = socket.Close() return nil, errors.Wrap(signersErr, "error when creating signer for SSH agent") } authMethods = append(authMethods, ssh.PublicKeys(signers...)) } sshConfig := &ssh.ClientConfig{ User: cfg.Username, Timeout: cfg.Timeout, Auth: authMethods, HostKeyCallback: ssh.InsecureIgnoreHostKey(), } targetHost := cfg.Address targetPort := strconv.Itoa(cfg.Port) if cfg.Bastion != "" { targetHost = cfg.Bastion targetPort = strconv.Itoa(cfg.BastionPort) sshConfig.User = cfg.BastionUser } endpoint := net.JoinHostPort(targetHost, targetPort) client, err := ssh.Dial("tcp", endpoint, sshConfig) if err != nil { return nil, errors.Wrapf(err, "could not establish connection to %s", endpoint) } ctx, cancelFn := context.WithCancel(context.Background()) sshConn := &connection{ ctx: ctx, cancel: cancelFn, } if cfg.Bastion == "" { sshConn.sshclient = client return sshConn, nil } endpointBehindBastion := net.JoinHostPort(cfg.Address, strconv.Itoa(cfg.Port)) conn, err := client.Dial("tcp", endpointBehindBastion) if err != nil { return nil, errors.Wrapf(err, "could not establish connection to %s", endpointBehindBastion) } sshConfig.User = cfg.Username ncc, chans, reqs, err := ssh.NewClientConn(conn, endpointBehindBastion, sshConfig) if err != nil { return nil, errors.Wrapf(err, "could not establish connection to %s", endpointBehindBastion) } sshConn.sshclient = ssh.NewClient(ncc, chans, reqs) return sshConn, nil } func (c *connection) Exec(cmd string, host *kubekeyapiv1alpha1.HostCfg) (string, error) { sess, err := c.session() if err != nil { return "", errors.Wrap(err, "Failed to get SSH session") } defer sess.Close() modes := ssh.TerminalModes{ ssh.ECHO: 0, // disable echoing ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud } err = sess.RequestPty("xterm", 100, 50, modes) if err != nil { return "", err } stdin, _ := sess.StdinPipe() out, _ := sess.StdoutPipe() var output []byte err = sess.Start(strings.TrimSpace(cmd)) if err != nil { return "", err } var ( line = "" r = bufio.NewReader(out) ) for { b, err := r.ReadByte() if err != nil { break } output = append(output, b) if b == byte('\n') { line = "" continue } line += string(b) if (strings.HasPrefix(line, "[sudo] password for ") || strings.HasPrefix(line, "Password")) && strings.HasSuffix(line, ": ") { _, err = stdin.Write([]byte(host.Password + "\n")) if err != nil { break } } } err = sess.Wait() outStr := strings.TrimPrefix(string(output), fmt.Sprintf("[sudo] password for %s:", host.User)) return strings.TrimSpace(outStr), errors.Wrapf(err, "Failed to exec command: %s \n%s", cmd, strings.TrimSpace(outStr)) } func (c *connection) session() (*ssh.Session, error) { c.mu.Lock() defer c.mu.Unlock() if c.sshclient == nil { return nil, errors.New("connection closed") } return c.sshclient.NewSession() }