diff options
Diffstat (limited to 'go-selinux/selinux_linux.go')
-rw-r--r-- | go-selinux/selinux_linux.go | 1295 |
1 files changed, 1295 insertions, 0 deletions
diff --git a/go-selinux/selinux_linux.go b/go-selinux/selinux_linux.go new file mode 100644 index 0000000..f1e9597 --- /dev/null +++ b/go-selinux/selinux_linux.go @@ -0,0 +1,1295 @@ +package selinux + +import ( + "bufio" + "bytes" + "crypto/rand" + "encoding/binary" + "errors" + "fmt" + "io" + "io/fs" + "math/big" + "os" + "os/user" + "path/filepath" + "strconv" + "strings" + "sync" + + "github.com/opencontainers/selinux/pkg/pwalkdir" + "golang.org/x/sys/unix" +) + +const ( + minSensLen = 2 + contextFile = "/usr/share/containers/selinux/contexts" + selinuxDir = "/etc/selinux/" + selinuxUsersDir = "contexts/users" + defaultContexts = "contexts/default_contexts" + selinuxConfig = selinuxDir + "config" + selinuxfsMount = "/sys/fs/selinux" + selinuxTypeTag = "SELINUXTYPE" + selinuxTag = "SELINUX" + xattrNameSelinux = "security.selinux" +) + +type selinuxState struct { + mcsList map[string]bool + selinuxfs string + selinuxfsOnce sync.Once + enabledSet bool + enabled bool + sync.Mutex +} + +type level struct { + cats *big.Int + sens uint +} + +type mlsRange struct { + low *level + high *level +} + +type defaultSECtx struct { + userRdr io.Reader + verifier func(string) error + defaultRdr io.Reader + user, level, scon string +} + +type levelItem byte + +const ( + sensitivity levelItem = 's' + category levelItem = 'c' +) + +var ( + readOnlyFileLabel string + state = selinuxState{ + mcsList: make(map[string]bool), + } + + // for attrPath() + attrPathOnce sync.Once + haveThreadSelf bool + + // for policyRoot() + policyRootOnce sync.Once + policyRootVal string + + // for label() + loadLabelsOnce sync.Once + labels map[string]string +) + +func policyRoot() string { + policyRootOnce.Do(func() { + policyRootVal = filepath.Join(selinuxDir, readConfig(selinuxTypeTag)) + }) + + return policyRootVal +} + +func (s *selinuxState) setEnable(enabled bool) bool { + s.Lock() + defer s.Unlock() + s.enabledSet = true + s.enabled = enabled + return s.enabled +} + +func (s *selinuxState) getEnabled() bool { + s.Lock() + enabled := s.enabled + enabledSet := s.enabledSet + s.Unlock() + if enabledSet { + return enabled + } + + enabled = false + if fs := getSelinuxMountPoint(); fs != "" { + if con, _ := CurrentLabel(); con != "kernel" { + enabled = true + } + } + return s.setEnable(enabled) +} + +// setDisabled disables SELinux support for the package +func setDisabled() { + state.setEnable(false) +} + +func verifySELinuxfsMount(mnt string) bool { + var buf unix.Statfs_t + for { + err := unix.Statfs(mnt, &buf) + if err == nil { + break + } + if err == unix.EAGAIN || err == unix.EINTR { //nolint:errorlint // unix errors are bare + continue + } + return false + } + + if uint32(buf.Type) != uint32(unix.SELINUX_MAGIC) { + return false + } + if (buf.Flags & unix.ST_RDONLY) != 0 { + return false + } + + return true +} + +func findSELinuxfs() string { + // fast path: check the default mount first + if verifySELinuxfsMount(selinuxfsMount) { + return selinuxfsMount + } + + // check if selinuxfs is available before going the slow path + fs, err := os.ReadFile("/proc/filesystems") + if err != nil { + return "" + } + if !bytes.Contains(fs, []byte("\tselinuxfs\n")) { + return "" + } + + // slow path: try to find among the mounts + f, err := os.Open("/proc/self/mountinfo") + if err != nil { + return "" + } + defer f.Close() + + scanner := bufio.NewScanner(f) + for { + mnt := findSELinuxfsMount(scanner) + if mnt == "" { // error or not found + return "" + } + if verifySELinuxfsMount(mnt) { + return mnt + } + } +} + +// findSELinuxfsMount returns a next selinuxfs mount point found, +// if there is one, or an empty string in case of EOF or error. +func findSELinuxfsMount(s *bufio.Scanner) string { + for s.Scan() { + txt := s.Bytes() + // The first field after - is fs type. + // Safe as spaces in mountpoints are encoded as \040 + if !bytes.Contains(txt, []byte(" - selinuxfs ")) { + continue + } + const mPos = 5 // mount point is 5th field + fields := bytes.SplitN(txt, []byte(" "), mPos+1) + if len(fields) < mPos+1 { + continue + } + return string(fields[mPos-1]) + } + + return "" +} + +func (s *selinuxState) getSELinuxfs() string { + s.selinuxfsOnce.Do(func() { + s.selinuxfs = findSELinuxfs() + }) + + return s.selinuxfs +} + +// getSelinuxMountPoint returns the path to the mountpoint of an selinuxfs +// filesystem or an empty string if no mountpoint is found. Selinuxfs is +// a proc-like pseudo-filesystem that exposes the SELinux policy API to +// processes. The existence of an selinuxfs mount is used to determine +// whether SELinux is currently enabled or not. +func getSelinuxMountPoint() string { + return state.getSELinuxfs() +} + +// getEnabled returns whether SELinux is currently enabled. +func getEnabled() bool { + return state.getEnabled() +} + +func readConfig(target string) string { + in, err := os.Open(selinuxConfig) + if err != nil { + return "" + } + defer in.Close() + + scanner := bufio.NewScanner(in) + + for scanner.Scan() { + line := bytes.TrimSpace(scanner.Bytes()) + if len(line) == 0 { + // Skip blank lines + continue + } + if line[0] == ';' || line[0] == '#' { + // Skip comments + continue + } + fields := bytes.SplitN(line, []byte{'='}, 2) + if len(fields) != 2 { + continue + } + if bytes.Equal(fields[0], []byte(target)) { + return string(bytes.Trim(fields[1], `"`)) + } + } + return "" +} + +func isProcHandle(fh *os.File) error { + var buf unix.Statfs_t + + for { + err := unix.Fstatfs(int(fh.Fd()), &buf) + if err == nil { + break + } + if err != unix.EINTR { //nolint:errorlint // unix errors are bare + return &os.PathError{Op: "fstatfs", Path: fh.Name(), Err: err} + } + } + if buf.Type != unix.PROC_SUPER_MAGIC { + return fmt.Errorf("file %q is not on procfs", fh.Name()) + } + + return nil +} + +func readCon(fpath string) (string, error) { + if fpath == "" { + return "", ErrEmptyPath + } + + in, err := os.Open(fpath) + if err != nil { + return "", err + } + defer in.Close() + + if err := isProcHandle(in); err != nil { + return "", err + } + return readConFd(in) +} + +func readConFd(in *os.File) (string, error) { + data, err := io.ReadAll(in) + if err != nil { + return "", err + } + return string(bytes.TrimSuffix(data, []byte{0})), nil +} + +// classIndex returns the int index for an object class in the loaded policy, +// or -1 and an error +func classIndex(class string) (int, error) { + permpath := fmt.Sprintf("class/%s/index", class) + indexpath := filepath.Join(getSelinuxMountPoint(), permpath) + + indexB, err := os.ReadFile(indexpath) + if err != nil { + return -1, err + } + index, err := strconv.Atoi(string(indexB)) + if err != nil { + return -1, err + } + + return index, nil +} + +// lSetFileLabel sets the SELinux label for this path, not following symlinks, +// or returns an error. +func lSetFileLabel(fpath string, label string) error { + if fpath == "" { + return ErrEmptyPath + } + for { + err := unix.Lsetxattr(fpath, xattrNameSelinux, []byte(label), 0) + if err == nil { + break + } + if err != unix.EINTR { //nolint:errorlint // unix errors are bare + return &os.PathError{Op: "lsetxattr", Path: fpath, Err: err} + } + } + + return nil +} + +// setFileLabel sets the SELinux label for this path, following symlinks, +// or returns an error. +func setFileLabel(fpath string, label string) error { + if fpath == "" { + return ErrEmptyPath + } + for { + err := unix.Setxattr(fpath, xattrNameSelinux, []byte(label), 0) + if err == nil { + break + } + if err != unix.EINTR { //nolint:errorlint // unix errors are bare + return &os.PathError{Op: "setxattr", Path: fpath, Err: err} + } + } + + return nil +} + +// fileLabel returns the SELinux label for this path, following symlinks, +// or returns an error. +func fileLabel(fpath string) (string, error) { + if fpath == "" { + return "", ErrEmptyPath + } + + label, err := getxattr(fpath, xattrNameSelinux) + if err != nil { + return "", &os.PathError{Op: "getxattr", Path: fpath, Err: err} + } + // Trim the NUL byte at the end of the byte buffer, if present. + if len(label) > 0 && label[len(label)-1] == '\x00' { + label = label[:len(label)-1] + } + return string(label), nil +} + +// lFileLabel returns the SELinux label for this path, not following symlinks, +// or returns an error. +func lFileLabel(fpath string) (string, error) { + if fpath == "" { + return "", ErrEmptyPath + } + + label, err := lgetxattr(fpath, xattrNameSelinux) + if err != nil { + return "", &os.PathError{Op: "lgetxattr", Path: fpath, Err: err} + } + // Trim the NUL byte at the end of the byte buffer, if present. + if len(label) > 0 && label[len(label)-1] == '\x00' { + label = label[:len(label)-1] + } + return string(label), nil +} + +func setFSCreateLabel(label string) error { + return writeCon(attrPath("fscreate"), label) +} + +// fsCreateLabel returns the default label the kernel which the kernel is using +// for file system objects created by this task. "" indicates default. +func fsCreateLabel() (string, error) { + return readCon(attrPath("fscreate")) +} + +// currentLabel returns the SELinux label of the current process thread, or an error. +func currentLabel() (string, error) { + return readCon(attrPath("current")) +} + +// pidLabel returns the SELinux label of the given pid, or an error. +func pidLabel(pid int) (string, error) { + return readCon(fmt.Sprintf("/proc/%d/attr/current", pid)) +} + +// ExecLabel returns the SELinux label that the kernel will use for any programs +// that are executed by the current process thread, or an error. +func execLabel() (string, error) { + return readCon(attrPath("exec")) +} + +func writeCon(fpath, val string) error { + if fpath == "" { + return ErrEmptyPath + } + if val == "" { + if !getEnabled() { + return nil + } + } + + out, err := os.OpenFile(fpath, os.O_WRONLY, 0) + if err != nil { + return err + } + defer out.Close() + + if err := isProcHandle(out); err != nil { + return err + } + + if val != "" { + _, err = out.Write([]byte(val)) + } else { + _, err = out.Write(nil) + } + if err != nil { + return err + } + return nil +} + +func attrPath(attr string) string { + // Linux >= 3.17 provides this + const threadSelfPrefix = "/proc/thread-self/attr" + + attrPathOnce.Do(func() { + st, err := os.Stat(threadSelfPrefix) + if err == nil && st.Mode().IsDir() { + haveThreadSelf = true + } + }) + + if haveThreadSelf { + return filepath.Join(threadSelfPrefix, attr) + } + + return filepath.Join("/proc/self/task", strconv.Itoa(unix.Gettid()), "attr", attr) +} + +// canonicalizeContext takes a context string and writes it to the kernel +// the function then returns the context that the kernel will use. Use this +// function to check if two contexts are equivalent +func canonicalizeContext(val string) (string, error) { + return readWriteCon(filepath.Join(getSelinuxMountPoint(), "context"), val) +} + +// computeCreateContext requests the type transition from source to target for +// class from the kernel. +func computeCreateContext(source string, target string, class string) (string, error) { + classidx, err := classIndex(class) + if err != nil { + return "", err + } + + return readWriteCon(filepath.Join(getSelinuxMountPoint(), "create"), fmt.Sprintf("%s %s %d", source, target, classidx)) +} + +// catsToBitset stores categories in a bitset. +func catsToBitset(cats string) (*big.Int, error) { + bitset := new(big.Int) + + catlist := strings.Split(cats, ",") + for _, r := range catlist { + ranges := strings.SplitN(r, ".", 2) + if len(ranges) > 1 { + catstart, err := parseLevelItem(ranges[0], category) + if err != nil { + return nil, err + } + catend, err := parseLevelItem(ranges[1], category) + if err != nil { + return nil, err + } + for i := catstart; i <= catend; i++ { + bitset.SetBit(bitset, int(i), 1) + } + } else { + cat, err := parseLevelItem(ranges[0], category) + if err != nil { + return nil, err + } + bitset.SetBit(bitset, int(cat), 1) + } + } + + return bitset, nil +} + +// parseLevelItem parses and verifies that a sensitivity or category are valid +func parseLevelItem(s string, sep levelItem) (uint, error) { + if len(s) < minSensLen || levelItem(s[0]) != sep { + return 0, ErrLevelSyntax + } + val, err := strconv.ParseUint(s[1:], 10, 32) + if err != nil { + return 0, err + } + + return uint(val), nil +} + +// parseLevel fills a level from a string that contains +// a sensitivity and categories +func (l *level) parseLevel(levelStr string) error { + lvl := strings.SplitN(levelStr, ":", 2) + sens, err := parseLevelItem(lvl[0], sensitivity) + if err != nil { + return fmt.Errorf("failed to parse sensitivity: %w", err) + } + l.sens = sens + if len(lvl) > 1 { + cats, err := catsToBitset(lvl[1]) + if err != nil { + return fmt.Errorf("failed to parse categories: %w", err) + } + l.cats = cats + } + + return nil +} + +// rangeStrToMLSRange marshals a string representation of a range. +func rangeStrToMLSRange(rangeStr string) (*mlsRange, error) { + r := &mlsRange{} + l := strings.SplitN(rangeStr, "-", 2) + + switch len(l) { + // rangeStr that has a low and a high level, e.g. s4:c0.c1023-s6:c0.c1023 + case 2: + r.high = &level{} + if err := r.high.parseLevel(l[1]); err != nil { + return nil, fmt.Errorf("failed to parse high level %q: %w", l[1], err) + } + fallthrough + // rangeStr that is single level, e.g. s6:c0,c3,c5,c30.c1023 + case 1: + r.low = &level{} + if err := r.low.parseLevel(l[0]); err != nil { + return nil, fmt.Errorf("failed to parse low level %q: %w", l[0], err) + } + } + + if r.high == nil { + r.high = r.low + } + + return r, nil +} + +// bitsetToStr takes a category bitset and returns it in the +// canonical selinux syntax +func bitsetToStr(c *big.Int) string { + var str string + + length := 0 + for i := int(c.TrailingZeroBits()); i < c.BitLen(); i++ { + if c.Bit(i) == 0 { + continue + } + if length == 0 { + if str != "" { + str += "," + } + str += "c" + strconv.Itoa(i) + } + if c.Bit(i+1) == 1 { + length++ + continue + } + if length == 1 { + str += ",c" + strconv.Itoa(i) + } else if length > 1 { + str += ".c" + strconv.Itoa(i) + } + length = 0 + } + + return str +} + +func (l *level) equal(l2 *level) bool { + if l2 == nil || l == nil { + return l == l2 + } + if l2.sens != l.sens { + return false + } + if l2.cats == nil || l.cats == nil { + return l2.cats == l.cats + } + return l.cats.Cmp(l2.cats) == 0 +} + +// String returns an mlsRange as a string. +func (m mlsRange) String() string { + low := "s" + strconv.Itoa(int(m.low.sens)) + if m.low.cats != nil && m.low.cats.BitLen() > 0 { + low += ":" + bitsetToStr(m.low.cats) + } + + if m.low.equal(m.high) { + return low + } + + high := "s" + strconv.Itoa(int(m.high.sens)) + if m.high.cats != nil && m.high.cats.BitLen() > 0 { + high += ":" + bitsetToStr(m.high.cats) + } + + return low + "-" + high +} + +func max(a, b uint) uint { + if a > b { + return a + } + return b +} + +func min(a, b uint) uint { + if a < b { + return a + } + return b +} + +// calculateGlbLub computes the glb (greatest lower bound) and lub (least upper bound) +// of a source and target range. +// The glblub is calculated as the greater of the low sensitivities and +// the lower of the high sensitivities and the and of each category bitset. +func calculateGlbLub(sourceRange, targetRange string) (string, error) { + s, err := rangeStrToMLSRange(sourceRange) + if err != nil { + return "", err + } + t, err := rangeStrToMLSRange(targetRange) + if err != nil { + return "", err + } + + if s.high.sens < t.low.sens || t.high.sens < s.low.sens { + /* these ranges have no common sensitivities */ + return "", ErrIncomparable + } + + outrange := &mlsRange{low: &level{}, high: &level{}} + + /* take the greatest of the low */ + outrange.low.sens = max(s.low.sens, t.low.sens) + + /* take the least of the high */ + outrange.high.sens = min(s.high.sens, t.high.sens) + + /* find the intersecting categories */ + if s.low.cats != nil && t.low.cats != nil { + outrange.low.cats = new(big.Int) + outrange.low.cats.And(s.low.cats, t.low.cats) + } + if s.high.cats != nil && t.high.cats != nil { + outrange.high.cats = new(big.Int) + outrange.high.cats.And(s.high.cats, t.high.cats) + } + + return outrange.String(), nil +} + +func readWriteCon(fpath string, val string) (string, error) { + if fpath == "" { + return "", ErrEmptyPath + } + f, err := os.OpenFile(fpath, os.O_RDWR, 0) + if err != nil { + return "", err + } + defer f.Close() + + _, err = f.Write([]byte(val)) + if err != nil { + return "", err + } + + return readConFd(f) +} + +// peerLabel retrieves the label of the client on the other side of a socket +func peerLabel(fd uintptr) (string, error) { + l, err := unix.GetsockoptString(int(fd), unix.SOL_SOCKET, unix.SO_PEERSEC) + if err != nil { + return "", &os.PathError{Op: "getsockopt", Path: "fd " + strconv.Itoa(int(fd)), Err: err} + } + return l, nil +} + +// setKeyLabel takes a process label and tells the kernel to assign the +// label to the next kernel keyring that gets created +func setKeyLabel(label string) error { + err := writeCon("/proc/self/attr/keycreate", label) + if errors.Is(err, os.ErrNotExist) { + return nil + } + if label == "" && errors.Is(err, os.ErrPermission) { + return nil + } + return err +} + +// get returns the Context as a string +func (c Context) get() string { + if l := c["level"]; l != "" { + return c["user"] + ":" + c["role"] + ":" + c["type"] + ":" + l + } + return c["user"] + ":" + c["role"] + ":" + c["type"] +} + +// newContext creates a new Context struct from the specified label +func newContext(label string) (Context, error) { + c := make(Context) + + if len(label) != 0 { + con := strings.SplitN(label, ":", 4) + if len(con) < 3 { + return c, ErrInvalidLabel + } + c["user"] = con[0] + c["role"] = con[1] + c["type"] = con[2] + if len(con) > 3 { + c["level"] = con[3] + } + } + return c, nil +} + +// clearLabels clears all reserved labels +func clearLabels() { + state.Lock() + state.mcsList = make(map[string]bool) + state.Unlock() +} + +// reserveLabel reserves the MLS/MCS level component of the specified label +func reserveLabel(label string) { + if len(label) != 0 { + con := strings.SplitN(label, ":", 4) + if len(con) > 3 { + _ = mcsAdd(con[3]) + } + } +} + +func selinuxEnforcePath() string { + return filepath.Join(getSelinuxMountPoint(), "enforce") +} + +// isMLSEnabled checks if MLS is enabled. +func isMLSEnabled() bool { + enabledB, err := os.ReadFile(filepath.Join(getSelinuxMountPoint(), "mls")) + if err != nil { + return false + } + return bytes.Equal(enabledB, []byte{'1'}) +} + +// enforceMode returns the current SELinux mode Enforcing, Permissive, Disabled +func enforceMode() int { + var enforce int + + enforceB, err := os.ReadFile(selinuxEnforcePath()) + if err != nil { + return -1 + } + enforce, err = strconv.Atoi(string(enforceB)) + if err != nil { + return -1 + } + return enforce +} + +// setEnforceMode sets the current SELinux mode Enforcing, Permissive. +// Disabled is not valid, since this needs to be set at boot time. +func setEnforceMode(mode int) error { + //nolint:gosec // ignore G306: permissions to be 0600 or less. + return os.WriteFile(selinuxEnforcePath(), []byte(strconv.Itoa(mode)), 0o644) +} + +// defaultEnforceMode returns the systems default SELinux mode Enforcing, +// Permissive or Disabled. Note this is just the default at boot time. +// EnforceMode tells you the systems current mode. +func defaultEnforceMode() int { + switch readConfig(selinuxTag) { + case "enforcing": + return Enforcing + case "permissive": + return Permissive + } + return Disabled +} + +func mcsAdd(mcs string) error { + if mcs == "" { + return nil + } + state.Lock() + defer state.Unlock() + if state.mcsList[mcs] { + return ErrMCSAlreadyExists + } + state.mcsList[mcs] = true + return nil +} + +func mcsDelete(mcs string) { + if mcs == "" { + return + } + state.Lock() + defer state.Unlock() + state.mcsList[mcs] = false +} + +func intToMcs(id int, catRange uint32) string { + var ( + SETSIZE = int(catRange) + TIER = SETSIZE + ORD = id + ) + + if id < 1 || id > 523776 { + return "" + } + + for ORD > TIER { + ORD -= TIER + TIER-- + } + TIER = SETSIZE - TIER + ORD += TIER + return fmt.Sprintf("s0:c%d,c%d", TIER, ORD) +} + +func uniqMcs(catRange uint32) string { + var ( + n uint32 + c1, c2 uint32 + mcs string + ) + + for { + _ = binary.Read(rand.Reader, binary.LittleEndian, &n) + c1 = n % catRange + _ = binary.Read(rand.Reader, binary.LittleEndian, &n) + c2 = n % catRange + if c1 == c2 { + continue + } else if c1 > c2 { + c1, c2 = c2, c1 + } + mcs = fmt.Sprintf("s0:c%d,c%d", c1, c2) + if err := mcsAdd(mcs); err != nil { + continue + } + break + } + return mcs +} + +// releaseLabel un-reserves the MLS/MCS Level field of the specified label, +// allowing it to be used by another process. +func releaseLabel(label string) { + if len(label) != 0 { + con := strings.SplitN(label, ":", 4) + if len(con) > 3 { + mcsDelete(con[3]) + } + } +} + +// roFileLabel returns the specified SELinux readonly file label +func roFileLabel() string { + return readOnlyFileLabel +} + +func openContextFile() (*os.File, error) { + if f, err := os.Open(contextFile); err == nil { + return f, nil + } + return os.Open(filepath.Join(policyRoot(), "contexts", "lxc_contexts")) +} + +func loadLabels() { + labels = make(map[string]string) + in, err := openContextFile() + if err != nil { + return + } + defer in.Close() + + scanner := bufio.NewScanner(in) + + for scanner.Scan() { + line := bytes.TrimSpace(scanner.Bytes()) + if len(line) == 0 { + // Skip blank lines + continue + } + if line[0] == ';' || line[0] == '#' { + // Skip comments + continue + } + fields := bytes.SplitN(line, []byte{'='}, 2) + if len(fields) != 2 { + continue + } + key, val := bytes.TrimSpace(fields[0]), bytes.TrimSpace(fields[1]) + labels[string(key)] = string(bytes.Trim(val, `"`)) + } + + con, _ := NewContext(labels["file"]) + con["level"] = fmt.Sprintf("s0:c%d,c%d", maxCategory-2, maxCategory-1) + privContainerMountLabel = con.get() + reserveLabel(privContainerMountLabel) +} + +func label(key string) string { + loadLabelsOnce.Do(func() { + loadLabels() + }) + return labels[key] +} + +// kvmContainerLabels returns the default processLabel and mountLabel to be used +// for kvm containers by the calling process. +func kvmContainerLabels() (string, string) { + processLabel := label("kvm_process") + if processLabel == "" { + processLabel = label("process") + } + + return addMcs(processLabel, label("file")) +} + +// initContainerLabels returns the default processLabel and file labels to be +// used for containers running an init system like systemd by the calling process. +func initContainerLabels() (string, string) { + processLabel := label("init_process") + if processLabel == "" { + processLabel = label("process") + } + + return addMcs(processLabel, label("file")) +} + +// containerLabels returns an allocated processLabel and fileLabel to be used for +// container labeling by the calling process. +func containerLabels() (processLabel string, fileLabel string) { + if !getEnabled() { + return "", "" + } + + processLabel = label("process") + fileLabel = label("file") + readOnlyFileLabel = label("ro_file") + + if processLabel == "" || fileLabel == "" { + return "", fileLabel + } + + if readOnlyFileLabel == "" { + readOnlyFileLabel = fileLabel + } + + return addMcs(processLabel, fileLabel) +} + +func addMcs(processLabel, fileLabel string) (string, string) { + scon, _ := NewContext(processLabel) + if scon["level"] != "" { + mcs := uniqMcs(CategoryRange) + scon["level"] = mcs + processLabel = scon.Get() + scon, _ = NewContext(fileLabel) + scon["level"] = mcs + fileLabel = scon.Get() + } + return processLabel, fileLabel +} + +// securityCheckContext validates that the SELinux label is understood by the kernel +func securityCheckContext(val string) error { + //nolint:gosec // ignore G306: permissions to be 0600 or less. + return os.WriteFile(filepath.Join(getSelinuxMountPoint(), "context"), []byte(val), 0o644) +} + +// copyLevel returns a label with the MLS/MCS level from src label replaced on +// the dest label. +func copyLevel(src, dest string) (string, error) { + if src == "" { + return "", nil + } + if err := SecurityCheckContext(src); err != nil { + return "", err + } + if err := SecurityCheckContext(dest); err != nil { + return "", err + } + scon, err := NewContext(src) + if err != nil { + return "", err + } + tcon, err := NewContext(dest) + if err != nil { + return "", err + } + mcsDelete(tcon["level"]) + _ = mcsAdd(scon["level"]) + tcon["level"] = scon["level"] + return tcon.Get(), nil +} + +// chcon changes the fpath file object to the SELinux label. +// If fpath is a directory and recurse is true, then chcon walks the +// directory tree setting the label. +func chcon(fpath string, label string, recurse bool) error { + if fpath == "" { + return ErrEmptyPath + } + if label == "" { + return nil + } + + excludePaths := map[string]bool{ + "/": true, + "/bin": true, + "/boot": true, + "/dev": true, + "/etc": true, + "/etc/passwd": true, + "/etc/pki": true, + "/etc/shadow": true, + "/home": true, + "/lib": true, + "/lib64": true, + "/media": true, + "/opt": true, + "/proc": true, + "/root": true, + "/run": true, + "/sbin": true, + "/srv": true, + "/sys": true, + "/tmp": true, + "/usr": true, + "/var": true, + "/var/lib": true, + "/var/log": true, + } + + if home := os.Getenv("HOME"); home != "" { + excludePaths[home] = true + } + + if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" { + if usr, err := user.Lookup(sudoUser); err == nil { + excludePaths[usr.HomeDir] = true + } + } + + if fpath != "/" { + fpath = strings.TrimSuffix(fpath, "/") + } + if excludePaths[fpath] { + return fmt.Errorf("SELinux relabeling of %s is not allowed", fpath) + } + + if !recurse { + err := lSetFileLabel(fpath, label) + if err != nil { + // Check if file doesn't exist, must have been removed + if errors.Is(err, os.ErrNotExist) { + return nil + } + // Check if current label is correct on disk + flabel, nerr := lFileLabel(fpath) + if nerr == nil && flabel == label { + return nil + } + // Check if file doesn't exist, must have been removed + if errors.Is(nerr, os.ErrNotExist) { + return nil + } + return err + } + return nil + } + + return rchcon(fpath, label) +} + +func rchcon(fpath, label string) error { //revive:disable:cognitive-complexity + fastMode := false + // If the current label matches the new label, assume + // other labels are correct. + if cLabel, err := lFileLabel(fpath); err == nil && cLabel == label { + fastMode = true + } + return pwalkdir.Walk(fpath, func(p string, _ fs.DirEntry, _ error) error { + if fastMode { + if cLabel, err := lFileLabel(fpath); err == nil && cLabel == label { + return nil + } + } + err := lSetFileLabel(p, label) + // Walk a file tree can race with removal, so ignore ENOENT. + if errors.Is(err, os.ErrNotExist) { + return nil + } + return err + }) +} + +// dupSecOpt takes an SELinux process label and returns security options that +// can be used to set the SELinux Type and Level for future container processes. +func dupSecOpt(src string) ([]string, error) { + if src == "" { + return nil, nil + } + con, err := NewContext(src) + if err != nil { + return nil, err + } + if con["user"] == "" || + con["role"] == "" || + con["type"] == "" { + return nil, nil + } + dup := []string{ + "user:" + con["user"], + "role:" + con["role"], + "type:" + con["type"], + } + + if con["level"] != "" { + dup = append(dup, "level:"+con["level"]) + } + + return dup, nil +} + +// findUserInContext scans the reader for a valid SELinux context +// match that is verified with the verifier. Invalid contexts are +// skipped. It returns a matched context or an empty string if no +// match is found. If a scanner error occurs, it is returned. +func findUserInContext(context Context, r io.Reader, verifier func(string) error) (string, error) { + fromRole := context["role"] + fromType := context["type"] + scanner := bufio.NewScanner(r) + + for scanner.Scan() { + fromConns := strings.Fields(scanner.Text()) + if len(fromConns) == 0 { + // Skip blank lines + continue + } + + line := fromConns[0] + + if line[0] == ';' || line[0] == '#' { + // Skip comments + continue + } + + // user context files contexts are formatted as + // role_r:type_t:s0 where the user is missing. + lineArr := strings.SplitN(line, ":", 4) + // skip context with typo, or role and type do not match + if len(lineArr) != 3 || + lineArr[0] != fromRole || + lineArr[1] != fromType { + continue + } + + for _, cc := range fromConns[1:] { + toConns := strings.SplitN(cc, ":", 4) + if len(toConns) != 3 { + continue + } + + context["role"] = toConns[0] + context["type"] = toConns[1] + + outConn := context.get() + if err := verifier(outConn); err != nil { + continue + } + + return outConn, nil + } + } + if err := scanner.Err(); err != nil { + return "", fmt.Errorf("failed to scan for context: %w", err) + } + + return "", nil +} + +func getDefaultContextFromReaders(c *defaultSECtx) (string, error) { + if c.verifier == nil { + return "", ErrVerifierNil + } + + context, err := newContext(c.scon) + if err != nil { + return "", fmt.Errorf("failed to create label for %s: %w", c.scon, err) + } + + // set so the verifier validates the matched context with the provided user and level. + context["user"] = c.user + context["level"] = c.level + + conn, err := findUserInContext(context, c.userRdr, c.verifier) + if err != nil { + return "", err + } + + if conn != "" { + return conn, nil + } + + conn, err = findUserInContext(context, c.defaultRdr, c.verifier) + if err != nil { + return "", err + } + + if conn != "" { + return conn, nil + } + + return "", fmt.Errorf("context %q not found: %w", c.scon, ErrContextMissing) +} + +func getDefaultContextWithLevel(user, level, scon string) (string, error) { + userPath := filepath.Join(policyRoot(), selinuxUsersDir, user) + fu, err := os.Open(userPath) + if err != nil { + return "", err + } + defer fu.Close() + + defaultPath := filepath.Join(policyRoot(), defaultContexts) + fd, err := os.Open(defaultPath) + if err != nil { + return "", err + } + defer fd.Close() + + c := defaultSECtx{ + user: user, + level: level, + scon: scon, + userRdr: fu, + defaultRdr: fd, + verifier: securityCheckContext, + } + + return getDefaultContextFromReaders(&c) +} |