diff --git a/internal/pkg/agent/application/upgrade/step_unpack.go b/internal/pkg/agent/application/upgrade/step_unpack.go index 78a56a51c20..eb5e22bd61f 100644 --- a/internal/pkg/agent/application/upgrade/step_unpack.go +++ b/internal/pkg/agent/application/upgrade/step_unpack.go @@ -33,16 +33,43 @@ type UnpackResult struct { VersionedHome string `json:"versioned-home" yaml:"versioned-home"` } +type copyFunc func(dst io.Writer, src io.Reader) (written int64, err error) +type mkdirAllFunc func(name string, perm fs.FileMode) error +type openFileFunc func(name string, flag int, perm fs.FileMode) (*os.File, error) +type unarchiveFunc func(log *logger.Logger, archivePath, dataDir string, copy copyFunc, mkdirAll mkdirAllFunc, openFile openFileFunc) (UnpackResult, error) + +type unpacker struct { + log *logger.Logger + // Abstractsions for testability + unzip unarchiveFunc + untar unarchiveFunc + // stdlib abstractions for testability + copy copyFunc + mkdirAll mkdirAllFunc + openFile openFileFunc +} + +func newUnpacker(log *logger.Logger) *unpacker { + return &unpacker{ + log: log, + unzip: unzip, + untar: untar, + copy: io.Copy, + mkdirAll: os.MkdirAll, + openFile: os.OpenFile, + } +} + // unpack unpacks archive correctly, skips root (symlink, config...) unpacks data/* -func (u *Upgrader) unpack(version, archivePath, dataDir string) (UnpackResult, error) { +func (u *unpacker) unpack(version, archivePath, dataDir string) (UnpackResult, error) { // unpack must occur in directory that holds the installation directory // or the extraction will be double nested var unpackRes UnpackResult var err error if runtime.GOOS == windows { - unpackRes, err = unzip(u.log, archivePath, dataDir) + unpackRes, err = u.unzip(u.log, archivePath, dataDir, u.copy, u.mkdirAll, u.openFile) } else { - unpackRes, err = untar(u.log, archivePath, dataDir) + unpackRes, err = u.untar(u.log, archivePath, dataDir, u.copy, u.mkdirAll, u.openFile) } if err != nil { @@ -59,7 +86,7 @@ type packageMetadata struct { hash string } -func (u *Upgrader) getPackageMetadata(archivePath string) (packageMetadata, error) { +func (u *unpacker) getPackageMetadata(archivePath string) (packageMetadata, error) { ext := filepath.Ext(archivePath) if ext == ".gz" { // if we got gzip extension we need another extension before last @@ -76,7 +103,8 @@ func (u *Upgrader) getPackageMetadata(archivePath string) (packageMetadata, erro } } -func unzip(log *logger.Logger, archivePath, dataDir string) (UnpackResult, error) { +// injecting copy, mkdirAll and openFile for testability +func unzip(log *logger.Logger, archivePath, dataDir string, copy copyFunc, mkdirAll mkdirAllFunc, openFile openFileFunc) (UnpackResult, error) { var hash, rootDir string r, err := zip.OpenReader(archivePath) if err != nil { @@ -136,8 +164,10 @@ func unzip(log *logger.Logger, archivePath, dataDir string) (UnpackResult, error // check if the directory already exists _, err = os.Stat(dstPath) if errors.Is(err, fs.ErrNotExist) { - // the directory does not exist, create it and any non-existing parent directory with the same permissions - if err := os.MkdirAll(dstPath, f.Mode().Perm()&0770); err != nil { + // the directory does not exist, create it and any non-existing + // parent directory with the same permissions. + // Using mkdirAll instead of os.MkdirAll so that we can mock it in tests. + if err := mkdirAll(dstPath, f.Mode().Perm()&0770); err != nil { return fmt.Errorf("creating directory %q: %w", dstPath, err) } } else if err != nil { @@ -150,13 +180,23 @@ func unzip(log *logger.Logger, archivePath, dataDir string) (UnpackResult, error } } - _ = os.MkdirAll(dstPath, f.Mode()&0770) + // Using mkdirAll instead of os.MkdirAll so that we can mock it in tests. + err = mkdirAll(dstPath, f.Mode()&0770) + if err != nil { + return fmt.Errorf("creating directory %q: %w", dstPath, err) + } } else { log.Debugw("Unpacking file", "archive", "zip", "file.path", dstPath) // create non-existing containing folders with 0770 permissions right now, we'll fix the permission of each - // directory as we come across them while processing the other package entries - _ = os.MkdirAll(filepath.Dir(dstPath), 0770) - f, err := os.OpenFile(dstPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode()&0770) + // directory as we come across them while processing the other + // package entries + // Using mkdirAll instead of os.MkdirAll so that we can mock it in tests. + err = mkdirAll(filepath.Dir(dstPath), 0770) + if err != nil { + return fmt.Errorf("creating directory %q: %w", dstPath, err) + } + // Using openFile instead of os.OpenFile so that we can mock it in tests. + f, err := openFile(dstPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode()&0770) if err != nil { return err } @@ -166,7 +206,9 @@ func unzip(log *logger.Logger, archivePath, dataDir string) (UnpackResult, error } }() - if _, err = io.Copy(f, rc); err != nil { //nolint:gosec // legacy + // Using copy instead of io.Copy so that we can + // mock it in tests. + if _, err = copy(f, rc); err != nil { return err } } @@ -240,8 +282,8 @@ func getPackageMetadataFromZipReader(r *zip.ReadCloser, fileNamePrefix string) ( return ret, nil } -func untar(log *logger.Logger, archivePath, dataDir string) (UnpackResult, error) { - +// injecting copy, mkdirAll and openFile for testability +func untar(log *logger.Logger, archivePath, dataDir string, copy copyFunc, mkdirAll mkdirAllFunc, openFile openFileFunc) (UnpackResult, error) { var versionedHome string var rootDir string var hash string @@ -330,17 +372,23 @@ func untar(log *logger.Logger, archivePath, dataDir string) (UnpackResult, error log.Debugw("Unpacking file", "archive", "tar", "file.path", abs) // create non-existing containing folders with 0750 permissions right now, we'll fix the permission of each // directory as we come across them while processing the other package entries - if err = os.MkdirAll(filepath.Dir(abs), 0750); err != nil { - return UnpackResult{}, errors.New(err, "TarInstaller: creating directory for file "+abs, errors.TypeFilesystem, errors.M(errors.MetaKeyPath, abs)) + // Using mkdirAll instead of os.MkdirAll so that we can + // mock it in tests. + if err = mkdirAll(filepath.Dir(abs), 0750); err != nil { + return UnpackResult{}, goerrors.Join(err, errors.New("TarInstaller: creating directory for file "+abs, errors.TypeFilesystem, errors.M(errors.MetaKeyPath, abs))) } // remove any world permissions from the file - wf, err := os.OpenFile(abs, os.O_RDWR|os.O_CREATE|os.O_TRUNC, mode.Perm()&0770) + // Using openFile instead of os.OpenFile so that we can + // mock it in tests. + wf, err := openFile(abs, os.O_RDWR|os.O_CREATE|os.O_TRUNC, mode.Perm()&0770) if err != nil { - return UnpackResult{}, errors.New(err, "TarInstaller: creating file "+abs, errors.TypeFilesystem, errors.M(errors.MetaKeyPath, abs)) + return UnpackResult{}, goerrors.Join(err, errors.New("TarInstaller: creating file "+abs, errors.TypeFilesystem, errors.M(errors.MetaKeyPath, abs))) } - _, err = io.Copy(wf, tr) //nolint:gosec // legacy + // Using copy instead of io.Copy so that we can + // mock it in tests. + _, err = copy(wf, tr) if closeErr := wf.Close(); closeErr != nil && err == nil { err = closeErr } @@ -352,9 +400,12 @@ func untar(log *logger.Logger, archivePath, dataDir string) (UnpackResult, error // check if the directory already exists _, err = os.Stat(abs) if errors.Is(err, fs.ErrNotExist) { - // the directory does not exist, create it and any non-existing parent directory with the same permissions - if err := os.MkdirAll(abs, mode.Perm()&0770); err != nil { - return UnpackResult{}, errors.New(err, "TarInstaller: creating directory for file "+abs, errors.TypeFilesystem, errors.M(errors.MetaKeyPath, abs)) + // the directory does not exist, create it and any non-existing + // parent directory with the same permissions. + // Using mkdirAll instead of os.MkdirAll so that we can + // mock it in tests. + if err := mkdirAll(abs, mode.Perm()&0770); err != nil { + return UnpackResult{}, goerrors.Join(err, errors.New("TarInstaller: creating directory for file "+abs, errors.TypeFilesystem, errors.M(errors.MetaKeyPath, abs))) } } else if err != nil { return UnpackResult{}, errors.New(err, "TarInstaller: stat() directory for file "+abs, errors.TypeFilesystem, errors.M(errors.MetaKeyPath, abs)) @@ -362,7 +413,7 @@ func untar(log *logger.Logger, archivePath, dataDir string) (UnpackResult, error // directory already exists, set the appropriate permissions err = os.Chmod(abs, mode.Perm()&0770) if err != nil { - return UnpackResult{}, errors.New(err, fmt.Sprintf("TarInstaller: setting permissions %O for directory %q", mode.Perm()&0770, abs), errors.TypeFilesystem, errors.M(errors.MetaKeyPath, abs)) + return UnpackResult{}, goerrors.Join(err, errors.New("TarInstaller: setting permissions %O for directory %q", mode.Perm()&0770, abs, errors.TypeFilesystem, errors.M(errors.MetaKeyPath, abs))) } } default: diff --git a/internal/pkg/agent/application/upgrade/step_unpack_test.go b/internal/pkg/agent/application/upgrade/step_unpack_test.go index a7d1b1ae630..11b2cbfc32b 100644 --- a/internal/pkg/agent/application/upgrade/step_unpack_test.go +++ b/internal/pkg/agent/application/upgrade/step_unpack_test.go @@ -8,6 +8,7 @@ import ( "archive/tar" "archive/zip" "compress/gzip" + "errors" "fmt" "io" "io/fs" @@ -23,6 +24,7 @@ import ( "github.com/stretchr/testify/require" v1 "github.com/elastic/elastic-agent/pkg/api/v1" + "github.com/elastic/elastic-agent/pkg/core/logger" "github.com/elastic/elastic-agent/pkg/core/logger/loggertest" ) @@ -130,6 +132,7 @@ type createArchiveFunc func(t *testing.T, archiveFiles []files) (string, error) type checkExtractedPath func(t *testing.T, testDataDir string) func TestUpgrader_unpackTarGz(t *testing.T) { + testError := errors.New("test error") type args struct { version string archiveGenerator createArchiveFunc @@ -137,11 +140,14 @@ func TestUpgrader_unpackTarGz(t *testing.T) { } tests := []struct { - name string - args args - want UnpackResult - wantErr assert.ErrorAssertionFunc - checkFiles checkExtractedPath + name string + args args + want UnpackResult + expectedError error + checkFiles checkExtractedPath + copy copyFunc + mkdirAll mkdirAllFunc + openFile openFileFunc }{ { name: "file before containing folder", @@ -156,11 +162,14 @@ func TestUpgrader_unpackTarGz(t *testing.T) { Hash: "abcdef", VersionedHome: filepath.Join("data", "elastic-agent-abcdef"), }, - wantErr: assert.NoError, + expectedError: nil, checkFiles: func(t *testing.T, testDataDir string) { versionedHome := filepath.Join(testDataDir, "elastic-agent-abcdef") checkExtractedFilesOutOfOrder(t, versionedHome) }, + copy: io.Copy, + mkdirAll: os.MkdirAll, + openFile: os.OpenFile, }, { name: "package with manifest file", @@ -175,8 +184,59 @@ func TestUpgrader_unpackTarGz(t *testing.T) { Hash: "abcdef", VersionedHome: filepath.Join("data", "elastic-agent-1.2.3-SNAPSHOT-abcdef"), }, - wantErr: assert.NoError, - checkFiles: checkExtractedFilesWithManifest, + expectedError: nil, + checkFiles: checkExtractedFilesWithManifest, + copy: io.Copy, + mkdirAll: os.MkdirAll, + openFile: os.OpenFile, + }, + { + name: "copying file fails", + args: args{ + version: "1.2.3", + archiveFiles: append(archiveFilesWithManifestNoSymlink, agentArchiveSymLink), + archiveGenerator: func(t *testing.T, i []files) (string, error) { + return createTarArchive(t, "elastic-agent-1.2.3-SNAPSHOT-someos-x86_64.tar.gz", i) + }, + }, + expectedError: testError, + copy: func(dst io.Writer, src io.Reader) (written int64, err error) { + return 0, testError + }, + mkdirAll: os.MkdirAll, + openFile: os.OpenFile, + }, + { + name: "opening file fails", + args: args{ + version: "1.2.3", + archiveFiles: append(archiveFilesWithManifestNoSymlink, agentArchiveSymLink), + archiveGenerator: func(t *testing.T, i []files) (string, error) { + return createTarArchive(t, "elastic-agent-1.2.3-SNAPSHOT-someos-x86_64.tar.gz", i) + }, + }, + expectedError: testError, + openFile: func(name string, flag int, perm os.FileMode) (*os.File, error) { + return nil, testError + }, + mkdirAll: os.MkdirAll, + copy: io.Copy, + }, + { + name: "creating directory fails", + args: args{ + version: "1.2.3", + archiveFiles: append(archiveFilesWithManifestNoSymlink, agentArchiveSymLink), + archiveGenerator: func(t *testing.T, i []files) (string, error) { + return createTarArchive(t, "elastic-agent-1.2.3-SNAPSHOT-someos-x86_64.tar.gz", i) + }, + }, + expectedError: testError, + mkdirAll: func(name string, perm os.FileMode) error { + return testError + }, + openFile: os.OpenFile, + copy: io.Copy, }, } for _, tt := range tests { @@ -190,10 +250,12 @@ func TestUpgrader_unpackTarGz(t *testing.T) { archiveFile, err := tt.args.archiveGenerator(t, tt.args.archiveFiles) require.NoError(t, err, "creation of test archive file failed") - got, err := untar(log, archiveFile, testDataDir) - if !tt.wantErr(t, err, fmt.Sprintf("untar(%v, %v, %v)", tt.args.version, archiveFile, testDataDir)) { + got, err := untar(log, archiveFile, testDataDir, tt.copy, tt.mkdirAll, tt.openFile) + if tt.expectedError != nil { + assert.ErrorIsf(t, err, tt.expectedError, "untar(%v, %v, %v)", tt.args.version, archiveFile, testDataDir) return } + assert.NoErrorf(t, err, "untar(%v, %v, %v)", tt.args.version, archiveFile, testDataDir) assert.Equalf(t, tt.want, got, "untar(%v, %v, %v)", tt.args.version, archiveFile, testDataDir) if tt.checkFiles != nil { tt.checkFiles(t, testDataDir) @@ -203,17 +265,21 @@ func TestUpgrader_unpackTarGz(t *testing.T) { } func TestUpgrader_unpackZip(t *testing.T) { + testError := errors.New("test error") type args struct { archiveGenerator createArchiveFunc archiveFiles []files } tests := []struct { - name string - args args - want UnpackResult - wantErr assert.ErrorAssertionFunc - checkFiles checkExtractedPath + name string + args args + want UnpackResult + expectedError error + checkFiles checkExtractedPath + copy copyFunc + mkdirAll mkdirAllFunc + openFile openFileFunc }{ { name: "file before containing folder", @@ -227,11 +293,14 @@ func TestUpgrader_unpackZip(t *testing.T) { Hash: "abcdef", VersionedHome: filepath.Join("data", "elastic-agent-abcdef"), }, - wantErr: assert.NoError, + expectedError: nil, checkFiles: func(t *testing.T, testDataDir string) { versionedHome := filepath.Join(testDataDir, "elastic-agent-abcdef") checkExtractedFilesOutOfOrder(t, versionedHome) }, + copy: io.Copy, + mkdirAll: os.MkdirAll, + openFile: os.OpenFile, }, { name: "package with manifest file", @@ -245,8 +314,56 @@ func TestUpgrader_unpackZip(t *testing.T) { Hash: "abcdef", VersionedHome: filepath.Join("data", "elastic-agent-1.2.3-SNAPSHOT-abcdef"), }, - wantErr: assert.NoError, - checkFiles: checkExtractedFilesWithManifest, + expectedError: nil, + checkFiles: checkExtractedFilesWithManifest, + copy: io.Copy, + mkdirAll: os.MkdirAll, + openFile: os.OpenFile, + }, + { + name: "copying file fails", + args: args{ + archiveFiles: archiveFilesWithManifestNoSymlink, + archiveGenerator: func(t *testing.T, i []files) (string, error) { + return createZipArchive(t, "elastic-agent-1.2.3-SNAPSHOT-someos-x86_64.zip", i) + }, + }, + expectedError: testError, + copy: func(dst io.Writer, src io.Reader) (written int64, err error) { + return 0, testError + }, + mkdirAll: os.MkdirAll, + openFile: os.OpenFile, + }, + { + name: "opening file fails", + args: args{ + archiveFiles: archiveFilesWithManifestNoSymlink, + archiveGenerator: func(t *testing.T, i []files) (string, error) { + return createZipArchive(t, "elastic-agent-1.2.3-SNAPSHOT-someos-x86_64.zip", i) + }, + }, + expectedError: testError, + openFile: func(name string, flag int, perm os.FileMode) (*os.File, error) { + return nil, testError + }, + mkdirAll: os.MkdirAll, + copy: io.Copy, + }, + { + name: "creating directory fails", + args: args{ + archiveFiles: archiveFilesWithManifestNoSymlink, + archiveGenerator: func(t *testing.T, i []files) (string, error) { + return createZipArchive(t, "elastic-agent-1.2.3-SNAPSHOT-someos-x86_64.zip", i) + }, + }, + expectedError: testError, + mkdirAll: func(name string, perm os.FileMode) error { + return testError + }, + openFile: os.OpenFile, + copy: io.Copy, }, } for _, tt := range tests { @@ -261,10 +378,12 @@ func TestUpgrader_unpackZip(t *testing.T) { archiveFile, err := tt.args.archiveGenerator(t, tt.args.archiveFiles) require.NoError(t, err, "creation of test archive file failed") - got, err := unzip(log, archiveFile, testDataDir) - if !tt.wantErr(t, err, fmt.Sprintf("unzip(%v, %v)", archiveFile, testDataDir)) { + got, err := unzip(log, archiveFile, testDataDir, tt.copy, tt.mkdirAll, tt.openFile) + if tt.expectedError != nil { + assert.ErrorIs(t, err, tt.expectedError, "error mismatch") return } + assert.NoErrorf(t, err, "unzip(%v, %v)", archiveFile, testDataDir) assert.Equalf(t, tt.want, got, "unzip(%v, %v)", archiveFile, testDataDir) if tt.checkFiles != nil { tt.checkFiles(t, testDataDir) @@ -443,3 +562,49 @@ func TestGetFileNamePrefix(t *testing.T) { } } + +func TestUnpack(t *testing.T) { + log, _ := loggertest.New("TestUnpack") + + unarchiveSetup := func(unpackResult UnpackResult, err error) unarchiveFunc { + return func(log *logger.Logger, archivePath, dataDir string, copy copyFunc, mkdirAll mkdirAllFunc, openFile openFileFunc) (UnpackResult, error) { + return unpackResult, err + } + } + + type testCase struct { + expectedUnpackResult UnpackResult + expectedErr error + unarchiveFunc unarchiveFunc + } + + testCases := map[string]testCase{ + "when unarchiving succeeds it should return the unpack result": { + expectedUnpackResult: UnpackResult{ + Hash: "abcdef", + VersionedHome: filepath.Join("data", "elastic-agent-abcdef"), + }, + expectedErr: nil, + unarchiveFunc: unarchiveSetup(UnpackResult{ + Hash: "abcdef", + VersionedHome: filepath.Join("data", "elastic-agent-abcdef"), + }, nil), + }, + "when unarchiving fails it should return an error": { + expectedUnpackResult: UnpackResult{}, + expectedErr: errors.New("unarchiving failed"), + unarchiveFunc: unarchiveSetup(UnpackResult{}, errors.New("unarchiving failed")), + }, + } + + for name, test := range testCases { + t.Run(name, func(t *testing.T) { + unpacker := newUnpacker(log) + unpacker.untar = test.unarchiveFunc + unpacker.unzip = test.unarchiveFunc + unpackResult, unpackErr := unpacker.unpack("mockVersion", "mockArchivePath", "mockDataDir") + assert.Equal(t, test.expectedUnpackResult, unpackResult) + assert.Equal(t, test.expectedErr, unpackErr) + }) + } +} diff --git a/internal/pkg/agent/application/upgrade/upgrade.go b/internal/pkg/agent/application/upgrade/upgrade.go index d933d069ddc..c8587094c62 100644 --- a/internal/pkg/agent/application/upgrade/upgrade.go +++ b/internal/pkg/agent/application/upgrade/upgrade.go @@ -73,6 +73,10 @@ type artifactDownloadHandler interface { downloadArtifact(ctx context.Context, parsedVersion *agtversion.ParsedSemVer, sourceURI string, upgradeDetails *details.Details, skipVerifyOverride, skipDefaultPgp bool, pgpBytes ...string) (_ string, err error) withFleetServerURI(fleetServerURI string) } +type unpackHandler interface { + unpack(version, archivePath, dataDir string) (UnpackResult, error) + getPackageMetadata(archivePath string) (packageMetadata, error) +} // Upgrader performs an upgrade type Upgrader struct { @@ -85,7 +89,9 @@ type Upgrader struct { // The following are abstractions for testability artifactDownloader artifactDownloadHandler + unpacker unpackHandler isDiskSpaceErrorFunc func(err error) bool + extractAgentVersion func(metadata packageMetadata, upgradeVersion string) agentVersion } // IsUpgradeable when agent is installed and running as a service or flag was provided. @@ -104,7 +110,9 @@ func NewUpgrader(log *logger.Logger, settings *artifact.Config, agentInfo info.A upgradeable: IsUpgradeable(), markerWatcher: newMarkerFileWatcher(markerFilePath(paths.Data()), log), artifactDownloader: newArtifactDownloader(settings, log), + unpacker: newUnpacker(log), isDiskSpaceErrorFunc: upgradeErrors.IsDiskSpaceError, + extractAgentVersion: extractAgentVersion, }, nil } @@ -274,20 +282,19 @@ func (u *Upgrader) Upgrade(ctx context.Context, version string, sourceURI string det.SetState(details.StateExtracting) - metadata, err := u.getPackageMetadata(archivePath) + metadata, err := u.unpacker.getPackageMetadata(archivePath) if err != nil { return nil, fmt.Errorf("reading metadata for elastic agent version %s package %q: %w", version, archivePath, err) } - newVersion := extractAgentVersion(metadata, version) + newVersion := u.extractAgentVersion(metadata, version) if err := checkUpgrade(u.log, currentVersion, newVersion, metadata); err != nil { return nil, fmt.Errorf("cannot upgrade the agent: %w", err) } u.log.Infow("Unpacking agent package", "version", newVersion) - // Nice to have: add check that no archive files end up in the current versioned home - unpackRes, err := u.unpack(version, archivePath, paths.Data()) + unpackRes, err := u.unpacker.unpack(version, archivePath, paths.Data()) if err != nil { return nil, err } diff --git a/internal/pkg/agent/application/upgrade/upgrade_test.go b/internal/pkg/agent/application/upgrade/upgrade_test.go index 4af120d211b..3c55bc52f95 100644 --- a/internal/pkg/agent/application/upgrade/upgrade_test.go +++ b/internal/pkg/agent/application/upgrade/upgrade_test.go @@ -1311,38 +1311,94 @@ func (m *mockArtifactDownloader) withFleetServerURI(fleetServerURI string) { m.fleetServerURI = fleetServerURI } +type mockUnpacker struct { + returnPackageMetadata packageMetadata + returnPackageMetadataError error + returnUnpackResult UnpackResult + returnUnpackError error +} + +func (m *mockUnpacker) getPackageMetadata(archivePath string) (packageMetadata, error) { + return m.returnPackageMetadata, m.returnPackageMetadataError +} + +func (m *mockUnpacker) unpack(version, archivePath, dataDir string) (UnpackResult, error) { + return m.returnUnpackResult, m.returnUnpackError +} + func TestUpgradeErrorHandling(t *testing.T) { log, _ := loggertest.New("test") testError := errors.New("test error") + type upgraderMocker func(upgrader *Upgrader) type testCase struct { isDiskSpaceErrorResult bool expectedError error + upgraderMocker upgraderMocker } testCases := map[string]testCase{ "should return error if downloadArtifact fails": { isDiskSpaceErrorResult: false, expectedError: testError, + upgraderMocker: func(upgrader *Upgrader) { + upgrader.artifactDownloader = &mockArtifactDownloader{ + returnError: testError, + } + }, + }, + "should return error if getPackageMetadata fails": { + isDiskSpaceErrorResult: false, + expectedError: testError, + upgraderMocker: func(upgrader *Upgrader) { + upgrader.artifactDownloader = &mockArtifactDownloader{} + upgrader.unpacker = &mockUnpacker{ + returnPackageMetadataError: testError, + } + }, + }, + "should return error if unpack fails": { + isDiskSpaceErrorResult: false, + expectedError: testError, + upgraderMocker: func(upgrader *Upgrader) { + upgrader.artifactDownloader = &mockArtifactDownloader{} + upgrader.extractAgentVersion = func(metadata packageMetadata, upgradeVersion string) agentVersion { + return agentVersion{ + version: upgradeVersion, + snapshot: false, + hash: metadata.hash, + } + } + upgrader.unpacker = &mockUnpacker{ + returnPackageMetadata: packageMetadata{ + manifest: &v1.PackageManifest{}, + hash: "hash", + }, + returnUnpackError: testError, + } + }, }, "should add disk space error to the error chain if downloadArtifact fails with disk space error": { isDiskSpaceErrorResult: true, expectedError: upgradeErrors.ErrInsufficientDiskSpace, + upgraderMocker: func(upgrader *Upgrader) { + upgrader.artifactDownloader = &mockArtifactDownloader{ + returnError: upgradeErrors.ErrInsufficientDiskSpace, + } + }, }, } mockAgentInfo := info.NewAgent(t) mockAgentInfo.On("Version").Return("9.0.0") - upgrader, err := NewUpgrader(log, &artifact.Config{}, mockAgentInfo) - require.NoError(t, err) - - upgrader.artifactDownloader = &mockArtifactDownloader{ - returnError: testError, - } - for name, tc := range testCases { t.Run(name, func(t *testing.T) { + upgrader, err := NewUpgrader(log, &artifact.Config{}, mockAgentInfo) + require.NoError(t, err) + + tc.upgraderMocker(upgrader) + upgrader.isDiskSpaceErrorFunc = func(err error) bool { return tc.isDiskSpaceErrorResult } @@ -1362,7 +1418,6 @@ func (m *mockSender) Send(ctx context.Context, method, path string, params url.V func (m *mockSender) URI() string { return "mockURI" } - func TestSetClient(t *testing.T) { log, _ := loggertest.New("test") upgrader := &Upgrader{