Fix the permission problem of sftp for non-root users.

Signed-off-by: 24sama <jacksama@foxmail.com>
This commit is contained in:
24sama 2022-08-26 10:43:48 +08:00
parent c0a84d60b0
commit ae1e83ffd8
7 changed files with 98 additions and 75 deletions

View File

@ -357,4 +357,4 @@ $(SETUP_ENVTEST): # Build setup-envtest from tools folder.
$(GOLANGCI_LINT): ../../.github/workflows/golangci-lint.yml # Download golangci-lint using hack script into tools folder.
hack/ensure-golangci-lint.sh \
-b $(TOOLS_BIN_DIR) \
$(shell cat .github/workflows/golangci-lint.yml | grep [[:space:]]version | sed 's/.*version: //')
$(shell cat ../../.github/workflows/golangci-lint.yml | grep [[:space:]]version | sed 's/.*version: //')

View File

@ -13,7 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Package main
package main
import (

View File

@ -24,7 +24,6 @@ import (
"os"
"path"
"path/filepath"
"regexp"
"strconv"
"strings"
"sync"
@ -44,6 +43,7 @@ import (
const (
DefaultSSHPort = 22
DefaultTimeout = 15
ROOT = "root"
)
// Client is a wrapper around the SSH client that provides a few helper.
@ -69,7 +69,7 @@ func NewClient(host string, auth infrav1.Auth, log *logr.Logger) Interface {
log = &l
}
if auth.User == "" {
auth.User = "root"
auth.User = ROOT
}
port := DefaultSSHPort
@ -123,48 +123,14 @@ func (c *Client) Connect() error {
// ConnectSftpClient connects to the host sftp client using the provided ssh information.
func (c *Client) ConnectSftpClient(opts ...sftp.ClientOption) error {
sess1, err := c.sshClient.NewSession()
if err != nil {
return err
}
defer sess1.Close()
var (
sftpClient *sftp.Client
err error
)
cmd := `grep -oP "Subsystem\s+sftp\s+\K.*" /etc/ssh/sshd_config`
buff, err := sess1.Output(cmd)
if err != nil {
return fmt.Errorf("cmd output errored %v", err)
}
sess2, err := c.sshClient.NewSession()
if err != nil {
return err
}
sftpServerPath := strings.ReplaceAll(string(buff), "\r", "")
if match, _ := regexp.MatchString(`^sudo `, sftpServerPath); !match {
sftpServerPath = "sudo" + " " + sftpServerPath
}
ok, err := sess2.SendRequest("exec", true, ssh.Marshal(struct{ Command string }{sftpServerPath}))
if err == nil && !ok {
return errors.New("ssh: exec request failed")
}
pw, err := sess2.StdinPipe()
if err != nil {
return err
}
pr, err := sess2.StdoutPipe()
if err != nil {
return err
}
sftpClient, err := sftp.NewClientPipe(pr, pw, opts...)
if err != nil {
return err
}
sftpClient, err = sftp.NewClient(c.sshClient, opts...)
c.sftpClient = sftpClient
return nil
return err
}
// Close closes the underlying ssh and sftp connection.
@ -360,34 +326,45 @@ func (c *Client) SudoCmdf(cmd string, a ...any) (string, error) {
// Copy copies a file to the remote host.
func (c *Client) Copy(src, dst string) error {
if c.user == ROOT {
return c.copy(src, dst)
}
return c.sudoCopy(src, dst)
}
func (c *Client) sudoCopy(src, dst string) error {
// scp to tmp dir
remoteTmp := filepath.Join("/tmp/kubekey", dst)
if err := c.copy(src, remoteTmp); err != nil {
return err
}
baseRemotePath := filepath.Dir(dst)
if err := c.mkdirAll(baseRemotePath, ""); err != nil {
return err
}
if _, err := c.SudoCmdf("mv -f %s %s", remoteTmp, dst); err != nil {
return errors.Wrapf(err, "[%s] mv -f %s %s failed", c.host, remoteTmp, dst)
}
if _, err := c.SudoCmd("rm -rf /tmp/kubekey*"); err != nil {
return errors.Wrapf(err, "[%s] rm -rf /tmp/kubekey* failed", c.host)
}
return nil
}
func (c *Client) copy(src, dst string) error {
baseRemoteFilePath := filepath.Dir(dst)
_ = c.mkdirAll(baseRemoteFilePath, "777")
if err := c.Connect(); err != nil {
return errors.Wrapf(err, "[%s] connect ssh client failed", c.host)
}
if err := c.ConnectSftpClient(); err != nil {
return errors.Wrapf(err, "[%s] connect sftp client failed", c.host)
}
defer c.sshClient.Close()
defer c.sftpClient.Close()
src = filepath.Clean(src)
f, err := os.Stat(src)
if err != nil {
return errors.Wrapf(err, "[%s] get file stat failed", c.host)
}
if f.IsDir() {
return errors.Wrapf(err, "[%s] the source %s is not a file", c.host, src)
}
baseRemoteFilePath := filepath.Dir(dst)
_, err = c.sftpClient.ReadDir(baseRemoteFilePath)
if err != nil {
if err = c.sftpClient.MkdirAll(baseRemoteFilePath); err != nil {
return err
}
}
if err := c.copyLocalFileToRemote(src, dst); err != nil {
return errors.Wrapf(err, "[%s] copy file failed", c.host)
}
@ -399,7 +376,8 @@ func (c *Client) copyLocalFileToRemote(src, dst string) error {
var (
srcMd5, dstMd5 string
)
srcMd5 = c.fs.MD5Sum(src)
cleanSrc := filepath.Clean(src)
srcMd5 = c.fs.MD5Sum(cleanSrc)
if exist, err := c.remoteFileExist(dst); err != nil {
return err
} else if exist {
@ -410,28 +388,35 @@ func (c *Client) copyLocalFileToRemote(src, dst string) error {
}
}
srcFile, err := os.Open(filepath.Clean(src))
srcFile, err := os.Open(cleanSrc)
if err != nil {
return err
return errors.Wrapf(err, "open local file %s failed", cleanSrc)
}
defer srcFile.Close()
// the dst file mod will be 0666
dstFile, err := c.sftpClient.Create(dst)
if err != nil {
return err
}
fileStat, err := srcFile.Stat()
if err != nil {
return fmt.Errorf("get file stat failed %v", err)
}
if fileStat.IsDir() {
return fmt.Errorf("the source %s is not a file", cleanSrc)
}
// the dst file mod will be 0666
dstFile, err := c.sftpClient.Create(dst)
if err != nil {
return errors.Wrapf(err, "[%s] create remote file %s failed", c.host, dst)
}
if err := dstFile.Chmod(fileStat.Mode()); err != nil {
return fmt.Errorf("chmod remote file failed %v", err)
}
defer dstFile.Close()
_, err = io.Copy(dstFile, srcFile)
if err != nil {
return err
return errors.Wrapf(err, "[%s] io copy file %s to remote %s failed", c.host, cleanSrc, dst)
}
dstMd5 = c.remoteMd5Sum(dst)
if srcMd5 != dstMd5 {
return fmt.Errorf("validate md5sum failed %s != %s", srcMd5, dstMd5)
@ -441,6 +426,32 @@ func (c *Client) copyLocalFileToRemote(src, dst string) error {
// Fetch fetches a file from the remote host.
func (c *Client) Fetch(local, remote string) error {
if c.user == ROOT {
return c.fetch(local, remote)
}
return c.sudoFetch(local, remote)
}
func (c *Client) sudoFetch(local, remote string) error {
remoteTmp := filepath.Join("/tmp/kubekey", filepath.Base(remote))
baseRemotePath := filepath.Dir(remoteTmp)
if err := c.mkdirAll(baseRemotePath, "777"); err != nil {
return err
}
if _, err := c.SudoCmdf("cp %s %s", remote, remoteTmp); err != nil {
return errors.Wrapf(err, "[%s] cp %s %s failed", c.host, remote, remoteTmp)
}
if err := c.fetch(local, remoteTmp); err != nil {
return err
}
if _, err := c.SudoCmd("rm -rf /tmp/kubekey*"); err != nil {
return errors.Wrapf(err, "[%s] rm -rf /tmp/kubekey* failed", c.host)
}
return nil
}
func (c *Client) fetch(local, remote string) error {
if err := c.Connect(); err != nil {
return errors.Wrapf(err, "[%s] connect ssh client failed", c.host)
}
@ -525,6 +536,16 @@ func (c *Client) remoteFileExist(remote string) (bool, error) {
return count != 0, nil
}
func (c *Client) mkdirAll(path, mode string) error {
if mode == "" {
mode = "775"
}
if _, err := c.SudoCmdf("mkdir -p -m %s %s", mode, path); err != nil {
return errors.Wrapf(err, "[%s] mkdir -p -m %s %s failed", c.host, mode, path)
}
return nil
}
// Ping checks if the remote host is reachable.
func (c *Client) Ping() error {
if err := c.Connect(); err != nil {

View File

@ -95,7 +95,7 @@ func (s *Service) CreateDirectory() error {
// ResetTmpDirectory resets the temporary "/tmp/kubekey" directory.
func (s *Service) ResetTmpDirectory() error {
dirService := s.getDirectoryService(directory.TmpDir, os.FileMode(filesystem.FileMode0755))
dirService := s.getDirectoryService(directory.TmpDir, os.FileMode(filesystem.FileMode0777))
if err := dirService.Remove(); err != nil {
return err
}

View File

@ -113,7 +113,7 @@ func (b *Binary) CompareChecksum() error {
}
if sum != b.checksum.Value() {
return errors.New(fmt.Sprintf("SHA256 no match. file: %s sha256: %s not equal checksum: %s", b.Name(), sum, b.checksum.Value()))
return errors.Errorf("SHA256 no match. file: %s sha256: %s not equal checksum: %s", b.Name(), sum, b.checksum.Value())
}
return nil
}

View File

@ -38,7 +38,7 @@ type actionFactory struct {
}
// NewActionFactory returns a new action factory.
func NewActionFactory(sshClient ssh.Interface) *actionFactory {
func NewActionFactory(sshClient ssh.Interface) *actionFactory { //nolint: golint
return &actionFactory{
sshClient: sshClient,
}

View File

@ -22,6 +22,8 @@ const (
)
const (
// FileMode0777 represents the file mode 0755
FileMode0777 = 0777
// FileMode0755 represents the file mode 0755
FileMode0755 = 0755
// FileMode0644 represents the file mode 0644