summaryrefslogtreecommitdiffstats
path: root/misc/reboot/overlaydir_test.go
blob: 71faf0936ba19af6e961f158deb8ffd488b5fb28 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package reboot_test

import (
	"io"
	"io/fs"
	"os"
	"path/filepath"
	"strings"
)

// overlayDir makes a minimal-overhead copy of srcRoot in which new files may be added.
//
// TODO: Once we no longer need to support the misc module in GOPATH mode,
// factor this function out into a package to reduce duplication.
func overlayDir(dstRoot, srcRoot string) error {
	dstRoot = filepath.Clean(dstRoot)
	if err := os.MkdirAll(dstRoot, 0777); err != nil {
		return err
	}

	srcRoot, err := filepath.Abs(srcRoot)
	if err != nil {
		return err
	}

	return filepath.WalkDir(srcRoot, func(srcPath string, entry fs.DirEntry, err error) error {
		if err != nil || srcPath == srcRoot {
			return err
		}
		if filepath.Base(srcPath) == "testdata" {
			// We're just building, so no need to copy those.
			return fs.SkipDir
		}

		suffix := strings.TrimPrefix(srcPath, srcRoot)
		for len(suffix) > 0 && suffix[0] == filepath.Separator {
			suffix = suffix[1:]
		}
		dstPath := filepath.Join(dstRoot, suffix)

		info, err := entry.Info()
		perm := info.Mode() & os.ModePerm
		if info.Mode()&os.ModeSymlink != 0 {
			info, err = os.Stat(srcPath)
			if err != nil {
				return err
			}
			perm = info.Mode() & os.ModePerm
		}

		// Always make copies of directories.
		// If we add a file in the overlay, we don't want to add it in the original.
		if info.IsDir() {
			return os.MkdirAll(dstPath, perm|0200)
		}

		// If we can use a hard link, do that instead of copying bytes.
		// Go builds don't like symlinks in some cases, such as go:embed.
		if err := os.Link(srcPath, dstPath); err == nil {
			return nil
		}

		// Otherwise, copy the bytes.
		src, err := os.Open(srcPath)
		if err != nil {
			return err
		}
		defer src.Close()

		dst, err := os.OpenFile(dstPath, os.O_WRONLY|os.O_CREATE|os.O_EXCL, perm)
		if err != nil {
			return err
		}

		_, err = io.Copy(dst, src)
		if closeErr := dst.Close(); err == nil {
			err = closeErr
		}
		return err
	})
}