diff options
Diffstat (limited to 'src/crypto/x509/root_unix_test.go')
-rw-r--r-- | src/crypto/x509/root_unix_test.go | 228 |
1 files changed, 228 insertions, 0 deletions
diff --git a/src/crypto/x509/root_unix_test.go b/src/crypto/x509/root_unix_test.go new file mode 100644 index 0000000..7197a0d --- /dev/null +++ b/src/crypto/x509/root_unix_test.go @@ -0,0 +1,228 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build dragonfly || freebsd || linux || netbsd || openbsd || solaris + +package x509 + +import ( + "bytes" + "fmt" + "os" + "path/filepath" + "reflect" + "strings" + "testing" +) + +const ( + testDir = "testdata" + testDirCN = "test-dir" + testFile = "test-file.crt" + testFileCN = "test-file" + testMissing = "missing" +) + +func TestEnvVars(t *testing.T) { + testCases := []struct { + name string + fileEnv string + dirEnv string + files []string + dirs []string + cns []string + }{ + { + // Environment variables override the default locations preventing fall through. + name: "override-defaults", + fileEnv: testMissing, + dirEnv: testMissing, + files: []string{testFile}, + dirs: []string{testDir}, + cns: nil, + }, + { + // File environment overrides default file locations. + name: "file", + fileEnv: testFile, + dirEnv: "", + files: nil, + dirs: nil, + cns: []string{testFileCN}, + }, + { + // Directory environment overrides default directory locations. + name: "dir", + fileEnv: "", + dirEnv: testDir, + files: nil, + dirs: nil, + cns: []string{testDirCN}, + }, + { + // File & directory environment overrides both default locations. + name: "file+dir", + fileEnv: testFile, + dirEnv: testDir, + files: nil, + dirs: nil, + cns: []string{testFileCN, testDirCN}, + }, + { + // Environment variable empty / unset uses default locations. + name: "empty-fall-through", + fileEnv: "", + dirEnv: "", + files: []string{testFile}, + dirs: []string{testDir}, + cns: []string{testFileCN, testDirCN}, + }, + } + + // Save old settings so we can restore before the test ends. + origCertFiles, origCertDirectories := certFiles, certDirectories + origFile, origDir := os.Getenv(certFileEnv), os.Getenv(certDirEnv) + defer func() { + certFiles = origCertFiles + certDirectories = origCertDirectories + os.Setenv(certFileEnv, origFile) + os.Setenv(certDirEnv, origDir) + }() + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if err := os.Setenv(certFileEnv, tc.fileEnv); err != nil { + t.Fatalf("setenv %q failed: %v", certFileEnv, err) + } + if err := os.Setenv(certDirEnv, tc.dirEnv); err != nil { + t.Fatalf("setenv %q failed: %v", certDirEnv, err) + } + + certFiles, certDirectories = tc.files, tc.dirs + + r, err := loadSystemRoots() + if err != nil { + t.Fatal("unexpected failure:", err) + } + + if r == nil { + t.Fatal("nil roots") + } + + // Verify that the returned certs match, otherwise report where the mismatch is. + for i, cn := range tc.cns { + if i >= r.len() { + t.Errorf("missing cert %v @ %v", cn, i) + } else if r.mustCert(t, i).Subject.CommonName != cn { + fmt.Printf("%#v\n", r.mustCert(t, 0).Subject) + t.Errorf("unexpected cert common name %q, want %q", r.mustCert(t, i).Subject.CommonName, cn) + } + } + if r.len() > len(tc.cns) { + t.Errorf("got %v certs, which is more than %v wanted", r.len(), len(tc.cns)) + } + }) + } +} + +// Ensure that "SSL_CERT_DIR" when used as the environment +// variable delimited by colons, allows loadSystemRoots to +// load all the roots from the respective directories. +// See https://golang.org/issue/35325. +func TestLoadSystemCertsLoadColonSeparatedDirs(t *testing.T) { + origFile, origDir := os.Getenv(certFileEnv), os.Getenv(certDirEnv) + origCertFiles := certFiles[:] + + // To prevent any other certs from being loaded in + // through "SSL_CERT_FILE" or from known "certFiles", + // clear them all, and they'll be reverting on defer. + certFiles = certFiles[:0] + os.Setenv(certFileEnv, "") + + defer func() { + certFiles = origCertFiles[:] + os.Setenv(certDirEnv, origDir) + os.Setenv(certFileEnv, origFile) + }() + + tmpDir := t.TempDir() + + rootPEMs := []string{ + geoTrustRoot, + googleLeaf, + startComRoot, + } + + var certDirs []string + for i, certPEM := range rootPEMs { + certDir := filepath.Join(tmpDir, fmt.Sprintf("cert-%d", i)) + if err := os.MkdirAll(certDir, 0755); err != nil { + t.Fatalf("Failed to create certificate dir: %v", err) + } + certOutFile := filepath.Join(certDir, "cert.crt") + if err := os.WriteFile(certOutFile, []byte(certPEM), 0655); err != nil { + t.Fatalf("Failed to write certificate to file: %v", err) + } + certDirs = append(certDirs, certDir) + } + + // Sanity check: the number of certDirs should be equal to the number of roots. + if g, w := len(certDirs), len(rootPEMs); g != w { + t.Fatalf("Failed sanity check: len(certsDir)=%d is not equal to len(rootsPEMS)=%d", g, w) + } + + // Now finally concatenate them with a colon. + colonConcatCertDirs := strings.Join(certDirs, ":") + os.Setenv(certDirEnv, colonConcatCertDirs) + gotPool, err := loadSystemRoots() + if err != nil { + t.Fatalf("Failed to load system roots: %v", err) + } + subjects := gotPool.Subjects() + // We expect exactly len(rootPEMs) subjects back. + if g, w := len(subjects), len(rootPEMs); g != w { + t.Fatalf("Invalid number of subjects: got %d want %d", g, w) + } + + wantPool := NewCertPool() + for _, certPEM := range rootPEMs { + wantPool.AppendCertsFromPEM([]byte(certPEM)) + } + strCertPool := func(p *CertPool) string { + return string(bytes.Join(p.Subjects(), []byte("\n"))) + } + + if !certPoolEqual(gotPool, wantPool) { + g, w := strCertPool(gotPool), strCertPool(wantPool) + t.Fatalf("Mismatched certPools\nGot:\n%s\n\nWant:\n%s", g, w) + } +} + +func TestReadUniqueDirectoryEntries(t *testing.T) { + tmp := t.TempDir() + temp := func(base string) string { return filepath.Join(tmp, base) } + if f, err := os.Create(temp("file")); err != nil { + t.Fatal(err) + } else { + f.Close() + } + if err := os.Symlink("target-in", temp("link-in")); err != nil { + t.Fatal(err) + } + if err := os.Symlink("../target-out", temp("link-out")); err != nil { + t.Fatal(err) + } + got, err := readUniqueDirectoryEntries(tmp) + if err != nil { + t.Fatal(err) + } + gotNames := []string{} + for _, fi := range got { + gotNames = append(gotNames, fi.Name()) + } + wantNames := []string{"file", "link-out"} + if !reflect.DeepEqual(gotNames, wantNames) { + t.Errorf("got %q; want %q", gotNames, wantNames) + } +} |