summaryrefslogtreecommitdiffstats
path: root/dependencies/pkg/mod/github.com/lib/pq@v1.10.7/ssl_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'dependencies/pkg/mod/github.com/lib/pq@v1.10.7/ssl_test.go')
-rw-r--r--dependencies/pkg/mod/github.com/lib/pq@v1.10.7/ssl_test.go421
1 files changed, 421 insertions, 0 deletions
diff --git a/dependencies/pkg/mod/github.com/lib/pq@v1.10.7/ssl_test.go b/dependencies/pkg/mod/github.com/lib/pq@v1.10.7/ssl_test.go
new file mode 100644
index 0000000..64d68cf
--- /dev/null
+++ b/dependencies/pkg/mod/github.com/lib/pq@v1.10.7/ssl_test.go
@@ -0,0 +1,421 @@
+package pq
+
+// This file contains SSL tests
+
+import (
+ "bytes"
+ _ "crypto/sha256"
+ "crypto/tls"
+ "crypto/x509"
+ "database/sql"
+ "fmt"
+ "io"
+ "net"
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+ "time"
+)
+
+func maybeSkipSSLTests(t *testing.T) {
+ // Require some special variables for testing certificates
+ if os.Getenv("PQSSLCERTTEST_PATH") == "" {
+ t.Skip("PQSSLCERTTEST_PATH not set, skipping SSL tests")
+ }
+
+ value := os.Getenv("PQGOSSLTESTS")
+ if value == "" || value == "0" {
+ t.Skip("PQGOSSLTESTS not enabled, skipping SSL tests")
+ } else if value != "1" {
+ t.Fatalf("unexpected value %q for PQGOSSLTESTS", value)
+ }
+}
+
+func openSSLConn(t *testing.T, conninfo string) (*sql.DB, error) {
+ db, err := openTestConnConninfo(conninfo)
+ if err != nil {
+ // should never fail
+ t.Fatal(err)
+ }
+ // Do something with the connection to see whether it's working or not.
+ tx, err := db.Begin()
+ if err == nil {
+ return db, tx.Rollback()
+ }
+ _ = db.Close()
+ return nil, err
+}
+
+func checkSSLSetup(t *testing.T, conninfo string) {
+ _, err := openSSLConn(t, conninfo)
+ if pge, ok := err.(*Error); ok {
+ if pge.Code.Name() != "invalid_authorization_specification" {
+ t.Fatalf("unexpected error code '%s'", pge.Code.Name())
+ }
+ } else {
+ t.Fatalf("expected %T, got %v", (*Error)(nil), err)
+ }
+}
+
+// Connect over SSL and run a simple query to test the basics
+func TestSSLConnection(t *testing.T) {
+ maybeSkipSSLTests(t)
+ // Environment sanity check: should fail without SSL
+ checkSSLSetup(t, "sslmode=disable user=pqgossltest")
+
+ db, err := openSSLConn(t, "sslmode=require user=pqgossltest")
+ if err != nil {
+ t.Fatal(err)
+ }
+ rows, err := db.Query("SELECT 1")
+ if err != nil {
+ t.Fatal(err)
+ }
+ rows.Close()
+}
+
+// Test sslmode=verify-full
+func TestSSLVerifyFull(t *testing.T) {
+ maybeSkipSSLTests(t)
+ // Environment sanity check: should fail without SSL
+ checkSSLSetup(t, "sslmode=disable user=pqgossltest")
+
+ // Not OK according to the system CA
+ _, err := openSSLConn(t, "host=postgres sslmode=verify-full user=pqgossltest")
+ if err == nil {
+ t.Fatal("expected error")
+ }
+ _, ok := err.(x509.UnknownAuthorityError)
+ if !ok {
+ _, ok := err.(x509.HostnameError)
+ if !ok {
+ t.Fatalf("expected x509.UnknownAuthorityError or x509.HostnameError, got %#+v", err)
+ }
+ }
+
+ rootCertPath := filepath.Join(os.Getenv("PQSSLCERTTEST_PATH"), "root.crt")
+ rootCert := "sslrootcert=" + rootCertPath + " "
+ // No match on Common Name
+ _, err = openSSLConn(t, rootCert+"host=127.0.0.1 sslmode=verify-full user=pqgossltest")
+ if err == nil {
+ t.Fatal("expected error")
+ }
+ _, ok = err.(x509.HostnameError)
+ if !ok {
+ t.Fatalf("expected x509.HostnameError, got %#+v", err)
+ }
+ // OK
+ _, err = openSSLConn(t, rootCert+"host=postgres sslmode=verify-full user=pqgossltest")
+ if err != nil {
+ t.Fatal(err)
+ }
+}
+
+// Test sslmode=require sslrootcert=rootCertPath
+func TestSSLRequireWithRootCert(t *testing.T) {
+ maybeSkipSSLTests(t)
+ // Environment sanity check: should fail without SSL
+ checkSSLSetup(t, "sslmode=disable user=pqgossltest")
+
+ bogusRootCertPath := filepath.Join(os.Getenv("PQSSLCERTTEST_PATH"), "bogus_root.crt")
+ bogusRootCert := "sslrootcert=" + bogusRootCertPath + " "
+
+ // Not OK according to the bogus CA
+ _, err := openSSLConn(t, bogusRootCert+"host=postgres sslmode=require user=pqgossltest")
+ if err == nil {
+ t.Fatal("expected error")
+ }
+ _, ok := err.(x509.UnknownAuthorityError)
+ if !ok {
+ t.Fatalf("expected x509.UnknownAuthorityError, got %s, %#+v", err, err)
+ }
+
+ nonExistentCertPath := filepath.Join(os.Getenv("PQSSLCERTTEST_PATH"), "non_existent.crt")
+ nonExistentCert := "sslrootcert=" + nonExistentCertPath + " "
+
+ // No match on Common Name, but that's OK because we're not validating anything.
+ _, err = openSSLConn(t, nonExistentCert+"host=127.0.0.1 sslmode=require user=pqgossltest")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ rootCertPath := filepath.Join(os.Getenv("PQSSLCERTTEST_PATH"), "root.crt")
+ rootCert := "sslrootcert=" + rootCertPath + " "
+
+ // No match on Common Name, but that's OK because we're not validating the CN.
+ _, err = openSSLConn(t, rootCert+"host=127.0.0.1 sslmode=require user=pqgossltest")
+ if err != nil {
+ t.Fatal(err)
+ }
+ // Everything OK
+ _, err = openSSLConn(t, rootCert+"host=postgres sslmode=require user=pqgossltest")
+ if err != nil {
+ t.Fatal(err)
+ }
+}
+
+// Test sslmode=verify-ca
+func TestSSLVerifyCA(t *testing.T) {
+ maybeSkipSSLTests(t)
+ // Environment sanity check: should fail without SSL
+ checkSSLSetup(t, "sslmode=disable user=pqgossltest")
+
+ // Not OK according to the system CA
+ {
+ _, err := openSSLConn(t, "host=postgres sslmode=verify-ca user=pqgossltest")
+ if _, ok := err.(x509.UnknownAuthorityError); !ok {
+ t.Fatalf("expected %T, got %#+v", x509.UnknownAuthorityError{}, err)
+ }
+ }
+
+ // Still not OK according to the system CA; empty sslrootcert is treated as unspecified.
+ {
+ _, err := openSSLConn(t, "host=postgres sslmode=verify-ca user=pqgossltest sslrootcert=''")
+ if _, ok := err.(x509.UnknownAuthorityError); !ok {
+ t.Fatalf("expected %T, got %#+v", x509.UnknownAuthorityError{}, err)
+ }
+ }
+
+ rootCertPath := filepath.Join(os.Getenv("PQSSLCERTTEST_PATH"), "root.crt")
+ rootCert := "sslrootcert=" + rootCertPath + " "
+ // No match on Common Name, but that's OK
+ if _, err := openSSLConn(t, rootCert+"host=127.0.0.1 sslmode=verify-ca user=pqgossltest"); err != nil {
+ t.Fatal(err)
+ }
+ // Everything OK
+ if _, err := openSSLConn(t, rootCert+"host=postgres sslmode=verify-ca user=pqgossltest"); err != nil {
+ t.Fatal(err)
+ }
+}
+
+// Authenticate over SSL using client certificates
+func TestSSLClientCertificates(t *testing.T) {
+ maybeSkipSSLTests(t)
+ // Environment sanity check: should fail without SSL
+ checkSSLSetup(t, "sslmode=disable user=pqgossltest")
+
+ const baseinfo = "sslmode=require user=pqgosslcert"
+
+ // Certificate not specified, should fail
+ {
+ _, err := openSSLConn(t, baseinfo)
+ if pge, ok := err.(*Error); ok {
+ if pge.Code.Name() != "invalid_authorization_specification" {
+ t.Fatalf("unexpected error code '%s'", pge.Code.Name())
+ }
+ } else {
+ t.Fatalf("expected %T, got %v", (*Error)(nil), err)
+ }
+ }
+
+ // Empty certificate specified, should fail
+ {
+ _, err := openSSLConn(t, baseinfo+" sslcert=''")
+ if pge, ok := err.(*Error); ok {
+ if pge.Code.Name() != "invalid_authorization_specification" {
+ t.Fatalf("unexpected error code '%s'", pge.Code.Name())
+ }
+ } else {
+ t.Fatalf("expected %T, got %v", (*Error)(nil), err)
+ }
+ }
+
+ // Non-existent certificate specified, should fail
+ {
+ _, err := openSSLConn(t, baseinfo+" sslcert=/tmp/filedoesnotexist")
+ if pge, ok := err.(*Error); ok {
+ if pge.Code.Name() != "invalid_authorization_specification" {
+ t.Fatalf("unexpected error code '%s'", pge.Code.Name())
+ }
+ } else {
+ t.Fatalf("expected %T, got %v", (*Error)(nil), err)
+ }
+ }
+
+ certpath, ok := os.LookupEnv("PQSSLCERTTEST_PATH")
+ if !ok {
+ t.Fatalf("PQSSLCERTTEST_PATH not present in environment")
+ }
+
+ sslcert := filepath.Join(certpath, "postgresql.crt")
+
+ // Cert present, key not specified, should fail
+ {
+ _, err := openSSLConn(t, baseinfo+" sslcert="+sslcert)
+ if _, ok := err.(*os.PathError); !ok {
+ t.Fatalf("expected %T, got %#+v", (*os.PathError)(nil), err)
+ }
+ }
+
+ // Cert present, empty key specified, should fail
+ {
+ _, err := openSSLConn(t, baseinfo+" sslcert="+sslcert+" sslkey=''")
+ if _, ok := err.(*os.PathError); !ok {
+ t.Fatalf("expected %T, got %#+v", (*os.PathError)(nil), err)
+ }
+ }
+
+ // Cert present, non-existent key, should fail
+ {
+ _, err := openSSLConn(t, baseinfo+" sslcert="+sslcert+" sslkey=/tmp/filedoesnotexist")
+ if _, ok := err.(*os.PathError); !ok {
+ t.Fatalf("expected %T, got %#+v", (*os.PathError)(nil), err)
+ }
+ }
+
+ // Key has wrong permissions (passing the cert as the key), should fail
+ if _, err := openSSLConn(t, baseinfo+" sslcert="+sslcert+" sslkey="+sslcert); err != ErrSSLKeyHasWorldPermissions {
+ t.Fatalf("expected %s, got %#+v", ErrSSLKeyHasWorldPermissions, err)
+ }
+
+ sslkey := filepath.Join(certpath, "postgresql.key")
+
+ // Should work
+ if db, err := openSSLConn(t, baseinfo+" sslcert="+sslcert+" sslkey="+sslkey); err != nil {
+ t.Fatal(err)
+ } else {
+ rows, err := db.Query("SELECT 1")
+ if err != nil {
+ t.Fatal(err)
+ }
+ if err := rows.Close(); err != nil {
+ t.Fatal(err)
+ }
+ if err := db.Close(); err != nil {
+ t.Fatal(err)
+ }
+ }
+}
+
+// Check that clint sends SNI data when `sslsni` is not disabled
+func TestSNISupport(t *testing.T) {
+ t.Parallel()
+ tests := []struct {
+ name string
+ conn_param string
+ hostname string
+ expected_sni string
+ }{
+ {
+ name: "SNI is set by default",
+ conn_param: "",
+ hostname: "localhost",
+ expected_sni: "localhost",
+ },
+ {
+ name: "SNI is passed when asked for",
+ conn_param: "sslsni=1",
+ hostname: "localhost",
+ expected_sni: "localhost",
+ },
+ {
+ name: "SNI is not passed when disabled",
+ conn_param: "sslsni=0",
+ hostname: "localhost",
+ expected_sni: "",
+ },
+ {
+ name: "SNI is not set for IPv4",
+ conn_param: "",
+ hostname: "127.0.0.1",
+ expected_sni: "",
+ },
+ }
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ // Start mock postgres server on OS-provided port
+ listener, err := net.Listen("tcp", "127.0.0.1:")
+ if err != nil {
+ t.Fatal(err)
+ }
+ serverErrChan := make(chan error, 1)
+ serverSNINameChan := make(chan string, 1)
+ go mockPostgresSSL(listener, serverErrChan, serverSNINameChan)
+
+ defer listener.Close()
+ defer close(serverErrChan)
+ defer close(serverSNINameChan)
+
+ // Try to establish a connection with the mock server. Connection will error out after TLS
+ // clientHello, but it is enough to catch SNI data on the server side
+ port := strings.Split(listener.Addr().String(), ":")[1]
+ connStr := fmt.Sprintf("sslmode=require host=%s port=%s %s", tt.hostname, port, tt.conn_param)
+
+ // We are okay to skip this error as we are polling serverErrChan and we'll get an error
+ // or timeout from the server side in case of problems here.
+ db, _ := sql.Open("postgres", connStr)
+ _, _ = db.Exec("SELECT 1")
+
+ // Check SNI data
+ select {
+ case sniHost := <-serverSNINameChan:
+ if sniHost != tt.expected_sni {
+ t.Fatalf("Expected SNI to be 'localhost', got '%+v' instead", sniHost)
+ }
+ case err = <-serverErrChan:
+ t.Fatalf("mock server failed with error: %+v", err)
+ case <-time.After(time.Second):
+ t.Fatal("exceeded connection timeout without erroring out")
+ }
+ })
+ }
+}
+
+// Make a postgres mock server to test TLS SNI
+//
+// Accepts postgres StartupMessage and handles TLS clientHello, then closes a connection.
+// While reading clientHello catch passed SNI data and report it to nameChan.
+func mockPostgresSSL(listener net.Listener, errChan chan error, nameChan chan string) {
+ var sniHost string
+
+ conn, err := listener.Accept()
+ if err != nil {
+ errChan <- err
+ return
+ }
+ defer conn.Close()
+
+ err = conn.SetDeadline(time.Now().Add(time.Second))
+ if err != nil {
+ errChan <- err
+ return
+ }
+
+ // Receive StartupMessage with SSL Request
+ startupMessage := make([]byte, 8)
+ if _, err := io.ReadFull(conn, startupMessage); err != nil {
+ errChan <- err
+ return
+ }
+ // StartupMessage: first four bytes -- total len = 8, last four bytes SslRequestNumber
+ if !bytes.Equal(startupMessage, []byte{0, 0, 0, 0x8, 0x4, 0xd2, 0x16, 0x2f}) {
+ errChan <- fmt.Errorf("unexpected startup message: %#v", startupMessage)
+ return
+ }
+
+ // Respond with SSLOk
+ _, err = conn.Write([]byte("S"))
+ if err != nil {
+ errChan <- err
+ return
+ }
+
+ // Set up TLS context to catch clientHello. It will always error out during handshake
+ // as no certificate is set.
+ srv := tls.Server(conn, &tls.Config{
+ GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) {
+ sniHost = argHello.ServerName
+ return nil, nil
+ },
+ })
+ defer srv.Close()
+
+ // Do the TLS handshake ignoring errors
+ _ = srv.Handshake()
+
+ nameChan <- sniHost
+}