Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 74 additions & 22 deletions internal/pkg/agent/application/upgrade/step_unpack.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,43 @@ type UnpackResult struct {
VersionedHome string `json:"versioned-home" yaml:"versioned-home"`
}

type copyFunc func(dst io.Writer, src io.Reader) (written int64, err error)
type mkdirAllFunc func(name string, perm fs.FileMode) error
type openFileFunc func(name string, flag int, perm fs.FileMode) (*os.File, error)
type unarchiveFunc func(log *logger.Logger, archivePath, dataDir string, flavor string, copy copyFunc, mkdirAll mkdirAllFunc, openFile openFileFunc) (UnpackResult, error)

type unpacker struct {
log *logger.Logger
// Abstractsions for testability
unzip unarchiveFunc
untar unarchiveFunc
// stdlib abstractions for testability
copy copyFunc
mkdirAll mkdirAllFunc
openFile openFileFunc
}

func newUnpacker(log *logger.Logger) *unpacker {
return &unpacker{
log: log,
unzip: unzip,
untar: untar,
copy: io.Copy,
mkdirAll: os.MkdirAll,
openFile: os.OpenFile,
}
}

// unpack unpacks archive correctly, skips root (symlink, config...) unpacks data/*
func (u *Upgrader) unpack(version, archivePath, dataDir string, flavor string) (UnpackResult, error) {
func (u *unpacker) unpack(version, archivePath, dataDir string, flavor string) (UnpackResult, error) {
// unpack must occur in directory that holds the installation directory
// or the extraction will be double nested
var unpackRes UnpackResult
var err error
if runtime.GOOS == windows {
unpackRes, err = unzip(u.log, archivePath, dataDir, flavor)
unpackRes, err = u.unzip(u.log, archivePath, dataDir, flavor, u.copy, u.mkdirAll, u.openFile)
} else {
unpackRes, err = untar(u.log, archivePath, dataDir, flavor)
unpackRes, err = u.untar(u.log, archivePath, dataDir, flavor, u.copy, u.mkdirAll, u.openFile)
}

if err != nil {
Expand All @@ -61,7 +88,7 @@ type packageMetadata struct {
hash string
}

func (u *Upgrader) getPackageMetadata(archivePath string) (packageMetadata, error) {
func (u *unpacker) getPackageMetadata(archivePath string) (packageMetadata, error) {
ext := filepath.Ext(archivePath)
if ext == ".gz" {
// if we got gzip extension we need another extension before last
Expand All @@ -78,7 +105,8 @@ func (u *Upgrader) getPackageMetadata(archivePath string) (packageMetadata, erro
}
}

func unzip(log *logger.Logger, archivePath, dataDir string, flavor string) (UnpackResult, error) {
// injecting copy, mkdirAll and openFile for testability
func unzip(log *logger.Logger, archivePath, dataDir string, flavor string, copy copyFunc, mkdirAll mkdirAllFunc, openFile openFileFunc) (UnpackResult, error) {
var hash, rootDir string
r, err := zip.OpenReader(archivePath)
if err != nil {
Expand Down Expand Up @@ -148,8 +176,10 @@ func unzip(log *logger.Logger, archivePath, dataDir string, flavor string) (Unpa
// check if the directory already exists
_, err = os.Stat(dstPath)
if errors.Is(err, fs.ErrNotExist) {
// the directory does not exist, create it and any non-existing parent directory with the same permissions
if err := os.MkdirAll(dstPath, f.Mode().Perm()&0770); err != nil {
// the directory does not exist, create it and any non-existing
// parent directory with the same permissions.
// Using mkdirAll instead of os.MkdirAll so that we can mock it in tests.
if err := mkdirAll(dstPath, f.Mode().Perm()&0770); err != nil {
return fmt.Errorf("creating directory %q: %w", dstPath, err)
}
} else if err != nil {
Expand All @@ -162,13 +192,23 @@ func unzip(log *logger.Logger, archivePath, dataDir string, flavor string) (Unpa
}
}

_ = os.MkdirAll(dstPath, f.Mode()&0770)
// Using mkdirAll instead of os.MkdirAll so that we can mock it in tests.
err = mkdirAll(dstPath, f.Mode()&0770)
if err != nil {
return fmt.Errorf("creating directory %q: %w", dstPath, err)
}
} else {
log.Debugw("Unpacking file", "archive", "zip", "file.path", dstPath)
// create non-existing containing folders with 0770 permissions right now, we'll fix the permission of each
// directory as we come across them while processing the other package entries
_ = os.MkdirAll(filepath.Dir(dstPath), 0770)
f, err := os.OpenFile(dstPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode()&0770)
// directory as we come across them while processing the other
// package entries
// Using mkdirAll instead of os.MkdirAll so that we can mock it in tests.
err = mkdirAll(filepath.Dir(dstPath), 0770)
if err != nil {
return fmt.Errorf("creating directory %q: %w", dstPath, err)
}
// Using openFile instead of os.OpenFile so that we can mock it in tests.
f, err := openFile(dstPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode()&0770)
if err != nil {
return err
}
Expand All @@ -178,7 +218,9 @@ func unzip(log *logger.Logger, archivePath, dataDir string, flavor string) (Unpa
}
}()

if _, err = io.Copy(f, rc); err != nil { //nolint:gosec // legacy
// Using copy instead of io.Copy so that we can
// mock it in tests.
if _, err = copy(f, rc); err != nil {
return err
}
}
Expand Down Expand Up @@ -313,7 +355,8 @@ func getPackageMetadataFromZipReader(r *zip.ReadCloser, fileNamePrefix string) (
return ret, nil
}

func untar(log *logger.Logger, archivePath, dataDir string, flavor string) (UnpackResult, error) {
// injecting copy, mkdirAll and openFile for testability
func untar(log *logger.Logger, archivePath, dataDir string, flavor string, copy copyFunc, mkdirAll mkdirAllFunc, openFile openFileFunc) (UnpackResult, error) {
var versionedHome string
var rootDir string
var hash string
Expand Down Expand Up @@ -413,17 +456,23 @@ func untar(log *logger.Logger, archivePath, dataDir string, flavor string) (Unpa
log.Debugw("Unpacking file", "archive", "tar", "file.path", abs)
// create non-existing containing folders with 0750 permissions right now, we'll fix the permission of each
// directory as we come across them while processing the other package entries
if err = os.MkdirAll(filepath.Dir(abs), 0750); err != nil {
return UnpackResult{}, errors.New(err, "TarInstaller: creating directory for file "+abs, errors.TypeFilesystem, errors.M(errors.MetaKeyPath, abs))
// Using mkdirAll instead of os.MkdirAll so that we can
// mock it in tests.
if err = mkdirAll(filepath.Dir(abs), 0750); err != nil {
return UnpackResult{}, goerrors.Join(err, errors.New("TarInstaller: creating directory for file "+abs, errors.TypeFilesystem, errors.M(errors.MetaKeyPath, abs)))
}

// remove any world permissions from the file
wf, err := os.OpenFile(abs, os.O_RDWR|os.O_CREATE|os.O_TRUNC, mode.Perm()&0770)
// Using openFile instead of os.OpenFile so that we can
// mock it in tests.
wf, err := openFile(abs, os.O_RDWR|os.O_CREATE|os.O_TRUNC, mode.Perm()&0770)
if err != nil {
return UnpackResult{}, errors.New(err, "TarInstaller: creating file "+abs, errors.TypeFilesystem, errors.M(errors.MetaKeyPath, abs))
return UnpackResult{}, goerrors.Join(err, errors.New("TarInstaller: creating file "+abs, errors.TypeFilesystem, errors.M(errors.MetaKeyPath, abs)))
}

_, err = io.Copy(wf, tr) //nolint:gosec // legacy
// Using copy instead of io.Copy so that we can
// mock it in tests.
_, err = copy(wf, tr)
if closeErr := wf.Close(); closeErr != nil && err == nil {
err = closeErr
}
Expand All @@ -435,17 +484,20 @@ func untar(log *logger.Logger, archivePath, dataDir string, flavor string) (Unpa
// check if the directory already exists
_, err = os.Stat(abs)
if errors.Is(err, fs.ErrNotExist) {
// the directory does not exist, create it and any non-existing parent directory with the same permissions
if err := os.MkdirAll(abs, mode.Perm()&0770); err != nil {
return UnpackResult{}, errors.New(err, "TarInstaller: creating directory for file "+abs, errors.TypeFilesystem, errors.M(errors.MetaKeyPath, abs))
// the directory does not exist, create it and any non-existing
// parent directory with the same permissions.
// Using mkdirAll instead of os.MkdirAll so that we can
// mock it in tests.
if err := mkdirAll(abs, mode.Perm()&0770); err != nil {
return UnpackResult{}, goerrors.Join(err, errors.New("TarInstaller: creating directory for file "+abs, errors.TypeFilesystem, errors.M(errors.MetaKeyPath, abs)))
}
} else if err != nil {
return UnpackResult{}, errors.New(err, "TarInstaller: stat() directory for file "+abs, errors.TypeFilesystem, errors.M(errors.MetaKeyPath, abs))
} else {
// directory already exists, set the appropriate permissions
err = os.Chmod(abs, mode.Perm()&0770)
if err != nil {
return UnpackResult{}, errors.New(err, fmt.Sprintf("TarInstaller: setting permissions %O for directory %q", mode.Perm()&0770, abs), errors.TypeFilesystem, errors.M(errors.MetaKeyPath, abs))
return UnpackResult{}, goerrors.Join(err, errors.New("TarInstaller: setting permissions %O for directory %q", mode.Perm()&0770, abs, errors.TypeFilesystem, errors.M(errors.MetaKeyPath, abs)))
}
}
default:
Expand Down
Loading
Loading