diff options
Diffstat (limited to 'src/crypto/tls/handshake_client_test.go')
-rw-r--r-- | src/crypto/tls/handshake_client_test.go | 2597 |
1 files changed, 2597 insertions, 0 deletions
diff --git a/src/crypto/tls/handshake_client_test.go b/src/crypto/tls/handshake_client_test.go new file mode 100644 index 0000000..749c9fc --- /dev/null +++ b/src/crypto/tls/handshake_client_test.go @@ -0,0 +1,2597 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "bytes" + "context" + "crypto/rsa" + "crypto/x509" + "encoding/base64" + "encoding/binary" + "encoding/pem" + "errors" + "fmt" + "io" + "math/big" + "net" + "os" + "os/exec" + "path/filepath" + "reflect" + "runtime" + "strconv" + "strings" + "testing" + "time" +) + +// Note: see comment in handshake_test.go for details of how the reference +// tests work. + +// opensslInputEvent enumerates possible inputs that can be sent to an `openssl +// s_client` process. +type opensslInputEvent int + +const ( + // opensslRenegotiate causes OpenSSL to request a renegotiation of the + // connection. + opensslRenegotiate opensslInputEvent = iota + + // opensslSendBanner causes OpenSSL to send the contents of + // opensslSentinel on the connection. + opensslSendSentinel + + // opensslKeyUpdate causes OpenSSL to send a key update message to the + // client and request one back. + opensslKeyUpdate +) + +const opensslSentinel = "SENTINEL\n" + +type opensslInput chan opensslInputEvent + +func (i opensslInput) Read(buf []byte) (n int, err error) { + for event := range i { + switch event { + case opensslRenegotiate: + return copy(buf, []byte("R\n")), nil + case opensslKeyUpdate: + return copy(buf, []byte("K\n")), nil + case opensslSendSentinel: + return copy(buf, []byte(opensslSentinel)), nil + default: + panic("unknown event") + } + } + + return 0, io.EOF +} + +// opensslOutputSink is an io.Writer that receives the stdout and stderr from an +// `openssl` process and sends a value to handshakeComplete or readKeyUpdate +// when certain messages are seen. +type opensslOutputSink struct { + handshakeComplete chan struct{} + readKeyUpdate chan struct{} + all []byte + line []byte +} + +func newOpensslOutputSink() *opensslOutputSink { + return &opensslOutputSink{make(chan struct{}), make(chan struct{}), nil, nil} +} + +// opensslEndOfHandshake is a message that the “openssl s_server” tool will +// print when a handshake completes if run with “-state”. +const opensslEndOfHandshake = "SSL_accept:SSLv3/TLS write finished" + +// opensslReadKeyUpdate is a message that the “openssl s_server” tool will +// print when a KeyUpdate message is received if run with “-state”. +const opensslReadKeyUpdate = "SSL_accept:TLSv1.3 read client key update" + +func (o *opensslOutputSink) Write(data []byte) (n int, err error) { + o.line = append(o.line, data...) + o.all = append(o.all, data...) + + for { + line, next, ok := bytes.Cut(o.line, []byte("\n")) + if !ok { + break + } + + if bytes.Equal([]byte(opensslEndOfHandshake), line) { + o.handshakeComplete <- struct{}{} + } + if bytes.Equal([]byte(opensslReadKeyUpdate), line) { + o.readKeyUpdate <- struct{}{} + } + o.line = next + } + + return len(data), nil +} + +func (o *opensslOutputSink) String() string { + return string(o.all) +} + +// clientTest represents a test of the TLS client handshake against a reference +// implementation. +type clientTest struct { + // name is a freeform string identifying the test and the file in which + // the expected results will be stored. + name string + // args, if not empty, contains a series of arguments for the + // command to run for the reference server. + args []string + // config, if not nil, contains a custom Config to use for this test. + config *Config + // cert, if not empty, contains a DER-encoded certificate for the + // reference server. + cert []byte + // key, if not nil, contains either a *rsa.PrivateKey, ed25519.PrivateKey or + // *ecdsa.PrivateKey which is the private key for the reference server. + key any + // extensions, if not nil, contains a list of extension data to be returned + // from the ServerHello. The data should be in standard TLS format with + // a 2-byte uint16 type, 2-byte data length, followed by the extension data. + extensions [][]byte + // validate, if not nil, is a function that will be called with the + // ConnectionState of the resulting connection. It returns a non-nil + // error if the ConnectionState is unacceptable. + validate func(ConnectionState) error + // numRenegotiations is the number of times that the connection will be + // renegotiated. + numRenegotiations int + // renegotiationExpectedToFail, if not zero, is the number of the + // renegotiation attempt that is expected to fail. + renegotiationExpectedToFail int + // checkRenegotiationError, if not nil, is called with any error + // arising from renegotiation. It can map expected errors to nil to + // ignore them. + checkRenegotiationError func(renegotiationNum int, err error) error + // sendKeyUpdate will cause the server to send a KeyUpdate message. + sendKeyUpdate bool +} + +var serverCommand = []string{"openssl", "s_server", "-no_ticket", "-num_tickets", "0"} + +// connFromCommand starts the reference server process, connects to it and +// returns a recordingConn for the connection. The stdin return value is an +// opensslInput for the stdin of the child process. It must be closed before +// Waiting for child. +func (test *clientTest) connFromCommand() (conn *recordingConn, child *exec.Cmd, stdin opensslInput, stdout *opensslOutputSink, err error) { + cert := testRSACertificate + if len(test.cert) > 0 { + cert = test.cert + } + certPath := tempFile(string(cert)) + defer os.Remove(certPath) + + var key any = testRSAPrivateKey + if test.key != nil { + key = test.key + } + derBytes, err := x509.MarshalPKCS8PrivateKey(key) + if err != nil { + panic(err) + } + + var pemOut bytes.Buffer + pem.Encode(&pemOut, &pem.Block{Type: "PRIVATE KEY", Bytes: derBytes}) + + keyPath := tempFile(pemOut.String()) + defer os.Remove(keyPath) + + var command []string + command = append(command, serverCommand...) + command = append(command, test.args...) + command = append(command, "-cert", certPath, "-certform", "DER", "-key", keyPath) + // serverPort contains the port that OpenSSL will listen on. OpenSSL + // can't take "0" as an argument here so we have to pick a number and + // hope that it's not in use on the machine. Since this only occurs + // when -update is given and thus when there's a human watching the + // test, this isn't too bad. + const serverPort = 24323 + command = append(command, "-accept", strconv.Itoa(serverPort)) + + if len(test.extensions) > 0 { + var serverInfo bytes.Buffer + for _, ext := range test.extensions { + pem.Encode(&serverInfo, &pem.Block{ + Type: fmt.Sprintf("SERVERINFO FOR EXTENSION %d", binary.BigEndian.Uint16(ext)), + Bytes: ext, + }) + } + serverInfoPath := tempFile(serverInfo.String()) + defer os.Remove(serverInfoPath) + command = append(command, "-serverinfo", serverInfoPath) + } + + if test.numRenegotiations > 0 || test.sendKeyUpdate { + found := false + for _, flag := range command[1:] { + if flag == "-state" { + found = true + break + } + } + + if !found { + panic("-state flag missing to OpenSSL, you need this if testing renegotiation or KeyUpdate") + } + } + + cmd := exec.Command(command[0], command[1:]...) + stdin = opensslInput(make(chan opensslInputEvent)) + cmd.Stdin = stdin + out := newOpensslOutputSink() + cmd.Stdout = out + cmd.Stderr = out + if err := cmd.Start(); err != nil { + return nil, nil, nil, nil, err + } + + // OpenSSL does print an "ACCEPT" banner, but it does so *before* + // opening the listening socket, so we can't use that to wait until it + // has started listening. Thus we are forced to poll until we get a + // connection. + var tcpConn net.Conn + for i := uint(0); i < 5; i++ { + tcpConn, err = net.DialTCP("tcp", nil, &net.TCPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: serverPort, + }) + if err == nil { + break + } + time.Sleep((1 << i) * 5 * time.Millisecond) + } + if err != nil { + close(stdin) + cmd.Process.Kill() + err = fmt.Errorf("error connecting to the OpenSSL server: %v (%v)\n\n%s", err, cmd.Wait(), out) + return nil, nil, nil, nil, err + } + + record := &recordingConn{ + Conn: tcpConn, + } + + return record, cmd, stdin, out, nil +} + +func (test *clientTest) dataPath() string { + return filepath.Join("testdata", "Client-"+test.name) +} + +func (test *clientTest) loadData() (flows [][]byte, err error) { + in, err := os.Open(test.dataPath()) + if err != nil { + return nil, err + } + defer in.Close() + return parseTestData(in) +} + +func (test *clientTest) run(t *testing.T, write bool) { + var clientConn, serverConn net.Conn + var recordingConn *recordingConn + var childProcess *exec.Cmd + var stdin opensslInput + var stdout *opensslOutputSink + + if write { + var err error + recordingConn, childProcess, stdin, stdout, err = test.connFromCommand() + if err != nil { + t.Fatalf("Failed to start subcommand: %s", err) + } + clientConn = recordingConn + defer func() { + if t.Failed() { + t.Logf("OpenSSL output:\n\n%s", stdout.all) + } + }() + } else { + clientConn, serverConn = localPipe(t) + } + + doneChan := make(chan bool) + defer func() { + clientConn.Close() + <-doneChan + }() + go func() { + defer close(doneChan) + + config := test.config + if config == nil { + config = testConfig + } + client := Client(clientConn, config) + defer client.Close() + + if _, err := client.Write([]byte("hello\n")); err != nil { + t.Errorf("Client.Write failed: %s", err) + return + } + + for i := 1; i <= test.numRenegotiations; i++ { + // The initial handshake will generate a + // handshakeComplete signal which needs to be quashed. + if i == 1 && write { + <-stdout.handshakeComplete + } + + // OpenSSL will try to interleave application data and + // a renegotiation if we send both concurrently. + // Therefore: ask OpensSSL to start a renegotiation, run + // a goroutine to call client.Read and thus process the + // renegotiation request, watch for OpenSSL's stdout to + // indicate that the handshake is complete and, + // finally, have OpenSSL write something to cause + // client.Read to complete. + if write { + stdin <- opensslRenegotiate + } + + signalChan := make(chan struct{}) + + go func() { + defer close(signalChan) + + buf := make([]byte, 256) + n, err := client.Read(buf) + + if test.checkRenegotiationError != nil { + newErr := test.checkRenegotiationError(i, err) + if err != nil && newErr == nil { + return + } + err = newErr + } + + if err != nil { + t.Errorf("Client.Read failed after renegotiation #%d: %s", i, err) + return + } + + buf = buf[:n] + if !bytes.Equal([]byte(opensslSentinel), buf) { + t.Errorf("Client.Read returned %q, but wanted %q", string(buf), opensslSentinel) + } + + if expected := i + 1; client.handshakes != expected { + t.Errorf("client should have recorded %d handshakes, but believes that %d have occurred", expected, client.handshakes) + } + }() + + if write && test.renegotiationExpectedToFail != i { + <-stdout.handshakeComplete + stdin <- opensslSendSentinel + } + <-signalChan + } + + if test.sendKeyUpdate { + if write { + <-stdout.handshakeComplete + stdin <- opensslKeyUpdate + } + + doneRead := make(chan struct{}) + + go func() { + defer close(doneRead) + + buf := make([]byte, 256) + n, err := client.Read(buf) + + if err != nil { + t.Errorf("Client.Read failed after KeyUpdate: %s", err) + return + } + + buf = buf[:n] + if !bytes.Equal([]byte(opensslSentinel), buf) { + t.Errorf("Client.Read returned %q, but wanted %q", string(buf), opensslSentinel) + } + }() + + if write { + // There's no real reason to wait for the client KeyUpdate to + // send data with the new server keys, except that s_server + // drops writes if they are sent at the wrong time. + <-stdout.readKeyUpdate + stdin <- opensslSendSentinel + } + <-doneRead + + if _, err := client.Write([]byte("hello again\n")); err != nil { + t.Errorf("Client.Write failed: %s", err) + return + } + } + + if test.validate != nil { + if err := test.validate(client.ConnectionState()); err != nil { + t.Errorf("validate callback returned error: %s", err) + } + } + + // If the server sent us an alert after our last flight, give it a + // chance to arrive. + if write && test.renegotiationExpectedToFail == 0 { + if err := peekError(client); err != nil { + t.Errorf("final Read returned an error: %s", err) + } + } + }() + + if !write { + flows, err := test.loadData() + if err != nil { + t.Fatalf("%s: failed to load data from %s: %v", test.name, test.dataPath(), err) + } + for i, b := range flows { + if i%2 == 1 { + if *fast { + serverConn.SetWriteDeadline(time.Now().Add(1 * time.Second)) + } else { + serverConn.SetWriteDeadline(time.Now().Add(1 * time.Minute)) + } + serverConn.Write(b) + continue + } + bb := make([]byte, len(b)) + if *fast { + serverConn.SetReadDeadline(time.Now().Add(1 * time.Second)) + } else { + serverConn.SetReadDeadline(time.Now().Add(1 * time.Minute)) + } + _, err := io.ReadFull(serverConn, bb) + if err != nil { + t.Fatalf("%s, flow %d: %s", test.name, i+1, err) + } + if !bytes.Equal(b, bb) { + t.Fatalf("%s, flow %d: mismatch on read: got:%x want:%x", test.name, i+1, bb, b) + } + } + } + + <-doneChan + if !write { + serverConn.Close() + } + + if write { + path := test.dataPath() + out, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) + if err != nil { + t.Fatalf("Failed to create output file: %s", err) + } + defer out.Close() + recordingConn.Close() + close(stdin) + childProcess.Process.Kill() + childProcess.Wait() + if len(recordingConn.flows) < 3 { + t.Fatalf("Client connection didn't work") + } + recordingConn.WriteTo(out) + t.Logf("Wrote %s\n", path) + } +} + +// peekError does a read with a short timeout to check if the next read would +// cause an error, for example if there is an alert waiting on the wire. +func peekError(conn net.Conn) error { + conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) + if n, err := conn.Read(make([]byte, 1)); n != 0 { + return errors.New("unexpectedly read data") + } else if err != nil { + if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() { + return err + } + } + return nil +} + +func runClientTestForVersion(t *testing.T, template *clientTest, version, option string) { + // Make a deep copy of the template before going parallel. + test := *template + if template.config != nil { + test.config = template.config.Clone() + } + test.name = version + "-" + test.name + test.args = append([]string{option}, test.args...) + + runTestAndUpdateIfNeeded(t, version, test.run, false) +} + +func runClientTestTLS10(t *testing.T, template *clientTest) { + runClientTestForVersion(t, template, "TLSv10", "-tls1") +} + +func runClientTestTLS11(t *testing.T, template *clientTest) { + runClientTestForVersion(t, template, "TLSv11", "-tls1_1") +} + +func runClientTestTLS12(t *testing.T, template *clientTest) { + runClientTestForVersion(t, template, "TLSv12", "-tls1_2") +} + +func runClientTestTLS13(t *testing.T, template *clientTest) { + runClientTestForVersion(t, template, "TLSv13", "-tls1_3") +} + +func TestHandshakeClientRSARC4(t *testing.T) { + test := &clientTest{ + name: "RSA-RC4", + args: []string{"-cipher", "RC4-SHA"}, + } + runClientTestTLS10(t, test) + runClientTestTLS11(t, test) + runClientTestTLS12(t, test) +} + +func TestHandshakeClientRSAAES128GCM(t *testing.T) { + test := &clientTest{ + name: "AES128-GCM-SHA256", + args: []string{"-cipher", "AES128-GCM-SHA256"}, + } + runClientTestTLS12(t, test) +} + +func TestHandshakeClientRSAAES256GCM(t *testing.T) { + test := &clientTest{ + name: "AES256-GCM-SHA384", + args: []string{"-cipher", "AES256-GCM-SHA384"}, + } + runClientTestTLS12(t, test) +} + +func TestHandshakeClientECDHERSAAES(t *testing.T) { + test := &clientTest{ + name: "ECDHE-RSA-AES", + args: []string{"-cipher", "ECDHE-RSA-AES128-SHA"}, + } + runClientTestTLS10(t, test) + runClientTestTLS11(t, test) + runClientTestTLS12(t, test) +} + +func TestHandshakeClientECDHEECDSAAES(t *testing.T) { + test := &clientTest{ + name: "ECDHE-ECDSA-AES", + args: []string{"-cipher", "ECDHE-ECDSA-AES128-SHA"}, + cert: testECDSACertificate, + key: testECDSAPrivateKey, + } + runClientTestTLS10(t, test) + runClientTestTLS11(t, test) + runClientTestTLS12(t, test) +} + +func TestHandshakeClientECDHEECDSAAESGCM(t *testing.T) { + test := &clientTest{ + name: "ECDHE-ECDSA-AES-GCM", + args: []string{"-cipher", "ECDHE-ECDSA-AES128-GCM-SHA256"}, + cert: testECDSACertificate, + key: testECDSAPrivateKey, + } + runClientTestTLS12(t, test) +} + +func TestHandshakeClientAES256GCMSHA384(t *testing.T) { + test := &clientTest{ + name: "ECDHE-ECDSA-AES256-GCM-SHA384", + args: []string{"-cipher", "ECDHE-ECDSA-AES256-GCM-SHA384"}, + cert: testECDSACertificate, + key: testECDSAPrivateKey, + } + runClientTestTLS12(t, test) +} + +func TestHandshakeClientAES128CBCSHA256(t *testing.T) { + test := &clientTest{ + name: "AES128-SHA256", + args: []string{"-cipher", "AES128-SHA256"}, + } + runClientTestTLS12(t, test) +} + +func TestHandshakeClientECDHERSAAES128CBCSHA256(t *testing.T) { + test := &clientTest{ + name: "ECDHE-RSA-AES128-SHA256", + args: []string{"-cipher", "ECDHE-RSA-AES128-SHA256"}, + } + runClientTestTLS12(t, test) +} + +func TestHandshakeClientECDHEECDSAAES128CBCSHA256(t *testing.T) { + test := &clientTest{ + name: "ECDHE-ECDSA-AES128-SHA256", + args: []string{"-cipher", "ECDHE-ECDSA-AES128-SHA256"}, + cert: testECDSACertificate, + key: testECDSAPrivateKey, + } + runClientTestTLS12(t, test) +} + +func TestHandshakeClientX25519(t *testing.T) { + config := testConfig.Clone() + config.CurvePreferences = []CurveID{X25519} + + test := &clientTest{ + name: "X25519-ECDHE", + args: []string{"-cipher", "ECDHE-RSA-AES128-GCM-SHA256", "-curves", "X25519"}, + config: config, + } + + runClientTestTLS12(t, test) + runClientTestTLS13(t, test) +} + +func TestHandshakeClientP256(t *testing.T) { + config := testConfig.Clone() + config.CurvePreferences = []CurveID{CurveP256} + + test := &clientTest{ + name: "P256-ECDHE", + args: []string{"-cipher", "ECDHE-RSA-AES128-GCM-SHA256", "-curves", "P-256"}, + config: config, + } + + runClientTestTLS12(t, test) + runClientTestTLS13(t, test) +} + +func TestHandshakeClientHelloRetryRequest(t *testing.T) { + config := testConfig.Clone() + config.CurvePreferences = []CurveID{X25519, CurveP256} + + test := &clientTest{ + name: "HelloRetryRequest", + args: []string{"-cipher", "ECDHE-RSA-AES128-GCM-SHA256", "-curves", "P-256"}, + config: config, + } + + runClientTestTLS13(t, test) +} + +func TestHandshakeClientECDHERSAChaCha20(t *testing.T) { + config := testConfig.Clone() + config.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305} + + test := &clientTest{ + name: "ECDHE-RSA-CHACHA20-POLY1305", + args: []string{"-cipher", "ECDHE-RSA-CHACHA20-POLY1305"}, + config: config, + } + + runClientTestTLS12(t, test) +} + +func TestHandshakeClientECDHEECDSAChaCha20(t *testing.T) { + config := testConfig.Clone() + config.CipherSuites = []uint16{TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305} + + test := &clientTest{ + name: "ECDHE-ECDSA-CHACHA20-POLY1305", + args: []string{"-cipher", "ECDHE-ECDSA-CHACHA20-POLY1305"}, + config: config, + cert: testECDSACertificate, + key: testECDSAPrivateKey, + } + + runClientTestTLS12(t, test) +} + +func TestHandshakeClientAES128SHA256(t *testing.T) { + test := &clientTest{ + name: "AES128-SHA256", + args: []string{"-ciphersuites", "TLS_AES_128_GCM_SHA256"}, + } + runClientTestTLS13(t, test) +} +func TestHandshakeClientAES256SHA384(t *testing.T) { + test := &clientTest{ + name: "AES256-SHA384", + args: []string{"-ciphersuites", "TLS_AES_256_GCM_SHA384"}, + } + runClientTestTLS13(t, test) +} +func TestHandshakeClientCHACHA20SHA256(t *testing.T) { + test := &clientTest{ + name: "CHACHA20-SHA256", + args: []string{"-ciphersuites", "TLS_CHACHA20_POLY1305_SHA256"}, + } + runClientTestTLS13(t, test) +} + +func TestHandshakeClientECDSATLS13(t *testing.T) { + test := &clientTest{ + name: "ECDSA", + cert: testECDSACertificate, + key: testECDSAPrivateKey, + } + runClientTestTLS13(t, test) +} + +func TestHandshakeClientEd25519(t *testing.T) { + test := &clientTest{ + name: "Ed25519", + cert: testEd25519Certificate, + key: testEd25519PrivateKey, + } + runClientTestTLS12(t, test) + runClientTestTLS13(t, test) + + config := testConfig.Clone() + cert, _ := X509KeyPair([]byte(clientEd25519CertificatePEM), []byte(clientEd25519KeyPEM)) + config.Certificates = []Certificate{cert} + + test = &clientTest{ + name: "ClientCert-Ed25519", + args: []string{"-Verify", "1"}, + config: config, + } + + runClientTestTLS12(t, test) + runClientTestTLS13(t, test) +} + +func TestHandshakeClientCertRSA(t *testing.T) { + config := testConfig.Clone() + cert, _ := X509KeyPair([]byte(clientCertificatePEM), []byte(clientKeyPEM)) + config.Certificates = []Certificate{cert} + + test := &clientTest{ + name: "ClientCert-RSA-RSA", + args: []string{"-cipher", "AES128", "-Verify", "1"}, + config: config, + } + + runClientTestTLS10(t, test) + runClientTestTLS12(t, test) + + test = &clientTest{ + name: "ClientCert-RSA-ECDSA", + args: []string{"-cipher", "ECDHE-ECDSA-AES128-SHA", "-Verify", "1"}, + config: config, + cert: testECDSACertificate, + key: testECDSAPrivateKey, + } + + runClientTestTLS10(t, test) + runClientTestTLS12(t, test) + runClientTestTLS13(t, test) + + test = &clientTest{ + name: "ClientCert-RSA-AES256-GCM-SHA384", + args: []string{"-cipher", "ECDHE-RSA-AES256-GCM-SHA384", "-Verify", "1"}, + config: config, + cert: testRSACertificate, + key: testRSAPrivateKey, + } + + runClientTestTLS12(t, test) +} + +func TestHandshakeClientCertECDSA(t *testing.T) { + config := testConfig.Clone() + cert, _ := X509KeyPair([]byte(clientECDSACertificatePEM), []byte(clientECDSAKeyPEM)) + config.Certificates = []Certificate{cert} + + test := &clientTest{ + name: "ClientCert-ECDSA-RSA", + args: []string{"-cipher", "AES128", "-Verify", "1"}, + config: config, + } + + runClientTestTLS10(t, test) + runClientTestTLS12(t, test) + runClientTestTLS13(t, test) + + test = &clientTest{ + name: "ClientCert-ECDSA-ECDSA", + args: []string{"-cipher", "ECDHE-ECDSA-AES128-SHA", "-Verify", "1"}, + config: config, + cert: testECDSACertificate, + key: testECDSAPrivateKey, + } + + runClientTestTLS10(t, test) + runClientTestTLS12(t, test) +} + +// TestHandshakeClientCertRSAPSS tests rsa_pss_rsae_sha256 signatures from both +// client and server certificates. It also serves from both sides a certificate +// signed itself with RSA-PSS, mostly to check that crypto/x509 chain validation +// works. +func TestHandshakeClientCertRSAPSS(t *testing.T) { + cert, err := x509.ParseCertificate(testRSAPSSCertificate) + if err != nil { + panic(err) + } + rootCAs := x509.NewCertPool() + rootCAs.AddCert(cert) + + config := testConfig.Clone() + // Use GetClientCertificate to bypass the client certificate selection logic. + config.GetClientCertificate = func(*CertificateRequestInfo) (*Certificate, error) { + return &Certificate{ + Certificate: [][]byte{testRSAPSSCertificate}, + PrivateKey: testRSAPrivateKey, + }, nil + } + config.RootCAs = rootCAs + + test := &clientTest{ + name: "ClientCert-RSA-RSAPSS", + args: []string{"-cipher", "AES128", "-Verify", "1", "-client_sigalgs", + "rsa_pss_rsae_sha256", "-sigalgs", "rsa_pss_rsae_sha256"}, + config: config, + cert: testRSAPSSCertificate, + key: testRSAPrivateKey, + } + runClientTestTLS12(t, test) + runClientTestTLS13(t, test) +} + +func TestHandshakeClientCertRSAPKCS1v15(t *testing.T) { + config := testConfig.Clone() + cert, _ := X509KeyPair([]byte(clientCertificatePEM), []byte(clientKeyPEM)) + config.Certificates = []Certificate{cert} + + test := &clientTest{ + name: "ClientCert-RSA-RSAPKCS1v15", + args: []string{"-cipher", "AES128", "-Verify", "1", "-client_sigalgs", + "rsa_pkcs1_sha256", "-sigalgs", "rsa_pkcs1_sha256"}, + config: config, + } + + runClientTestTLS12(t, test) +} + +func TestClientKeyUpdate(t *testing.T) { + test := &clientTest{ + name: "KeyUpdate", + args: []string{"-state"}, + sendKeyUpdate: true, + } + runClientTestTLS13(t, test) +} + +func TestResumption(t *testing.T) { + t.Run("TLSv12", func(t *testing.T) { testResumption(t, VersionTLS12) }) + t.Run("TLSv13", func(t *testing.T) { testResumption(t, VersionTLS13) }) +} + +func testResumption(t *testing.T, version uint16) { + if testing.Short() { + t.Skip("skipping in -short mode") + } + serverConfig := &Config{ + MaxVersion: version, + CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA}, + Certificates: testConfig.Certificates, + } + + issuer, err := x509.ParseCertificate(testRSACertificateIssuer) + if err != nil { + panic(err) + } + + rootCAs := x509.NewCertPool() + rootCAs.AddCert(issuer) + + clientConfig := &Config{ + MaxVersion: version, + CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA}, + ClientSessionCache: NewLRUClientSessionCache(32), + RootCAs: rootCAs, + ServerName: "example.golang", + } + + testResumeState := func(test string, didResume bool) { + _, hs, err := testHandshake(t, clientConfig, serverConfig) + if err != nil { + t.Fatalf("%s: handshake failed: %s", test, err) + } + if hs.DidResume != didResume { + t.Fatalf("%s resumed: %v, expected: %v", test, hs.DidResume, didResume) + } + if didResume && (hs.PeerCertificates == nil || hs.VerifiedChains == nil) { + t.Fatalf("expected non-nil certificates after resumption. Got peerCertificates: %#v, verifiedCertificates: %#v", hs.PeerCertificates, hs.VerifiedChains) + } + if got, want := hs.ServerName, clientConfig.ServerName; got != want { + t.Errorf("%s: server name %s, want %s", test, got, want) + } + } + + getTicket := func() []byte { + return clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).state.sessionTicket + } + deleteTicket := func() { + ticketKey := clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).sessionKey + clientConfig.ClientSessionCache.Put(ticketKey, nil) + } + corruptTicket := func() { + clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).state.masterSecret[0] ^= 0xff + } + randomKey := func() [32]byte { + var k [32]byte + if _, err := io.ReadFull(serverConfig.rand(), k[:]); err != nil { + t.Fatalf("Failed to read new SessionTicketKey: %s", err) + } + return k + } + + testResumeState("Handshake", false) + ticket := getTicket() + testResumeState("Resume", true) + if !bytes.Equal(ticket, getTicket()) && version != VersionTLS13 { + t.Fatal("first ticket doesn't match ticket after resumption") + } + if bytes.Equal(ticket, getTicket()) && version == VersionTLS13 { + t.Fatal("ticket didn't change after resumption") + } + + // An old session ticket can resume, but the server will provide a ticket encrypted with a fresh key. + serverConfig.Time = func() time.Time { return time.Now().Add(24*time.Hour + time.Minute) } + testResumeState("ResumeWithOldTicket", true) + if bytes.Equal(ticket[:ticketKeyNameLen], getTicket()[:ticketKeyNameLen]) { + t.Fatal("old first ticket matches the fresh one") + } + + // Now the session tickey key is expired, so a full handshake should occur. + serverConfig.Time = func() time.Time { return time.Now().Add(24*8*time.Hour + time.Minute) } + testResumeState("ResumeWithExpiredTicket", false) + if bytes.Equal(ticket, getTicket()) { + t.Fatal("expired first ticket matches the fresh one") + } + + serverConfig.Time = func() time.Time { return time.Now() } // reset the time back + key1 := randomKey() + serverConfig.SetSessionTicketKeys([][32]byte{key1}) + + testResumeState("InvalidSessionTicketKey", false) + testResumeState("ResumeAfterInvalidSessionTicketKey", true) + + key2 := randomKey() + serverConfig.SetSessionTicketKeys([][32]byte{key2, key1}) + ticket = getTicket() + testResumeState("KeyChange", true) + if bytes.Equal(ticket, getTicket()) { + t.Fatal("new ticket wasn't included while resuming") + } + testResumeState("KeyChangeFinish", true) + + // Age the session ticket a bit, but not yet expired. + serverConfig.Time = func() time.Time { return time.Now().Add(24*time.Hour + time.Minute) } + testResumeState("OldSessionTicket", true) + ticket = getTicket() + // Expire the session ticket, which would force a full handshake. + serverConfig.Time = func() time.Time { return time.Now().Add(24*8*time.Hour + time.Minute) } + testResumeState("ExpiredSessionTicket", false) + if bytes.Equal(ticket, getTicket()) { + t.Fatal("new ticket wasn't provided after old ticket expired") + } + + // Age the session ticket a bit at a time, but don't expire it. + d := 0 * time.Hour + for i := 0; i < 13; i++ { + d += 12 * time.Hour + serverConfig.Time = func() time.Time { return time.Now().Add(d) } + testResumeState("OldSessionTicket", true) + } + // Expire it (now a little more than 7 days) and make sure a full + // handshake occurs for TLS 1.2. Resumption should still occur for + // TLS 1.3 since the client should be using a fresh ticket sent over + // by the server. + d += 12 * time.Hour + serverConfig.Time = func() time.Time { return time.Now().Add(d) } + if version == VersionTLS13 { + testResumeState("ExpiredSessionTicket", true) + } else { + testResumeState("ExpiredSessionTicket", false) + } + if bytes.Equal(ticket, getTicket()) { + t.Fatal("new ticket wasn't provided after old ticket expired") + } + + // Reset serverConfig to ensure that calling SetSessionTicketKeys + // before the serverConfig is used works. + serverConfig = &Config{ + MaxVersion: version, + CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA}, + Certificates: testConfig.Certificates, + } + serverConfig.SetSessionTicketKeys([][32]byte{key2}) + + testResumeState("FreshConfig", true) + + // In TLS 1.3, cross-cipher suite resumption is allowed as long as the KDF + // hash matches. Also, Config.CipherSuites does not apply to TLS 1.3. + if version != VersionTLS13 { + clientConfig.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_RC4_128_SHA} + testResumeState("DifferentCipherSuite", false) + testResumeState("DifferentCipherSuiteRecovers", true) + } + + deleteTicket() + testResumeState("WithoutSessionTicket", false) + + // Session resumption should work when using client certificates + deleteTicket() + serverConfig.ClientCAs = rootCAs + serverConfig.ClientAuth = RequireAndVerifyClientCert + clientConfig.Certificates = serverConfig.Certificates + testResumeState("InitialHandshake", false) + testResumeState("WithClientCertificates", true) + serverConfig.ClientAuth = NoClientCert + + // Tickets should be removed from the session cache on TLS handshake + // failure, and the client should recover from a corrupted PSK + testResumeState("FetchTicketToCorrupt", false) + corruptTicket() + _, _, err = testHandshake(t, clientConfig, serverConfig) + if err == nil { + t.Fatalf("handshake did not fail with a corrupted client secret") + } + testResumeState("AfterHandshakeFailure", false) + + clientConfig.ClientSessionCache = nil + testResumeState("WithoutSessionCache", false) +} + +func TestLRUClientSessionCache(t *testing.T) { + // Initialize cache of capacity 4. + cache := NewLRUClientSessionCache(4) + cs := make([]ClientSessionState, 6) + keys := []string{"0", "1", "2", "3", "4", "5", "6"} + + // Add 4 entries to the cache and look them up. + for i := 0; i < 4; i++ { + cache.Put(keys[i], &cs[i]) + } + for i := 0; i < 4; i++ { + if s, ok := cache.Get(keys[i]); !ok || s != &cs[i] { + t.Fatalf("session cache failed lookup for added key: %s", keys[i]) + } + } + + // Add 2 more entries to the cache. First 2 should be evicted. + for i := 4; i < 6; i++ { + cache.Put(keys[i], &cs[i]) + } + for i := 0; i < 2; i++ { + if s, ok := cache.Get(keys[i]); ok || s != nil { + t.Fatalf("session cache should have evicted key: %s", keys[i]) + } + } + + // Touch entry 2. LRU should evict 3 next. + cache.Get(keys[2]) + cache.Put(keys[0], &cs[0]) + if s, ok := cache.Get(keys[3]); ok || s != nil { + t.Fatalf("session cache should have evicted key 3") + } + + // Update entry 0 in place. + cache.Put(keys[0], &cs[3]) + if s, ok := cache.Get(keys[0]); !ok || s != &cs[3] { + t.Fatalf("session cache failed update for key 0") + } + + // Calling Put with a nil entry deletes the key. + cache.Put(keys[0], nil) + if _, ok := cache.Get(keys[0]); ok { + t.Fatalf("session cache failed to delete key 0") + } + + // Delete entry 2. LRU should keep 4 and 5 + cache.Put(keys[2], nil) + if _, ok := cache.Get(keys[2]); ok { + t.Fatalf("session cache failed to delete key 4") + } + for i := 4; i < 6; i++ { + if s, ok := cache.Get(keys[i]); !ok || s != &cs[i] { + t.Fatalf("session cache should not have deleted key: %s", keys[i]) + } + } +} + +func TestKeyLogTLS12(t *testing.T) { + var serverBuf, clientBuf bytes.Buffer + + clientConfig := testConfig.Clone() + clientConfig.KeyLogWriter = &clientBuf + clientConfig.MaxVersion = VersionTLS12 + + serverConfig := testConfig.Clone() + serverConfig.KeyLogWriter = &serverBuf + serverConfig.MaxVersion = VersionTLS12 + + c, s := localPipe(t) + done := make(chan bool) + + go func() { + defer close(done) + + if err := Server(s, serverConfig).Handshake(); err != nil { + t.Errorf("server: %s", err) + return + } + s.Close() + }() + + if err := Client(c, clientConfig).Handshake(); err != nil { + t.Fatalf("client: %s", err) + } + + c.Close() + <-done + + checkKeylogLine := func(side, loggedLine string) { + if len(loggedLine) == 0 { + t.Fatalf("%s: no keylog line was produced", side) + } + const expectedLen = 13 /* "CLIENT_RANDOM" */ + + 1 /* space */ + + 32*2 /* hex client nonce */ + + 1 /* space */ + + 48*2 /* hex master secret */ + + 1 /* new line */ + if len(loggedLine) != expectedLen { + t.Fatalf("%s: keylog line has incorrect length (want %d, got %d): %q", side, expectedLen, len(loggedLine), loggedLine) + } + if !strings.HasPrefix(loggedLine, "CLIENT_RANDOM "+strings.Repeat("0", 64)+" ") { + t.Fatalf("%s: keylog line has incorrect structure or nonce: %q", side, loggedLine) + } + } + + checkKeylogLine("client", clientBuf.String()) + checkKeylogLine("server", serverBuf.String()) +} + +func TestKeyLogTLS13(t *testing.T) { + var serverBuf, clientBuf bytes.Buffer + + clientConfig := testConfig.Clone() + clientConfig.KeyLogWriter = &clientBuf + + serverConfig := testConfig.Clone() + serverConfig.KeyLogWriter = &serverBuf + + c, s := localPipe(t) + done := make(chan bool) + + go func() { + defer close(done) + + if err := Server(s, serverConfig).Handshake(); err != nil { + t.Errorf("server: %s", err) + return + } + s.Close() + }() + + if err := Client(c, clientConfig).Handshake(); err != nil { + t.Fatalf("client: %s", err) + } + + c.Close() + <-done + + checkKeylogLines := func(side, loggedLines string) { + loggedLines = strings.TrimSpace(loggedLines) + lines := strings.Split(loggedLines, "\n") + if len(lines) != 4 { + t.Errorf("Expected the %s to log 4 lines, got %d", side, len(lines)) + } + } + + checkKeylogLines("client", clientBuf.String()) + checkKeylogLines("server", serverBuf.String()) +} + +func TestHandshakeClientALPNMatch(t *testing.T) { + config := testConfig.Clone() + config.NextProtos = []string{"proto2", "proto1"} + + test := &clientTest{ + name: "ALPN", + // Note that this needs OpenSSL 1.0.2 because that is the first + // version that supports the -alpn flag. + args: []string{"-alpn", "proto1,proto2"}, + config: config, + validate: func(state ConnectionState) error { + // The server's preferences should override the client. + if state.NegotiatedProtocol != "proto1" { + return fmt.Errorf("Got protocol %q, wanted proto1", state.NegotiatedProtocol) + } + return nil + }, + } + runClientTestTLS12(t, test) + runClientTestTLS13(t, test) +} + +func TestServerSelectingUnconfiguredApplicationProtocol(t *testing.T) { + // This checks that the server can't select an application protocol that the + // client didn't offer. + + c, s := localPipe(t) + errChan := make(chan error, 1) + + go func() { + client := Client(c, &Config{ + ServerName: "foo", + CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256}, + NextProtos: []string{"http", "something-else"}, + }) + errChan <- client.Handshake() + }() + + var header [5]byte + if _, err := io.ReadFull(s, header[:]); err != nil { + t.Fatal(err) + } + recordLen := int(header[3])<<8 | int(header[4]) + + record := make([]byte, recordLen) + if _, err := io.ReadFull(s, record); err != nil { + t.Fatal(err) + } + + serverHello := &serverHelloMsg{ + vers: VersionTLS12, + random: make([]byte, 32), + cipherSuite: TLS_RSA_WITH_AES_128_GCM_SHA256, + alpnProtocol: "how-about-this", + } + serverHelloBytes := mustMarshal(t, serverHello) + + s.Write([]byte{ + byte(recordTypeHandshake), + byte(VersionTLS12 >> 8), + byte(VersionTLS12 & 0xff), + byte(len(serverHelloBytes) >> 8), + byte(len(serverHelloBytes)), + }) + s.Write(serverHelloBytes) + s.Close() + + if err := <-errChan; !strings.Contains(err.Error(), "server selected unadvertised ALPN protocol") { + t.Fatalf("Expected error about unconfigured cipher suite but got %q", err) + } +} + +// sctsBase64 contains data from `openssl s_client -serverinfo 18 -connect ritter.vg:443` +const sctsBase64 = "ABIBaQFnAHUApLkJkLQYWBSHuxOizGdwCjw1mAT5G9+443fNDsgN3BAAAAFHl5nuFgAABAMARjBEAiAcS4JdlW5nW9sElUv2zvQyPoZ6ejKrGGB03gjaBZFMLwIgc1Qbbn+hsH0RvObzhS+XZhr3iuQQJY8S9G85D9KeGPAAdgBo9pj4H2SCvjqM7rkoHUz8cVFdZ5PURNEKZ6y7T0/7xAAAAUeX4bVwAAAEAwBHMEUCIDIhFDgG2HIuADBkGuLobU5a4dlCHoJLliWJ1SYT05z6AiEAjxIoZFFPRNWMGGIjskOTMwXzQ1Wh2e7NxXE1kd1J0QsAdgDuS723dc5guuFCaR+r4Z5mow9+X7By2IMAxHuJeqj9ywAAAUhcZIqHAAAEAwBHMEUCICmJ1rBT09LpkbzxtUC+Hi7nXLR0J+2PmwLp+sJMuqK+AiEAr0NkUnEVKVhAkccIFpYDqHOlZaBsuEhWWrYpg2RtKp0=" + +func TestHandshakClientSCTs(t *testing.T) { + config := testConfig.Clone() + + scts, err := base64.StdEncoding.DecodeString(sctsBase64) + if err != nil { + t.Fatal(err) + } + + // Note that this needs OpenSSL 1.0.2 because that is the first + // version that supports the -serverinfo flag. + test := &clientTest{ + name: "SCT", + config: config, + extensions: [][]byte{scts}, + validate: func(state ConnectionState) error { + expectedSCTs := [][]byte{ + scts[8:125], + scts[127:245], + scts[247:], + } + if n := len(state.SignedCertificateTimestamps); n != len(expectedSCTs) { + return fmt.Errorf("Got %d scts, wanted %d", n, len(expectedSCTs)) + } + for i, expected := range expectedSCTs { + if sct := state.SignedCertificateTimestamps[i]; !bytes.Equal(sct, expected) { + return fmt.Errorf("SCT #%d contained %x, expected %x", i, sct, expected) + } + } + return nil + }, + } + runClientTestTLS12(t, test) + + // TLS 1.3 moved SCTs to the Certificate extensions and -serverinfo only + // supports ServerHello extensions. +} + +func TestRenegotiationRejected(t *testing.T) { + config := testConfig.Clone() + test := &clientTest{ + name: "RenegotiationRejected", + args: []string{"-state"}, + config: config, + numRenegotiations: 1, + renegotiationExpectedToFail: 1, + checkRenegotiationError: func(renegotiationNum int, err error) error { + if err == nil { + return errors.New("expected error from renegotiation but got nil") + } + if !strings.Contains(err.Error(), "no renegotiation") { + return fmt.Errorf("expected renegotiation to be rejected but got %q", err) + } + return nil + }, + } + runClientTestTLS12(t, test) +} + +func TestRenegotiateOnce(t *testing.T) { + config := testConfig.Clone() + config.Renegotiation = RenegotiateOnceAsClient + + test := &clientTest{ + name: "RenegotiateOnce", + args: []string{"-state"}, + config: config, + numRenegotiations: 1, + } + + runClientTestTLS12(t, test) +} + +func TestRenegotiateTwice(t *testing.T) { + config := testConfig.Clone() + config.Renegotiation = RenegotiateFreelyAsClient + + test := &clientTest{ + name: "RenegotiateTwice", + args: []string{"-state"}, + config: config, + numRenegotiations: 2, + } + + runClientTestTLS12(t, test) +} + +func TestRenegotiateTwiceRejected(t *testing.T) { + config := testConfig.Clone() + config.Renegotiation = RenegotiateOnceAsClient + + test := &clientTest{ + name: "RenegotiateTwiceRejected", + args: []string{"-state"}, + config: config, + numRenegotiations: 2, + renegotiationExpectedToFail: 2, + checkRenegotiationError: func(renegotiationNum int, err error) error { + if renegotiationNum == 1 { + return err + } + + if err == nil { + return errors.New("expected error from renegotiation but got nil") + } + if !strings.Contains(err.Error(), "no renegotiation") { + return fmt.Errorf("expected renegotiation to be rejected but got %q", err) + } + return nil + }, + } + + runClientTestTLS12(t, test) +} + +func TestHandshakeClientExportKeyingMaterial(t *testing.T) { + test := &clientTest{ + name: "ExportKeyingMaterial", + config: testConfig.Clone(), + validate: func(state ConnectionState) error { + if km, err := state.ExportKeyingMaterial("test", nil, 42); err != nil { + return fmt.Errorf("ExportKeyingMaterial failed: %v", err) + } else if len(km) != 42 { + return fmt.Errorf("Got %d bytes from ExportKeyingMaterial, wanted %d", len(km), 42) + } + return nil + }, + } + runClientTestTLS10(t, test) + runClientTestTLS12(t, test) + runClientTestTLS13(t, test) +} + +var hostnameInSNITests = []struct { + in, out string +}{ + // Opaque string + {"", ""}, + {"localhost", "localhost"}, + {"foo, bar, baz and qux", "foo, bar, baz and qux"}, + + // DNS hostname + {"golang.org", "golang.org"}, + {"golang.org.", "golang.org"}, + + // Literal IPv4 address + {"1.2.3.4", ""}, + + // Literal IPv6 address + {"::1", ""}, + {"::1%lo0", ""}, // with zone identifier + {"[::1]", ""}, // as per RFC 5952 we allow the [] style as IPv6 literal + {"[::1%lo0]", ""}, +} + +func TestHostnameInSNI(t *testing.T) { + for _, tt := range hostnameInSNITests { + c, s := localPipe(t) + + go func(host string) { + Client(c, &Config{ServerName: host, InsecureSkipVerify: true}).Handshake() + }(tt.in) + + var header [5]byte + if _, err := io.ReadFull(s, header[:]); err != nil { + t.Fatal(err) + } + recordLen := int(header[3])<<8 | int(header[4]) + + record := make([]byte, recordLen) + if _, err := io.ReadFull(s, record[:]); err != nil { + t.Fatal(err) + } + + c.Close() + s.Close() + + var m clientHelloMsg + if !m.unmarshal(record) { + t.Errorf("unmarshaling ClientHello for %q failed", tt.in) + continue + } + if tt.in != tt.out && m.serverName == tt.in { + t.Errorf("prohibited %q found in ClientHello: %x", tt.in, record) + } + if m.serverName != tt.out { + t.Errorf("expected %q not found in ClientHello: %x", tt.out, record) + } + } +} + +func TestServerSelectingUnconfiguredCipherSuite(t *testing.T) { + // This checks that the server can't select a cipher suite that the + // client didn't offer. See #13174. + + c, s := localPipe(t) + errChan := make(chan error, 1) + + go func() { + client := Client(c, &Config{ + ServerName: "foo", + CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256}, + }) + errChan <- client.Handshake() + }() + + var header [5]byte + if _, err := io.ReadFull(s, header[:]); err != nil { + t.Fatal(err) + } + recordLen := int(header[3])<<8 | int(header[4]) + + record := make([]byte, recordLen) + if _, err := io.ReadFull(s, record); err != nil { + t.Fatal(err) + } + + // Create a ServerHello that selects a different cipher suite than the + // sole one that the client offered. + serverHello := &serverHelloMsg{ + vers: VersionTLS12, + random: make([]byte, 32), + cipherSuite: TLS_RSA_WITH_AES_256_GCM_SHA384, + } + serverHelloBytes := mustMarshal(t, serverHello) + + s.Write([]byte{ + byte(recordTypeHandshake), + byte(VersionTLS12 >> 8), + byte(VersionTLS12 & 0xff), + byte(len(serverHelloBytes) >> 8), + byte(len(serverHelloBytes)), + }) + s.Write(serverHelloBytes) + s.Close() + + if err := <-errChan; !strings.Contains(err.Error(), "unconfigured cipher") { + t.Fatalf("Expected error about unconfigured cipher suite but got %q", err) + } +} + +func TestVerifyConnection(t *testing.T) { + t.Run("TLSv12", func(t *testing.T) { testVerifyConnection(t, VersionTLS12) }) + t.Run("TLSv13", func(t *testing.T) { testVerifyConnection(t, VersionTLS13) }) +} + +func testVerifyConnection(t *testing.T, version uint16) { + checkFields := func(c ConnectionState, called *int, errorType string) error { + if c.Version != version { + return fmt.Errorf("%s: got Version %v, want %v", errorType, c.Version, version) + } + if c.HandshakeComplete { + return fmt.Errorf("%s: got HandshakeComplete, want false", errorType) + } + if c.ServerName != "example.golang" { + return fmt.Errorf("%s: got ServerName %s, want %s", errorType, c.ServerName, "example.golang") + } + if c.NegotiatedProtocol != "protocol1" { + return fmt.Errorf("%s: got NegotiatedProtocol %s, want %s", errorType, c.NegotiatedProtocol, "protocol1") + } + if c.CipherSuite == 0 { + return fmt.Errorf("%s: got CipherSuite 0, want non-zero", errorType) + } + wantDidResume := false + if *called == 2 { // if this is the second time, then it should be a resumption + wantDidResume = true + } + if c.DidResume != wantDidResume { + return fmt.Errorf("%s: got DidResume %t, want %t", errorType, c.DidResume, wantDidResume) + } + return nil + } + + tests := []struct { + name string + configureServer func(*Config, *int) + configureClient func(*Config, *int) + }{ + { + name: "RequireAndVerifyClientCert", + configureServer: func(config *Config, called *int) { + config.ClientAuth = RequireAndVerifyClientCert + config.VerifyConnection = func(c ConnectionState) error { + *called++ + if l := len(c.PeerCertificates); l != 1 { + return fmt.Errorf("server: got len(PeerCertificates) = %d, wanted 1", l) + } + if len(c.VerifiedChains) == 0 { + return fmt.Errorf("server: got len(VerifiedChains) = 0, wanted non-zero") + } + return checkFields(c, called, "server") + } + }, + configureClient: func(config *Config, called *int) { + config.VerifyConnection = func(c ConnectionState) error { + *called++ + if l := len(c.PeerCertificates); l != 1 { + return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l) + } + if len(c.VerifiedChains) == 0 { + return fmt.Errorf("client: got len(VerifiedChains) = 0, wanted non-zero") + } + if c.DidResume { + return nil + // The SCTs and OCSP Response are dropped on resumption. + // See http://golang.org/issue/39075. + } + if len(c.OCSPResponse) == 0 { + return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero") + } + if len(c.SignedCertificateTimestamps) == 0 { + return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero") + } + return checkFields(c, called, "client") + } + }, + }, + { + name: "InsecureSkipVerify", + configureServer: func(config *Config, called *int) { + config.ClientAuth = RequireAnyClientCert + config.InsecureSkipVerify = true + config.VerifyConnection = func(c ConnectionState) error { + *called++ + if l := len(c.PeerCertificates); l != 1 { + return fmt.Errorf("server: got len(PeerCertificates) = %d, wanted 1", l) + } + if c.VerifiedChains != nil { + return fmt.Errorf("server: got Verified Chains %v, want nil", c.VerifiedChains) + } + return checkFields(c, called, "server") + } + }, + configureClient: func(config *Config, called *int) { + config.InsecureSkipVerify = true + config.VerifyConnection = func(c ConnectionState) error { + *called++ + if l := len(c.PeerCertificates); l != 1 { + return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l) + } + if c.VerifiedChains != nil { + return fmt.Errorf("server: got Verified Chains %v, want nil", c.VerifiedChains) + } + if c.DidResume { + return nil + // The SCTs and OCSP Response are dropped on resumption. + // See http://golang.org/issue/39075. + } + if len(c.OCSPResponse) == 0 { + return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero") + } + if len(c.SignedCertificateTimestamps) == 0 { + return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero") + } + return checkFields(c, called, "client") + } + }, + }, + { + name: "NoClientCert", + configureServer: func(config *Config, called *int) { + config.ClientAuth = NoClientCert + config.VerifyConnection = func(c ConnectionState) error { + *called++ + return checkFields(c, called, "server") + } + }, + configureClient: func(config *Config, called *int) { + config.VerifyConnection = func(c ConnectionState) error { + *called++ + return checkFields(c, called, "client") + } + }, + }, + { + name: "RequestClientCert", + configureServer: func(config *Config, called *int) { + config.ClientAuth = RequestClientCert + config.VerifyConnection = func(c ConnectionState) error { + *called++ + return checkFields(c, called, "server") + } + }, + configureClient: func(config *Config, called *int) { + config.Certificates = nil // clear the client cert + config.VerifyConnection = func(c ConnectionState) error { + *called++ + if l := len(c.PeerCertificates); l != 1 { + return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l) + } + if len(c.VerifiedChains) == 0 { + return fmt.Errorf("client: got len(VerifiedChains) = 0, wanted non-zero") + } + if c.DidResume { + return nil + // The SCTs and OCSP Response are dropped on resumption. + // See http://golang.org/issue/39075. + } + if len(c.OCSPResponse) == 0 { + return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero") + } + if len(c.SignedCertificateTimestamps) == 0 { + return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero") + } + return checkFields(c, called, "client") + } + }, + }, + } + for _, test := range tests { + issuer, err := x509.ParseCertificate(testRSACertificateIssuer) + if err != nil { + panic(err) + } + rootCAs := x509.NewCertPool() + rootCAs.AddCert(issuer) + + var serverCalled, clientCalled int + + serverConfig := &Config{ + MaxVersion: version, + Certificates: []Certificate{testConfig.Certificates[0]}, + ClientCAs: rootCAs, + NextProtos: []string{"protocol1"}, + } + serverConfig.Certificates[0].SignedCertificateTimestamps = [][]byte{[]byte("dummy sct 1"), []byte("dummy sct 2")} + serverConfig.Certificates[0].OCSPStaple = []byte("dummy ocsp") + test.configureServer(serverConfig, &serverCalled) + + clientConfig := &Config{ + MaxVersion: version, + ClientSessionCache: NewLRUClientSessionCache(32), + RootCAs: rootCAs, + ServerName: "example.golang", + Certificates: []Certificate{testConfig.Certificates[0]}, + NextProtos: []string{"protocol1"}, + } + test.configureClient(clientConfig, &clientCalled) + + testHandshakeState := func(name string, didResume bool) { + _, hs, err := testHandshake(t, clientConfig, serverConfig) + if err != nil { + t.Fatalf("%s: handshake failed: %s", name, err) + } + if hs.DidResume != didResume { + t.Errorf("%s: resumed: %v, expected: %v", name, hs.DidResume, didResume) + } + wantCalled := 1 + if didResume { + wantCalled = 2 // resumption would mean this is the second time it was called in this test + } + if clientCalled != wantCalled { + t.Errorf("%s: expected client VerifyConnection called %d times, did %d times", name, wantCalled, clientCalled) + } + if serverCalled != wantCalled { + t.Errorf("%s: expected server VerifyConnection called %d times, did %d times", name, wantCalled, serverCalled) + } + } + testHandshakeState(fmt.Sprintf("%s-FullHandshake", test.name), false) + testHandshakeState(fmt.Sprintf("%s-Resumption", test.name), true) + } +} + +func TestVerifyPeerCertificate(t *testing.T) { + t.Run("TLSv12", func(t *testing.T) { testVerifyPeerCertificate(t, VersionTLS12) }) + t.Run("TLSv13", func(t *testing.T) { testVerifyPeerCertificate(t, VersionTLS13) }) +} + +func testVerifyPeerCertificate(t *testing.T, version uint16) { + issuer, err := x509.ParseCertificate(testRSACertificateIssuer) + if err != nil { + panic(err) + } + + rootCAs := x509.NewCertPool() + rootCAs.AddCert(issuer) + + now := func() time.Time { return time.Unix(1476984729, 0) } + + sentinelErr := errors.New("TestVerifyPeerCertificate") + + verifyPeerCertificateCallback := func(called *bool, rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { + if l := len(rawCerts); l != 1 { + return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l) + } + if len(validatedChains) == 0 { + return errors.New("got len(validatedChains) = 0, wanted non-zero") + } + *called = true + return nil + } + verifyConnectionCallback := func(called *bool, isClient bool, c ConnectionState) error { + if l := len(c.PeerCertificates); l != 1 { + return fmt.Errorf("got len(PeerCertificates) = %d, wanted 1", l) + } + if len(c.VerifiedChains) == 0 { + return fmt.Errorf("got len(VerifiedChains) = 0, wanted non-zero") + } + if isClient && len(c.OCSPResponse) == 0 { + return fmt.Errorf("got len(OCSPResponse) = 0, wanted non-zero") + } + *called = true + return nil + } + + tests := []struct { + configureServer func(*Config, *bool) + configureClient func(*Config, *bool) + validate func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) + }{ + { + configureServer: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { + return verifyPeerCertificateCallback(called, rawCerts, validatedChains) + } + }, + configureClient: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { + return verifyPeerCertificateCallback(called, rawCerts, validatedChains) + } + }, + validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { + if clientErr != nil { + t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr) + } + if serverErr != nil { + t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr) + } + if !clientCalled { + t.Errorf("test[%d]: client did not call callback", testNo) + } + if !serverCalled { + t.Errorf("test[%d]: server did not call callback", testNo) + } + }, + }, + { + configureServer: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { + return sentinelErr + } + }, + configureClient: func(config *Config, called *bool) { + config.VerifyPeerCertificate = nil + }, + validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { + if serverErr != sentinelErr { + t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr) + } + }, + }, + { + configureServer: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + }, + configureClient: func(config *Config, called *bool) { + config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { + return sentinelErr + } + }, + validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { + if clientErr != sentinelErr { + t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr) + } + }, + }, + { + configureServer: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + }, + configureClient: func(config *Config, called *bool) { + config.InsecureSkipVerify = true + config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { + if l := len(rawCerts); l != 1 { + return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l) + } + // With InsecureSkipVerify set, this + // callback should still be called but + // validatedChains must be empty. + if l := len(validatedChains); l != 0 { + return fmt.Errorf("got len(validatedChains) = %d, wanted zero", l) + } + *called = true + return nil + } + }, + validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { + if clientErr != nil { + t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr) + } + if serverErr != nil { + t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr) + } + if !clientCalled { + t.Errorf("test[%d]: client did not call callback", testNo) + } + }, + }, + { + configureServer: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + config.VerifyConnection = func(c ConnectionState) error { + return verifyConnectionCallback(called, false, c) + } + }, + configureClient: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + config.VerifyConnection = func(c ConnectionState) error { + return verifyConnectionCallback(called, true, c) + } + }, + validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { + if clientErr != nil { + t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr) + } + if serverErr != nil { + t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr) + } + if !clientCalled { + t.Errorf("test[%d]: client did not call callback", testNo) + } + if !serverCalled { + t.Errorf("test[%d]: server did not call callback", testNo) + } + }, + }, + { + configureServer: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + config.VerifyConnection = func(c ConnectionState) error { + return sentinelErr + } + }, + configureClient: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + config.VerifyConnection = nil + }, + validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { + if serverErr != sentinelErr { + t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr) + } + }, + }, + { + configureServer: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + config.VerifyConnection = nil + }, + configureClient: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + config.VerifyConnection = func(c ConnectionState) error { + return sentinelErr + } + }, + validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { + if clientErr != sentinelErr { + t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr) + } + }, + }, + { + configureServer: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { + return verifyPeerCertificateCallback(called, rawCerts, validatedChains) + } + config.VerifyConnection = func(c ConnectionState) error { + return sentinelErr + } + }, + configureClient: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + config.VerifyPeerCertificate = nil + config.VerifyConnection = nil + }, + validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { + if serverErr != sentinelErr { + t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr) + } + if !serverCalled { + t.Errorf("test[%d]: server did not call callback", testNo) + } + }, + }, + { + configureServer: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + config.VerifyPeerCertificate = nil + config.VerifyConnection = nil + }, + configureClient: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { + return verifyPeerCertificateCallback(called, rawCerts, validatedChains) + } + config.VerifyConnection = func(c ConnectionState) error { + return sentinelErr + } + }, + validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { + if clientErr != sentinelErr { + t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr) + } + if !clientCalled { + t.Errorf("test[%d]: client did not call callback", testNo) + } + }, + }, + } + + for i, test := range tests { + c, s := localPipe(t) + done := make(chan error) + + var clientCalled, serverCalled bool + + go func() { + config := testConfig.Clone() + config.ServerName = "example.golang" + config.ClientAuth = RequireAndVerifyClientCert + config.ClientCAs = rootCAs + config.Time = now + config.MaxVersion = version + config.Certificates = make([]Certificate, 1) + config.Certificates[0].Certificate = [][]byte{testRSACertificate} + config.Certificates[0].PrivateKey = testRSAPrivateKey + config.Certificates[0].SignedCertificateTimestamps = [][]byte{[]byte("dummy sct 1"), []byte("dummy sct 2")} + config.Certificates[0].OCSPStaple = []byte("dummy ocsp") + test.configureServer(config, &serverCalled) + + err = Server(s, config).Handshake() + s.Close() + done <- err + }() + + config := testConfig.Clone() + config.ServerName = "example.golang" + config.RootCAs = rootCAs + config.Time = now + config.MaxVersion = version + test.configureClient(config, &clientCalled) + clientErr := Client(c, config).Handshake() + c.Close() + serverErr := <-done + + test.validate(t, i, clientCalled, serverCalled, clientErr, serverErr) + } +} + +// brokenConn wraps a net.Conn and causes all Writes after a certain number to +// fail with brokenConnErr. +type brokenConn struct { + net.Conn + + // breakAfter is the number of successful writes that will be allowed + // before all subsequent writes fail. + breakAfter int + + // numWrites is the number of writes that have been done. + numWrites int +} + +// brokenConnErr is the error that brokenConn returns once exhausted. +var brokenConnErr = errors.New("too many writes to brokenConn") + +func (b *brokenConn) Write(data []byte) (int, error) { + if b.numWrites >= b.breakAfter { + return 0, brokenConnErr + } + + b.numWrites++ + return b.Conn.Write(data) +} + +func TestFailedWrite(t *testing.T) { + // Test that a write error during the handshake is returned. + for _, breakAfter := range []int{0, 1} { + c, s := localPipe(t) + done := make(chan bool) + + go func() { + Server(s, testConfig).Handshake() + s.Close() + done <- true + }() + + brokenC := &brokenConn{Conn: c, breakAfter: breakAfter} + err := Client(brokenC, testConfig).Handshake() + if err != brokenConnErr { + t.Errorf("#%d: expected error from brokenConn but got %q", breakAfter, err) + } + brokenC.Close() + + <-done + } +} + +// writeCountingConn wraps a net.Conn and counts the number of Write calls. +type writeCountingConn struct { + net.Conn + + // numWrites is the number of writes that have been done. + numWrites int +} + +func (wcc *writeCountingConn) Write(data []byte) (int, error) { + wcc.numWrites++ + return wcc.Conn.Write(data) +} + +func TestBuffering(t *testing.T) { + t.Run("TLSv12", func(t *testing.T) { testBuffering(t, VersionTLS12) }) + t.Run("TLSv13", func(t *testing.T) { testBuffering(t, VersionTLS13) }) +} + +func testBuffering(t *testing.T, version uint16) { + c, s := localPipe(t) + done := make(chan bool) + + clientWCC := &writeCountingConn{Conn: c} + serverWCC := &writeCountingConn{Conn: s} + + go func() { + config := testConfig.Clone() + config.MaxVersion = version + Server(serverWCC, config).Handshake() + serverWCC.Close() + done <- true + }() + + err := Client(clientWCC, testConfig).Handshake() + if err != nil { + t.Fatal(err) + } + clientWCC.Close() + <-done + + var expectedClient, expectedServer int + if version == VersionTLS13 { + expectedClient = 2 + expectedServer = 1 + } else { + expectedClient = 2 + expectedServer = 2 + } + + if n := clientWCC.numWrites; n != expectedClient { + t.Errorf("expected client handshake to complete with %d writes, but saw %d", expectedClient, n) + } + + if n := serverWCC.numWrites; n != expectedServer { + t.Errorf("expected server handshake to complete with %d writes, but saw %d", expectedServer, n) + } +} + +func TestAlertFlushing(t *testing.T) { + c, s := localPipe(t) + done := make(chan bool) + + clientWCC := &writeCountingConn{Conn: c} + serverWCC := &writeCountingConn{Conn: s} + + serverConfig := testConfig.Clone() + + // Cause a signature-time error + brokenKey := rsa.PrivateKey{PublicKey: testRSAPrivateKey.PublicKey} + brokenKey.D = big.NewInt(42) + serverConfig.Certificates = []Certificate{{ + Certificate: [][]byte{testRSACertificate}, + PrivateKey: &brokenKey, + }} + + go func() { + Server(serverWCC, serverConfig).Handshake() + serverWCC.Close() + done <- true + }() + + err := Client(clientWCC, testConfig).Handshake() + if err == nil { + t.Fatal("client unexpectedly returned no error") + } + + const expectedError = "remote error: tls: internal error" + if e := err.Error(); !strings.Contains(e, expectedError) { + t.Fatalf("expected to find %q in error but error was %q", expectedError, e) + } + clientWCC.Close() + <-done + + if n := serverWCC.numWrites; n != 1 { + t.Errorf("expected server handshake to complete with one write, but saw %d", n) + } +} + +func TestHandshakeRace(t *testing.T) { + if testing.Short() { + t.Skip("skipping in -short mode") + } + t.Parallel() + // This test races a Read and Write to try and complete a handshake in + // order to provide some evidence that there are no races or deadlocks + // in the handshake locking. + for i := 0; i < 32; i++ { + c, s := localPipe(t) + + go func() { + server := Server(s, testConfig) + if err := server.Handshake(); err != nil { + panic(err) + } + + var request [1]byte + if n, err := server.Read(request[:]); err != nil || n != 1 { + panic(err) + } + + server.Write(request[:]) + server.Close() + }() + + startWrite := make(chan struct{}) + startRead := make(chan struct{}) + readDone := make(chan struct{}, 1) + + client := Client(c, testConfig) + go func() { + <-startWrite + var request [1]byte + client.Write(request[:]) + }() + + go func() { + <-startRead + var reply [1]byte + if _, err := io.ReadFull(client, reply[:]); err != nil { + panic(err) + } + c.Close() + readDone <- struct{}{} + }() + + if i&1 == 1 { + startWrite <- struct{}{} + startRead <- struct{}{} + } else { + startRead <- struct{}{} + startWrite <- struct{}{} + } + <-readDone + } +} + +var getClientCertificateTests = []struct { + setup func(*Config, *Config) + expectedClientError string + verify func(*testing.T, int, *ConnectionState) +}{ + { + func(clientConfig, serverConfig *Config) { + // Returning a Certificate with no certificate data + // should result in an empty message being sent to the + // server. + serverConfig.ClientCAs = nil + clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) { + if len(cri.SignatureSchemes) == 0 { + panic("empty SignatureSchemes") + } + if len(cri.AcceptableCAs) != 0 { + panic("AcceptableCAs should have been empty") + } + return new(Certificate), nil + } + }, + "", + func(t *testing.T, testNum int, cs *ConnectionState) { + if l := len(cs.PeerCertificates); l != 0 { + t.Errorf("#%d: expected no certificates but got %d", testNum, l) + } + }, + }, + { + func(clientConfig, serverConfig *Config) { + // With TLS 1.1, the SignatureSchemes should be + // synthesised from the supported certificate types. + clientConfig.MaxVersion = VersionTLS11 + clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) { + if len(cri.SignatureSchemes) == 0 { + panic("empty SignatureSchemes") + } + return new(Certificate), nil + } + }, + "", + func(t *testing.T, testNum int, cs *ConnectionState) { + if l := len(cs.PeerCertificates); l != 0 { + t.Errorf("#%d: expected no certificates but got %d", testNum, l) + } + }, + }, + { + func(clientConfig, serverConfig *Config) { + // Returning an error should abort the handshake with + // that error. + clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) { + return nil, errors.New("GetClientCertificate") + } + }, + "GetClientCertificate", + func(t *testing.T, testNum int, cs *ConnectionState) { + }, + }, + { + func(clientConfig, serverConfig *Config) { + clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) { + if len(cri.AcceptableCAs) == 0 { + panic("empty AcceptableCAs") + } + cert := &Certificate{ + Certificate: [][]byte{testRSACertificate}, + PrivateKey: testRSAPrivateKey, + } + return cert, nil + } + }, + "", + func(t *testing.T, testNum int, cs *ConnectionState) { + if len(cs.VerifiedChains) == 0 { + t.Errorf("#%d: expected some verified chains, but found none", testNum) + } + }, + }, +} + +func TestGetClientCertificate(t *testing.T) { + t.Run("TLSv12", func(t *testing.T) { testGetClientCertificate(t, VersionTLS12) }) + t.Run("TLSv13", func(t *testing.T) { testGetClientCertificate(t, VersionTLS13) }) +} + +func testGetClientCertificate(t *testing.T, version uint16) { + issuer, err := x509.ParseCertificate(testRSACertificateIssuer) + if err != nil { + panic(err) + } + + for i, test := range getClientCertificateTests { + serverConfig := testConfig.Clone() + serverConfig.ClientAuth = VerifyClientCertIfGiven + serverConfig.RootCAs = x509.NewCertPool() + serverConfig.RootCAs.AddCert(issuer) + serverConfig.ClientCAs = serverConfig.RootCAs + serverConfig.Time = func() time.Time { return time.Unix(1476984729, 0) } + serverConfig.MaxVersion = version + + clientConfig := testConfig.Clone() + clientConfig.MaxVersion = version + + test.setup(clientConfig, serverConfig) + + type serverResult struct { + cs ConnectionState + err error + } + + c, s := localPipe(t) + done := make(chan serverResult) + + go func() { + defer s.Close() + server := Server(s, serverConfig) + err := server.Handshake() + + var cs ConnectionState + if err == nil { + cs = server.ConnectionState() + } + done <- serverResult{cs, err} + }() + + clientErr := Client(c, clientConfig).Handshake() + c.Close() + + result := <-done + + if clientErr != nil { + if len(test.expectedClientError) == 0 { + t.Errorf("#%d: client error: %v", i, clientErr) + } else if got := clientErr.Error(); got != test.expectedClientError { + t.Errorf("#%d: expected client error %q, but got %q", i, test.expectedClientError, got) + } else { + test.verify(t, i, &result.cs) + } + } else if len(test.expectedClientError) > 0 { + t.Errorf("#%d: expected client error %q, but got no error", i, test.expectedClientError) + } else if err := result.err; err != nil { + t.Errorf("#%d: server error: %v", i, err) + } else { + test.verify(t, i, &result.cs) + } + } +} + +func TestRSAPSSKeyError(t *testing.T) { + // crypto/tls does not support the rsa_pss_pss_* SignatureSchemes. If support for + // public keys with OID RSASSA-PSS is added to crypto/x509, they will be misused with + // the rsa_pss_rsae_* SignatureSchemes. Assert that RSASSA-PSS certificates don't + // parse, or that they don't carry *rsa.PublicKey keys. + b, _ := pem.Decode([]byte(` +-----BEGIN CERTIFICATE----- +MIIDZTCCAhygAwIBAgIUCF2x0FyTgZG0CC9QTDjGWkB5vgEwPgYJKoZIhvcNAQEK +MDGgDTALBglghkgBZQMEAgGhGjAYBgkqhkiG9w0BAQgwCwYJYIZIAWUDBAIBogQC +AgDeMBIxEDAOBgNVBAMMB1JTQS1QU1MwHhcNMTgwNjI3MjI0NDM2WhcNMTgwNzI3 +MjI0NDM2WjASMRAwDgYDVQQDDAdSU0EtUFNTMIIBIDALBgkqhkiG9w0BAQoDggEP +ADCCAQoCggEBANxDm0f76JdI06YzsjB3AmmjIYkwUEGxePlafmIASFjDZl/elD0Z +/a7xLX468b0qGxLS5al7XCcEprSdsDR6DF5L520+pCbpfLyPOjuOvGmk9KzVX4x5 +b05YXYuXdsQ0Kjxcx2i3jjCday6scIhMJVgBZxTEyMj1thPQM14SHzKCd/m6HmCL +QmswpH2yMAAcBRWzRpp/vdH5DeOJEB3aelq7094no731mrLUCHRiZ1htq8BDB3ou +czwqgwspbqZ4dnMXl2MvfySQ5wJUxQwILbiuAKO2lVVPUbFXHE9pgtznNoPvKwQT +JNcX8ee8WIZc2SEGzofjk3NpjR+2ADB2u3sCAwEAAaNTMFEwHQYDVR0OBBYEFNEz +AdyJ2f+fU+vSCS6QzohnOnprMB8GA1UdIwQYMBaAFNEzAdyJ2f+fU+vSCS6Qzohn +OnprMA8GA1UdEwEB/wQFMAMBAf8wPgYJKoZIhvcNAQEKMDGgDTALBglghkgBZQME +AgGhGjAYBgkqhkiG9w0BAQgwCwYJYIZIAWUDBAIBogQCAgDeA4IBAQCjEdrR5aab +sZmCwrMeKidXgfkmWvfuLDE+TCbaqDZp7BMWcMQXT9O0UoUT5kqgKj2ARm2pEW0Z +H3Z1vj3bbds72qcDIJXp+l0fekyLGeCrX/CbgnMZXEP7+/+P416p34ChR1Wz4dU1 +KD3gdsUuTKKeMUog3plxlxQDhRQmiL25ygH1LmjLd6dtIt0GVRGr8lj3euVeprqZ +bZ3Uq5eLfsn8oPgfC57gpO6yiN+UURRTlK3bgYvLh4VWB3XXk9UaQZ7Mq1tpXjoD +HYFybkWzibkZp4WRo+Fa28rirH+/wHt0vfeN7UCceURZEx4JaxIIfe4ku7uDRhJi +RwBA9Xk1KBNF +-----END CERTIFICATE-----`)) + if b == nil { + t.Fatal("Failed to decode certificate") + } + cert, err := x509.ParseCertificate(b.Bytes) + if err != nil { + return + } + if _, ok := cert.PublicKey.(*rsa.PublicKey); ok { + t.Error("A RSASSA-PSS certificate was parsed like a PKCS#1 v1.5 one, and it will be mistakenly used with rsa_pss_rsae_* signature algorithms") + } +} + +func TestCloseClientConnectionOnIdleServer(t *testing.T) { + clientConn, serverConn := localPipe(t) + client := Client(clientConn, testConfig.Clone()) + go func() { + var b [1]byte + serverConn.Read(b[:]) + client.Close() + }() + client.SetWriteDeadline(time.Now().Add(time.Minute)) + err := client.Handshake() + if err != nil { + if err, ok := err.(net.Error); ok && err.Timeout() { + t.Errorf("Expected a closed network connection error but got '%s'", err.Error()) + } + } else { + t.Errorf("Error expected, but no error returned") + } +} + +func testDowngradeCanary(t *testing.T, clientVersion, serverVersion uint16) error { + defer func() { testingOnlyForceDowngradeCanary = false }() + testingOnlyForceDowngradeCanary = true + + clientConfig := testConfig.Clone() + clientConfig.MaxVersion = clientVersion + serverConfig := testConfig.Clone() + serverConfig.MaxVersion = serverVersion + _, _, err := testHandshake(t, clientConfig, serverConfig) + return err +} + +func TestDowngradeCanary(t *testing.T) { + if err := testDowngradeCanary(t, VersionTLS13, VersionTLS12); err == nil { + t.Errorf("downgrade from TLS 1.3 to TLS 1.2 was not detected") + } + if testing.Short() { + t.Skip("skipping the rest of the checks in short mode") + } + if err := testDowngradeCanary(t, VersionTLS13, VersionTLS11); err == nil { + t.Errorf("downgrade from TLS 1.3 to TLS 1.1 was not detected") + } + if err := testDowngradeCanary(t, VersionTLS13, VersionTLS10); err == nil { + t.Errorf("downgrade from TLS 1.3 to TLS 1.0 was not detected") + } + if err := testDowngradeCanary(t, VersionTLS12, VersionTLS11); err == nil { + t.Errorf("downgrade from TLS 1.2 to TLS 1.1 was not detected") + } + if err := testDowngradeCanary(t, VersionTLS12, VersionTLS10); err == nil { + t.Errorf("downgrade from TLS 1.2 to TLS 1.0 was not detected") + } + if err := testDowngradeCanary(t, VersionTLS13, VersionTLS13); err != nil { + t.Errorf("server unexpectedly sent downgrade canary for TLS 1.3") + } + if err := testDowngradeCanary(t, VersionTLS12, VersionTLS12); err != nil { + t.Errorf("client didn't ignore expected TLS 1.2 canary") + } + if err := testDowngradeCanary(t, VersionTLS11, VersionTLS11); err != nil { + t.Errorf("client unexpectedly reacted to a canary in TLS 1.1") + } + if err := testDowngradeCanary(t, VersionTLS10, VersionTLS10); err != nil { + t.Errorf("client unexpectedly reacted to a canary in TLS 1.0") + } +} + +func TestResumptionKeepsOCSPAndSCT(t *testing.T) { + t.Run("TLSv12", func(t *testing.T) { testResumptionKeepsOCSPAndSCT(t, VersionTLS12) }) + t.Run("TLSv13", func(t *testing.T) { testResumptionKeepsOCSPAndSCT(t, VersionTLS13) }) +} + +func testResumptionKeepsOCSPAndSCT(t *testing.T, ver uint16) { + issuer, err := x509.ParseCertificate(testRSACertificateIssuer) + if err != nil { + t.Fatalf("failed to parse test issuer") + } + roots := x509.NewCertPool() + roots.AddCert(issuer) + clientConfig := &Config{ + MaxVersion: ver, + ClientSessionCache: NewLRUClientSessionCache(32), + ServerName: "example.golang", + RootCAs: roots, + } + serverConfig := testConfig.Clone() + serverConfig.MaxVersion = ver + serverConfig.Certificates[0].OCSPStaple = []byte{1, 2, 3} + serverConfig.Certificates[0].SignedCertificateTimestamps = [][]byte{{4, 5, 6}} + + _, ccs, err := testHandshake(t, clientConfig, serverConfig) + if err != nil { + t.Fatalf("handshake failed: %s", err) + } + // after a new session we expect to see OCSPResponse and + // SignedCertificateTimestamps populated as usual + if !bytes.Equal(ccs.OCSPResponse, serverConfig.Certificates[0].OCSPStaple) { + t.Errorf("client ConnectionState contained unexpected OCSPResponse: wanted %v, got %v", + serverConfig.Certificates[0].OCSPStaple, ccs.OCSPResponse) + } + if !reflect.DeepEqual(ccs.SignedCertificateTimestamps, serverConfig.Certificates[0].SignedCertificateTimestamps) { + t.Errorf("client ConnectionState contained unexpected SignedCertificateTimestamps: wanted %v, got %v", + serverConfig.Certificates[0].SignedCertificateTimestamps, ccs.SignedCertificateTimestamps) + } + + // if the server doesn't send any SCTs, repopulate the old SCTs + oldSCTs := serverConfig.Certificates[0].SignedCertificateTimestamps + serverConfig.Certificates[0].SignedCertificateTimestamps = nil + _, ccs, err = testHandshake(t, clientConfig, serverConfig) + if err != nil { + t.Fatalf("handshake failed: %s", err) + } + if !ccs.DidResume { + t.Fatalf("expected session to be resumed") + } + // after a resumed session we also expect to see OCSPResponse + // and SignedCertificateTimestamps populated + if !bytes.Equal(ccs.OCSPResponse, serverConfig.Certificates[0].OCSPStaple) { + t.Errorf("client ConnectionState contained unexpected OCSPResponse after resumption: wanted %v, got %v", + serverConfig.Certificates[0].OCSPStaple, ccs.OCSPResponse) + } + if !reflect.DeepEqual(ccs.SignedCertificateTimestamps, oldSCTs) { + t.Errorf("client ConnectionState contained unexpected SignedCertificateTimestamps after resumption: wanted %v, got %v", + oldSCTs, ccs.SignedCertificateTimestamps) + } + + // Only test overriding the SCTs for TLS 1.2, since in 1.3 + // the server won't send the message containing them + if ver == VersionTLS13 { + return + } + + // if the server changes the SCTs it sends, they should override the saved SCTs + serverConfig.Certificates[0].SignedCertificateTimestamps = [][]byte{{7, 8, 9}} + _, ccs, err = testHandshake(t, clientConfig, serverConfig) + if err != nil { + t.Fatalf("handshake failed: %s", err) + } + if !ccs.DidResume { + t.Fatalf("expected session to be resumed") + } + if !reflect.DeepEqual(ccs.SignedCertificateTimestamps, serverConfig.Certificates[0].SignedCertificateTimestamps) { + t.Errorf("client ConnectionState contained unexpected SignedCertificateTimestamps after resumption: wanted %v, got %v", + serverConfig.Certificates[0].SignedCertificateTimestamps, ccs.SignedCertificateTimestamps) + } +} + +// TestClientHandshakeContextCancellation tests that canceling +// the context given to the client side conn.HandshakeContext +// interrupts the in-progress handshake. +func TestClientHandshakeContextCancellation(t *testing.T) { + c, s := localPipe(t) + ctx, cancel := context.WithCancel(context.Background()) + unblockServer := make(chan struct{}) + defer close(unblockServer) + go func() { + cancel() + <-unblockServer + _ = s.Close() + }() + cli := Client(c, testConfig) + // Initiates client side handshake, which will block until the client hello is read + // by the server, unless the cancellation works. + err := cli.HandshakeContext(ctx) + if err == nil { + t.Fatal("Client handshake did not error when the context was canceled") + } + if err != context.Canceled { + t.Errorf("Unexpected client handshake error: %v", err) + } + if runtime.GOARCH == "wasm" { + t.Skip("conn.Close does not error as expected when called multiple times on WASM") + } + err = cli.Close() + if err == nil { + t.Error("Client connection was not closed when the context was canceled") + } +} |