diff --git a/pkg/downloader/downloader.go b/pkg/downloader/downloader.go index 41d48e39a88..723c7ea702a 100644 --- a/pkg/downloader/downloader.go +++ b/pkg/downloader/downloader.go @@ -1,12 +1,15 @@ package downloader import ( + "bytes" "crypto/sha256" "errors" "fmt" "io" "net/http" "os" + "os/exec" + "path" "path/filepath" "strings" "time" @@ -38,6 +41,7 @@ type Result struct { type options struct { cacheDir string // default: empty (disables caching) + decompress bool // default: false (keep compression) description string // default: url expectedDigest digest.Digest } @@ -73,6 +77,14 @@ func WithDescription(description string) Opt { } } +// WithDecompress decompress the download from the cache. +func WithDecompress(decompress bool) Opt { + return func(o *options) error { + o.decompress = decompress + return nil + } +} + // WithExpectedDigest is used to validate the downloaded file against the expected digest. // // The digest is not verified in the following cases: @@ -142,8 +154,9 @@ func Download(local, remote string, opts ...Opt) (*Result, error) { } } + ext := path.Ext(remote) if IsLocal(remote) { - if err := copyLocal(localPath, remote, o.description, o.expectedDigest); err != nil { + if err := copyLocal(localPath, remote, ext, o.decompress, o.description, o.expectedDigest); err != nil { return nil, err } res := &Result{ @@ -183,11 +196,11 @@ func Download(local, remote string, opts ...Opt) (*Result, error) { if o.expectedDigest.String() != shadDigestS { return nil, fmt.Errorf("expected digest %q does not match the cached digest %q", o.expectedDigest.String(), shadDigestS) } - if err := copyLocal(localPath, shadData, "", ""); err != nil { + if err := copyLocal(localPath, shadData, ext, o.decompress, "", ""); err != nil { return nil, err } } else { - if err := copyLocal(localPath, shadData, o.description, o.expectedDigest); err != nil { + if err := copyLocal(localPath, shadData, ext, o.decompress, o.description, o.expectedDigest); err != nil { return nil, err } } @@ -212,7 +225,7 @@ func Download(local, remote string, opts ...Opt) (*Result, error) { return nil, err } // no need to pass the digest to copyLocal(), as we already verified the digest - if err := copyLocal(localPath, shadData, "", ""); err != nil { + if err := copyLocal(localPath, shadData, ext, o.decompress, "", ""); err != nil { return nil, err } if shadDigest != "" && o.expectedDigest != "" { @@ -253,7 +266,7 @@ func canonicalLocalPath(s string) (string, error) { return localpathutil.Expand(s) } -func copyLocal(dst, src string, description string, expectedDigest digest.Digest) error { +func copyLocal(dst, src, ext string, decompress bool, description string, expectedDigest digest.Digest) error { srcPath, err := canonicalLocalPath(src) if err != nil { return err @@ -274,9 +287,60 @@ func copyLocal(dst, src string, description string, expectedDigest digest.Digest if description != "" { // TODO: progress bar for copy } + if _, ok := Decompressor(ext); ok && decompress { + return decompressLocal(dstPath, srcPath, ext) + } return fs.CopyFile(dstPath, srcPath) } +func Decompressor(ext string) ([]string, bool) { + var program string + switch ext { + case ".gz": + program = "gzip" + case ".bz2": + program = "bzip2" + case ".xz": + program = "xz" + case ".zst": + program = "zstd" + default: + return nil, false + } + // -d --decompress + return []string{program, "-d"}, true +} + +func decompressLocal(dst, src, ext string) error { + command, found := Decompressor(ext) + if !found { + return fmt.Errorf("decompressLocal: unknown extension %s", ext) + } + logrus.Infof("decompressing %s with %v", ext, command) + in, err := os.Open(src) + if err != nil { + return err + } + defer in.Close() + out, err := os.OpenFile(dst, os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + return err + } + defer out.Close() + buf := new(bytes.Buffer) + cmd := exec.Command(command[0], command[1:]...) + cmd.Stdin = in + cmd.Stdout = out + cmd.Stderr = buf + err = cmd.Run() + if err != nil { + if ee, ok := err.(*exec.ExitError); ok { + ee.Stderr = buf.Bytes() + } + } + return err +} + func validateLocalFileDigest(localPath string, expectedDigest digest.Digest) error { if localPath == "" { return fmt.Errorf("validateLocalFileDigest: got empty localPath") diff --git a/pkg/downloader/downloader_test.go b/pkg/downloader/downloader_test.go index 93785861126..1de5cbb89ab 100644 --- a/pkg/downloader/downloader_test.go +++ b/pkg/downloader/downloader_test.go @@ -3,6 +3,7 @@ package downloader import ( "io/ioutil" "os" + "os/exec" "path/filepath" "runtime" "testing" @@ -130,3 +131,29 @@ func TestDownloadLocal(t *testing.T) { }) } + +func TestDownloadCompressed(t *testing.T) { + + if runtime.GOOS == "windows" { + // FIXME: `assertion failed: error is not nil: exec: "gzip": executable file not found in %PATH%` + t.Skip("Skipping on windows") + } + + t.Run("gzip", func(t *testing.T) { + localPath := filepath.Join(t.TempDir(), t.Name()) + localFile := filepath.Join(t.TempDir(), "test-file") + testDownloadCompressedContents := []byte("TestDownloadCompressed") + ioutil.WriteFile(localFile, testDownloadCompressedContents, 0644) + assert.NilError(t, exec.Command("gzip", localFile).Run()) + localFile += ".gz" + testLocalFileURL := "file://" + localFile + + r, err := Download(localPath, testLocalFileURL, WithDecompress(true)) + assert.NilError(t, err) + assert.Equal(t, StatusDownloaded, r.Status) + + got, err := os.ReadFile(localPath) + assert.NilError(t, err) + assert.Equal(t, string(got), string(testDownloadCompressedContents)) + }) +} diff --git a/pkg/fileutils/download.go b/pkg/fileutils/download.go index 76aaa09b1b0..a973dc02de8 100644 --- a/pkg/fileutils/download.go +++ b/pkg/fileutils/download.go @@ -10,7 +10,7 @@ import ( ) // DownloadFile downloads a file to the cache, optionally copying it to the destination. Returns path in cache. -func DownloadFile(dest string, f limayaml.File, description string, expectedArch limayaml.Arch) (string, error) { +func DownloadFile(dest string, f limayaml.File, decompress bool, description string, expectedArch limayaml.Arch) (string, error) { if f.Arch != expectedArch { return "", fmt.Errorf("unsupported arch: %q", f.Arch) } @@ -18,6 +18,7 @@ func DownloadFile(dest string, f limayaml.File, description string, expectedArch logrus.WithFields(fields).Infof("Attempting to download %s", description) res, err := downloader.Download(dest, f.Location, downloader.WithCache(), + downloader.WithDecompress(decompress), downloader.WithDescription(fmt.Sprintf("%s (%s)", description, path.Base(f.Location))), downloader.WithExpectedDigest(f.Digest), ) diff --git a/pkg/qemu/qemu.go b/pkg/qemu/qemu.go index 57d2a903176..c6192bb4d54 100644 --- a/pkg/qemu/qemu.go +++ b/pkg/qemu/qemu.go @@ -50,12 +50,12 @@ func EnsureDisk(cfg Config) error { var ensuredBaseDisk bool errs := make([]error, len(cfg.LimaYAML.Images)) for i, f := range cfg.LimaYAML.Images { - if _, err := fileutils.DownloadFile(baseDisk, f.File, "the image", *cfg.LimaYAML.Arch); err != nil { + if _, err := fileutils.DownloadFile(baseDisk, f.File, true, "the image", *cfg.LimaYAML.Arch); err != nil { errs[i] = err continue } if f.Kernel != nil { - if _, err := fileutils.DownloadFile(kernel, f.Kernel.File, "the kernel", *cfg.LimaYAML.Arch); err != nil { + if _, err := fileutils.DownloadFile(kernel, f.Kernel.File, false, "the kernel", *cfg.LimaYAML.Arch); err != nil { errs[i] = err continue } @@ -67,7 +67,7 @@ func EnsureDisk(cfg Config) error { } } if f.Initrd != nil { - if _, err := fileutils.DownloadFile(initrd, *f.Initrd, "the initrd", *cfg.LimaYAML.Arch); err != nil { + if _, err := fileutils.DownloadFile(initrd, *f.Initrd, false, "the initrd", *cfg.LimaYAML.Arch); err != nil { errs[i] = err continue } diff --git a/pkg/start/start.go b/pkg/start/start.go index 8e504ea1ac8..70e8c377057 100644 --- a/pkg/start/start.go +++ b/pkg/start/start.go @@ -39,7 +39,7 @@ func ensureNerdctlArchiveCache(y *limayaml.LimaYAML) (string, error) { errs := make([]error, len(y.Containerd.Archives)) for i, f := range y.Containerd.Archives { - path, err := fileutils.DownloadFile("", f, "the nerdctl archive", *y.Arch) + path, err := fileutils.DownloadFile("", f, false, "the nerdctl archive", *y.Arch) if err != nil { errs[i] = err continue diff --git a/pkg/vz/disk.go b/pkg/vz/disk.go index 58bbc92a2fd..1513c0558a0 100644 --- a/pkg/vz/disk.go +++ b/pkg/vz/disk.go @@ -29,7 +29,7 @@ func EnsureDisk(driver *driver.BaseDriver) error { var ensuredBaseDisk bool errs := make([]error, len(driver.Yaml.Images)) for i, f := range driver.Yaml.Images { - if _, err := fileutils.DownloadFile(baseDisk, f.File, "the image", *driver.Yaml.Arch); err != nil { + if _, err := fileutils.DownloadFile(baseDisk, f.File, true, "the image", *driver.Yaml.Arch); err != nil { errs[i] = err continue }