summaryrefslogtreecommitdiffstats
path: root/src/crypto/tls/handshake_client_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'src/crypto/tls/handshake_client_test.go')
-rw-r--r--src/crypto/tls/handshake_client_test.go2597
1 files changed, 2597 insertions, 0 deletions
diff --git a/src/crypto/tls/handshake_client_test.go b/src/crypto/tls/handshake_client_test.go
new file mode 100644
index 0000000..749c9fc
--- /dev/null
+++ b/src/crypto/tls/handshake_client_test.go
@@ -0,0 +1,2597 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package tls
+
+import (
+ "bytes"
+ "context"
+ "crypto/rsa"
+ "crypto/x509"
+ "encoding/base64"
+ "encoding/binary"
+ "encoding/pem"
+ "errors"
+ "fmt"
+ "io"
+ "math/big"
+ "net"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "reflect"
+ "runtime"
+ "strconv"
+ "strings"
+ "testing"
+ "time"
+)
+
+// Note: see comment in handshake_test.go for details of how the reference
+// tests work.
+
+// opensslInputEvent enumerates possible inputs that can be sent to an `openssl
+// s_client` process.
+type opensslInputEvent int
+
+const (
+ // opensslRenegotiate causes OpenSSL to request a renegotiation of the
+ // connection.
+ opensslRenegotiate opensslInputEvent = iota
+
+ // opensslSendBanner causes OpenSSL to send the contents of
+ // opensslSentinel on the connection.
+ opensslSendSentinel
+
+ // opensslKeyUpdate causes OpenSSL to send a key update message to the
+ // client and request one back.
+ opensslKeyUpdate
+)
+
+const opensslSentinel = "SENTINEL\n"
+
+type opensslInput chan opensslInputEvent
+
+func (i opensslInput) Read(buf []byte) (n int, err error) {
+ for event := range i {
+ switch event {
+ case opensslRenegotiate:
+ return copy(buf, []byte("R\n")), nil
+ case opensslKeyUpdate:
+ return copy(buf, []byte("K\n")), nil
+ case opensslSendSentinel:
+ return copy(buf, []byte(opensslSentinel)), nil
+ default:
+ panic("unknown event")
+ }
+ }
+
+ return 0, io.EOF
+}
+
+// opensslOutputSink is an io.Writer that receives the stdout and stderr from an
+// `openssl` process and sends a value to handshakeComplete or readKeyUpdate
+// when certain messages are seen.
+type opensslOutputSink struct {
+ handshakeComplete chan struct{}
+ readKeyUpdate chan struct{}
+ all []byte
+ line []byte
+}
+
+func newOpensslOutputSink() *opensslOutputSink {
+ return &opensslOutputSink{make(chan struct{}), make(chan struct{}), nil, nil}
+}
+
+// opensslEndOfHandshake is a message that the “openssl s_server” tool will
+// print when a handshake completes if run with “-state”.
+const opensslEndOfHandshake = "SSL_accept:SSLv3/TLS write finished"
+
+// opensslReadKeyUpdate is a message that the “openssl s_server” tool will
+// print when a KeyUpdate message is received if run with “-state”.
+const opensslReadKeyUpdate = "SSL_accept:TLSv1.3 read client key update"
+
+func (o *opensslOutputSink) Write(data []byte) (n int, err error) {
+ o.line = append(o.line, data...)
+ o.all = append(o.all, data...)
+
+ for {
+ line, next, ok := bytes.Cut(o.line, []byte("\n"))
+ if !ok {
+ break
+ }
+
+ if bytes.Equal([]byte(opensslEndOfHandshake), line) {
+ o.handshakeComplete <- struct{}{}
+ }
+ if bytes.Equal([]byte(opensslReadKeyUpdate), line) {
+ o.readKeyUpdate <- struct{}{}
+ }
+ o.line = next
+ }
+
+ return len(data), nil
+}
+
+func (o *opensslOutputSink) String() string {
+ return string(o.all)
+}
+
+// clientTest represents a test of the TLS client handshake against a reference
+// implementation.
+type clientTest struct {
+ // name is a freeform string identifying the test and the file in which
+ // the expected results will be stored.
+ name string
+ // args, if not empty, contains a series of arguments for the
+ // command to run for the reference server.
+ args []string
+ // config, if not nil, contains a custom Config to use for this test.
+ config *Config
+ // cert, if not empty, contains a DER-encoded certificate for the
+ // reference server.
+ cert []byte
+ // key, if not nil, contains either a *rsa.PrivateKey, ed25519.PrivateKey or
+ // *ecdsa.PrivateKey which is the private key for the reference server.
+ key any
+ // extensions, if not nil, contains a list of extension data to be returned
+ // from the ServerHello. The data should be in standard TLS format with
+ // a 2-byte uint16 type, 2-byte data length, followed by the extension data.
+ extensions [][]byte
+ // validate, if not nil, is a function that will be called with the
+ // ConnectionState of the resulting connection. It returns a non-nil
+ // error if the ConnectionState is unacceptable.
+ validate func(ConnectionState) error
+ // numRenegotiations is the number of times that the connection will be
+ // renegotiated.
+ numRenegotiations int
+ // renegotiationExpectedToFail, if not zero, is the number of the
+ // renegotiation attempt that is expected to fail.
+ renegotiationExpectedToFail int
+ // checkRenegotiationError, if not nil, is called with any error
+ // arising from renegotiation. It can map expected errors to nil to
+ // ignore them.
+ checkRenegotiationError func(renegotiationNum int, err error) error
+ // sendKeyUpdate will cause the server to send a KeyUpdate message.
+ sendKeyUpdate bool
+}
+
+var serverCommand = []string{"openssl", "s_server", "-no_ticket", "-num_tickets", "0"}
+
+// connFromCommand starts the reference server process, connects to it and
+// returns a recordingConn for the connection. The stdin return value is an
+// opensslInput for the stdin of the child process. It must be closed before
+// Waiting for child.
+func (test *clientTest) connFromCommand() (conn *recordingConn, child *exec.Cmd, stdin opensslInput, stdout *opensslOutputSink, err error) {
+ cert := testRSACertificate
+ if len(test.cert) > 0 {
+ cert = test.cert
+ }
+ certPath := tempFile(string(cert))
+ defer os.Remove(certPath)
+
+ var key any = testRSAPrivateKey
+ if test.key != nil {
+ key = test.key
+ }
+ derBytes, err := x509.MarshalPKCS8PrivateKey(key)
+ if err != nil {
+ panic(err)
+ }
+
+ var pemOut bytes.Buffer
+ pem.Encode(&pemOut, &pem.Block{Type: "PRIVATE KEY", Bytes: derBytes})
+
+ keyPath := tempFile(pemOut.String())
+ defer os.Remove(keyPath)
+
+ var command []string
+ command = append(command, serverCommand...)
+ command = append(command, test.args...)
+ command = append(command, "-cert", certPath, "-certform", "DER", "-key", keyPath)
+ // serverPort contains the port that OpenSSL will listen on. OpenSSL
+ // can't take "0" as an argument here so we have to pick a number and
+ // hope that it's not in use on the machine. Since this only occurs
+ // when -update is given and thus when there's a human watching the
+ // test, this isn't too bad.
+ const serverPort = 24323
+ command = append(command, "-accept", strconv.Itoa(serverPort))
+
+ if len(test.extensions) > 0 {
+ var serverInfo bytes.Buffer
+ for _, ext := range test.extensions {
+ pem.Encode(&serverInfo, &pem.Block{
+ Type: fmt.Sprintf("SERVERINFO FOR EXTENSION %d", binary.BigEndian.Uint16(ext)),
+ Bytes: ext,
+ })
+ }
+ serverInfoPath := tempFile(serverInfo.String())
+ defer os.Remove(serverInfoPath)
+ command = append(command, "-serverinfo", serverInfoPath)
+ }
+
+ if test.numRenegotiations > 0 || test.sendKeyUpdate {
+ found := false
+ for _, flag := range command[1:] {
+ if flag == "-state" {
+ found = true
+ break
+ }
+ }
+
+ if !found {
+ panic("-state flag missing to OpenSSL, you need this if testing renegotiation or KeyUpdate")
+ }
+ }
+
+ cmd := exec.Command(command[0], command[1:]...)
+ stdin = opensslInput(make(chan opensslInputEvent))
+ cmd.Stdin = stdin
+ out := newOpensslOutputSink()
+ cmd.Stdout = out
+ cmd.Stderr = out
+ if err := cmd.Start(); err != nil {
+ return nil, nil, nil, nil, err
+ }
+
+ // OpenSSL does print an "ACCEPT" banner, but it does so *before*
+ // opening the listening socket, so we can't use that to wait until it
+ // has started listening. Thus we are forced to poll until we get a
+ // connection.
+ var tcpConn net.Conn
+ for i := uint(0); i < 5; i++ {
+ tcpConn, err = net.DialTCP("tcp", nil, &net.TCPAddr{
+ IP: net.IPv4(127, 0, 0, 1),
+ Port: serverPort,
+ })
+ if err == nil {
+ break
+ }
+ time.Sleep((1 << i) * 5 * time.Millisecond)
+ }
+ if err != nil {
+ close(stdin)
+ cmd.Process.Kill()
+ err = fmt.Errorf("error connecting to the OpenSSL server: %v (%v)\n\n%s", err, cmd.Wait(), out)
+ return nil, nil, nil, nil, err
+ }
+
+ record := &recordingConn{
+ Conn: tcpConn,
+ }
+
+ return record, cmd, stdin, out, nil
+}
+
+func (test *clientTest) dataPath() string {
+ return filepath.Join("testdata", "Client-"+test.name)
+}
+
+func (test *clientTest) loadData() (flows [][]byte, err error) {
+ in, err := os.Open(test.dataPath())
+ if err != nil {
+ return nil, err
+ }
+ defer in.Close()
+ return parseTestData(in)
+}
+
+func (test *clientTest) run(t *testing.T, write bool) {
+ var clientConn, serverConn net.Conn
+ var recordingConn *recordingConn
+ var childProcess *exec.Cmd
+ var stdin opensslInput
+ var stdout *opensslOutputSink
+
+ if write {
+ var err error
+ recordingConn, childProcess, stdin, stdout, err = test.connFromCommand()
+ if err != nil {
+ t.Fatalf("Failed to start subcommand: %s", err)
+ }
+ clientConn = recordingConn
+ defer func() {
+ if t.Failed() {
+ t.Logf("OpenSSL output:\n\n%s", stdout.all)
+ }
+ }()
+ } else {
+ clientConn, serverConn = localPipe(t)
+ }
+
+ doneChan := make(chan bool)
+ defer func() {
+ clientConn.Close()
+ <-doneChan
+ }()
+ go func() {
+ defer close(doneChan)
+
+ config := test.config
+ if config == nil {
+ config = testConfig
+ }
+ client := Client(clientConn, config)
+ defer client.Close()
+
+ if _, err := client.Write([]byte("hello\n")); err != nil {
+ t.Errorf("Client.Write failed: %s", err)
+ return
+ }
+
+ for i := 1; i <= test.numRenegotiations; i++ {
+ // The initial handshake will generate a
+ // handshakeComplete signal which needs to be quashed.
+ if i == 1 && write {
+ <-stdout.handshakeComplete
+ }
+
+ // OpenSSL will try to interleave application data and
+ // a renegotiation if we send both concurrently.
+ // Therefore: ask OpensSSL to start a renegotiation, run
+ // a goroutine to call client.Read and thus process the
+ // renegotiation request, watch for OpenSSL's stdout to
+ // indicate that the handshake is complete and,
+ // finally, have OpenSSL write something to cause
+ // client.Read to complete.
+ if write {
+ stdin <- opensslRenegotiate
+ }
+
+ signalChan := make(chan struct{})
+
+ go func() {
+ defer close(signalChan)
+
+ buf := make([]byte, 256)
+ n, err := client.Read(buf)
+
+ if test.checkRenegotiationError != nil {
+ newErr := test.checkRenegotiationError(i, err)
+ if err != nil && newErr == nil {
+ return
+ }
+ err = newErr
+ }
+
+ if err != nil {
+ t.Errorf("Client.Read failed after renegotiation #%d: %s", i, err)
+ return
+ }
+
+ buf = buf[:n]
+ if !bytes.Equal([]byte(opensslSentinel), buf) {
+ t.Errorf("Client.Read returned %q, but wanted %q", string(buf), opensslSentinel)
+ }
+
+ if expected := i + 1; client.handshakes != expected {
+ t.Errorf("client should have recorded %d handshakes, but believes that %d have occurred", expected, client.handshakes)
+ }
+ }()
+
+ if write && test.renegotiationExpectedToFail != i {
+ <-stdout.handshakeComplete
+ stdin <- opensslSendSentinel
+ }
+ <-signalChan
+ }
+
+ if test.sendKeyUpdate {
+ if write {
+ <-stdout.handshakeComplete
+ stdin <- opensslKeyUpdate
+ }
+
+ doneRead := make(chan struct{})
+
+ go func() {
+ defer close(doneRead)
+
+ buf := make([]byte, 256)
+ n, err := client.Read(buf)
+
+ if err != nil {
+ t.Errorf("Client.Read failed after KeyUpdate: %s", err)
+ return
+ }
+
+ buf = buf[:n]
+ if !bytes.Equal([]byte(opensslSentinel), buf) {
+ t.Errorf("Client.Read returned %q, but wanted %q", string(buf), opensslSentinel)
+ }
+ }()
+
+ if write {
+ // There's no real reason to wait for the client KeyUpdate to
+ // send data with the new server keys, except that s_server
+ // drops writes if they are sent at the wrong time.
+ <-stdout.readKeyUpdate
+ stdin <- opensslSendSentinel
+ }
+ <-doneRead
+
+ if _, err := client.Write([]byte("hello again\n")); err != nil {
+ t.Errorf("Client.Write failed: %s", err)
+ return
+ }
+ }
+
+ if test.validate != nil {
+ if err := test.validate(client.ConnectionState()); err != nil {
+ t.Errorf("validate callback returned error: %s", err)
+ }
+ }
+
+ // If the server sent us an alert after our last flight, give it a
+ // chance to arrive.
+ if write && test.renegotiationExpectedToFail == 0 {
+ if err := peekError(client); err != nil {
+ t.Errorf("final Read returned an error: %s", err)
+ }
+ }
+ }()
+
+ if !write {
+ flows, err := test.loadData()
+ if err != nil {
+ t.Fatalf("%s: failed to load data from %s: %v", test.name, test.dataPath(), err)
+ }
+ for i, b := range flows {
+ if i%2 == 1 {
+ if *fast {
+ serverConn.SetWriteDeadline(time.Now().Add(1 * time.Second))
+ } else {
+ serverConn.SetWriteDeadline(time.Now().Add(1 * time.Minute))
+ }
+ serverConn.Write(b)
+ continue
+ }
+ bb := make([]byte, len(b))
+ if *fast {
+ serverConn.SetReadDeadline(time.Now().Add(1 * time.Second))
+ } else {
+ serverConn.SetReadDeadline(time.Now().Add(1 * time.Minute))
+ }
+ _, err := io.ReadFull(serverConn, bb)
+ if err != nil {
+ t.Fatalf("%s, flow %d: %s", test.name, i+1, err)
+ }
+ if !bytes.Equal(b, bb) {
+ t.Fatalf("%s, flow %d: mismatch on read: got:%x want:%x", test.name, i+1, bb, b)
+ }
+ }
+ }
+
+ <-doneChan
+ if !write {
+ serverConn.Close()
+ }
+
+ if write {
+ path := test.dataPath()
+ out, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
+ if err != nil {
+ t.Fatalf("Failed to create output file: %s", err)
+ }
+ defer out.Close()
+ recordingConn.Close()
+ close(stdin)
+ childProcess.Process.Kill()
+ childProcess.Wait()
+ if len(recordingConn.flows) < 3 {
+ t.Fatalf("Client connection didn't work")
+ }
+ recordingConn.WriteTo(out)
+ t.Logf("Wrote %s\n", path)
+ }
+}
+
+// peekError does a read with a short timeout to check if the next read would
+// cause an error, for example if there is an alert waiting on the wire.
+func peekError(conn net.Conn) error {
+ conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
+ if n, err := conn.Read(make([]byte, 1)); n != 0 {
+ return errors.New("unexpectedly read data")
+ } else if err != nil {
+ if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() {
+ return err
+ }
+ }
+ return nil
+}
+
+func runClientTestForVersion(t *testing.T, template *clientTest, version, option string) {
+ // Make a deep copy of the template before going parallel.
+ test := *template
+ if template.config != nil {
+ test.config = template.config.Clone()
+ }
+ test.name = version + "-" + test.name
+ test.args = append([]string{option}, test.args...)
+
+ runTestAndUpdateIfNeeded(t, version, test.run, false)
+}
+
+func runClientTestTLS10(t *testing.T, template *clientTest) {
+ runClientTestForVersion(t, template, "TLSv10", "-tls1")
+}
+
+func runClientTestTLS11(t *testing.T, template *clientTest) {
+ runClientTestForVersion(t, template, "TLSv11", "-tls1_1")
+}
+
+func runClientTestTLS12(t *testing.T, template *clientTest) {
+ runClientTestForVersion(t, template, "TLSv12", "-tls1_2")
+}
+
+func runClientTestTLS13(t *testing.T, template *clientTest) {
+ runClientTestForVersion(t, template, "TLSv13", "-tls1_3")
+}
+
+func TestHandshakeClientRSARC4(t *testing.T) {
+ test := &clientTest{
+ name: "RSA-RC4",
+ args: []string{"-cipher", "RC4-SHA"},
+ }
+ runClientTestTLS10(t, test)
+ runClientTestTLS11(t, test)
+ runClientTestTLS12(t, test)
+}
+
+func TestHandshakeClientRSAAES128GCM(t *testing.T) {
+ test := &clientTest{
+ name: "AES128-GCM-SHA256",
+ args: []string{"-cipher", "AES128-GCM-SHA256"},
+ }
+ runClientTestTLS12(t, test)
+}
+
+func TestHandshakeClientRSAAES256GCM(t *testing.T) {
+ test := &clientTest{
+ name: "AES256-GCM-SHA384",
+ args: []string{"-cipher", "AES256-GCM-SHA384"},
+ }
+ runClientTestTLS12(t, test)
+}
+
+func TestHandshakeClientECDHERSAAES(t *testing.T) {
+ test := &clientTest{
+ name: "ECDHE-RSA-AES",
+ args: []string{"-cipher", "ECDHE-RSA-AES128-SHA"},
+ }
+ runClientTestTLS10(t, test)
+ runClientTestTLS11(t, test)
+ runClientTestTLS12(t, test)
+}
+
+func TestHandshakeClientECDHEECDSAAES(t *testing.T) {
+ test := &clientTest{
+ name: "ECDHE-ECDSA-AES",
+ args: []string{"-cipher", "ECDHE-ECDSA-AES128-SHA"},
+ cert: testECDSACertificate,
+ key: testECDSAPrivateKey,
+ }
+ runClientTestTLS10(t, test)
+ runClientTestTLS11(t, test)
+ runClientTestTLS12(t, test)
+}
+
+func TestHandshakeClientECDHEECDSAAESGCM(t *testing.T) {
+ test := &clientTest{
+ name: "ECDHE-ECDSA-AES-GCM",
+ args: []string{"-cipher", "ECDHE-ECDSA-AES128-GCM-SHA256"},
+ cert: testECDSACertificate,
+ key: testECDSAPrivateKey,
+ }
+ runClientTestTLS12(t, test)
+}
+
+func TestHandshakeClientAES256GCMSHA384(t *testing.T) {
+ test := &clientTest{
+ name: "ECDHE-ECDSA-AES256-GCM-SHA384",
+ args: []string{"-cipher", "ECDHE-ECDSA-AES256-GCM-SHA384"},
+ cert: testECDSACertificate,
+ key: testECDSAPrivateKey,
+ }
+ runClientTestTLS12(t, test)
+}
+
+func TestHandshakeClientAES128CBCSHA256(t *testing.T) {
+ test := &clientTest{
+ name: "AES128-SHA256",
+ args: []string{"-cipher", "AES128-SHA256"},
+ }
+ runClientTestTLS12(t, test)
+}
+
+func TestHandshakeClientECDHERSAAES128CBCSHA256(t *testing.T) {
+ test := &clientTest{
+ name: "ECDHE-RSA-AES128-SHA256",
+ args: []string{"-cipher", "ECDHE-RSA-AES128-SHA256"},
+ }
+ runClientTestTLS12(t, test)
+}
+
+func TestHandshakeClientECDHEECDSAAES128CBCSHA256(t *testing.T) {
+ test := &clientTest{
+ name: "ECDHE-ECDSA-AES128-SHA256",
+ args: []string{"-cipher", "ECDHE-ECDSA-AES128-SHA256"},
+ cert: testECDSACertificate,
+ key: testECDSAPrivateKey,
+ }
+ runClientTestTLS12(t, test)
+}
+
+func TestHandshakeClientX25519(t *testing.T) {
+ config := testConfig.Clone()
+ config.CurvePreferences = []CurveID{X25519}
+
+ test := &clientTest{
+ name: "X25519-ECDHE",
+ args: []string{"-cipher", "ECDHE-RSA-AES128-GCM-SHA256", "-curves", "X25519"},
+ config: config,
+ }
+
+ runClientTestTLS12(t, test)
+ runClientTestTLS13(t, test)
+}
+
+func TestHandshakeClientP256(t *testing.T) {
+ config := testConfig.Clone()
+ config.CurvePreferences = []CurveID{CurveP256}
+
+ test := &clientTest{
+ name: "P256-ECDHE",
+ args: []string{"-cipher", "ECDHE-RSA-AES128-GCM-SHA256", "-curves", "P-256"},
+ config: config,
+ }
+
+ runClientTestTLS12(t, test)
+ runClientTestTLS13(t, test)
+}
+
+func TestHandshakeClientHelloRetryRequest(t *testing.T) {
+ config := testConfig.Clone()
+ config.CurvePreferences = []CurveID{X25519, CurveP256}
+
+ test := &clientTest{
+ name: "HelloRetryRequest",
+ args: []string{"-cipher", "ECDHE-RSA-AES128-GCM-SHA256", "-curves", "P-256"},
+ config: config,
+ }
+
+ runClientTestTLS13(t, test)
+}
+
+func TestHandshakeClientECDHERSAChaCha20(t *testing.T) {
+ config := testConfig.Clone()
+ config.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305}
+
+ test := &clientTest{
+ name: "ECDHE-RSA-CHACHA20-POLY1305",
+ args: []string{"-cipher", "ECDHE-RSA-CHACHA20-POLY1305"},
+ config: config,
+ }
+
+ runClientTestTLS12(t, test)
+}
+
+func TestHandshakeClientECDHEECDSAChaCha20(t *testing.T) {
+ config := testConfig.Clone()
+ config.CipherSuites = []uint16{TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305}
+
+ test := &clientTest{
+ name: "ECDHE-ECDSA-CHACHA20-POLY1305",
+ args: []string{"-cipher", "ECDHE-ECDSA-CHACHA20-POLY1305"},
+ config: config,
+ cert: testECDSACertificate,
+ key: testECDSAPrivateKey,
+ }
+
+ runClientTestTLS12(t, test)
+}
+
+func TestHandshakeClientAES128SHA256(t *testing.T) {
+ test := &clientTest{
+ name: "AES128-SHA256",
+ args: []string{"-ciphersuites", "TLS_AES_128_GCM_SHA256"},
+ }
+ runClientTestTLS13(t, test)
+}
+func TestHandshakeClientAES256SHA384(t *testing.T) {
+ test := &clientTest{
+ name: "AES256-SHA384",
+ args: []string{"-ciphersuites", "TLS_AES_256_GCM_SHA384"},
+ }
+ runClientTestTLS13(t, test)
+}
+func TestHandshakeClientCHACHA20SHA256(t *testing.T) {
+ test := &clientTest{
+ name: "CHACHA20-SHA256",
+ args: []string{"-ciphersuites", "TLS_CHACHA20_POLY1305_SHA256"},
+ }
+ runClientTestTLS13(t, test)
+}
+
+func TestHandshakeClientECDSATLS13(t *testing.T) {
+ test := &clientTest{
+ name: "ECDSA",
+ cert: testECDSACertificate,
+ key: testECDSAPrivateKey,
+ }
+ runClientTestTLS13(t, test)
+}
+
+func TestHandshakeClientEd25519(t *testing.T) {
+ test := &clientTest{
+ name: "Ed25519",
+ cert: testEd25519Certificate,
+ key: testEd25519PrivateKey,
+ }
+ runClientTestTLS12(t, test)
+ runClientTestTLS13(t, test)
+
+ config := testConfig.Clone()
+ cert, _ := X509KeyPair([]byte(clientEd25519CertificatePEM), []byte(clientEd25519KeyPEM))
+ config.Certificates = []Certificate{cert}
+
+ test = &clientTest{
+ name: "ClientCert-Ed25519",
+ args: []string{"-Verify", "1"},
+ config: config,
+ }
+
+ runClientTestTLS12(t, test)
+ runClientTestTLS13(t, test)
+}
+
+func TestHandshakeClientCertRSA(t *testing.T) {
+ config := testConfig.Clone()
+ cert, _ := X509KeyPair([]byte(clientCertificatePEM), []byte(clientKeyPEM))
+ config.Certificates = []Certificate{cert}
+
+ test := &clientTest{
+ name: "ClientCert-RSA-RSA",
+ args: []string{"-cipher", "AES128", "-Verify", "1"},
+ config: config,
+ }
+
+ runClientTestTLS10(t, test)
+ runClientTestTLS12(t, test)
+
+ test = &clientTest{
+ name: "ClientCert-RSA-ECDSA",
+ args: []string{"-cipher", "ECDHE-ECDSA-AES128-SHA", "-Verify", "1"},
+ config: config,
+ cert: testECDSACertificate,
+ key: testECDSAPrivateKey,
+ }
+
+ runClientTestTLS10(t, test)
+ runClientTestTLS12(t, test)
+ runClientTestTLS13(t, test)
+
+ test = &clientTest{
+ name: "ClientCert-RSA-AES256-GCM-SHA384",
+ args: []string{"-cipher", "ECDHE-RSA-AES256-GCM-SHA384", "-Verify", "1"},
+ config: config,
+ cert: testRSACertificate,
+ key: testRSAPrivateKey,
+ }
+
+ runClientTestTLS12(t, test)
+}
+
+func TestHandshakeClientCertECDSA(t *testing.T) {
+ config := testConfig.Clone()
+ cert, _ := X509KeyPair([]byte(clientECDSACertificatePEM), []byte(clientECDSAKeyPEM))
+ config.Certificates = []Certificate{cert}
+
+ test := &clientTest{
+ name: "ClientCert-ECDSA-RSA",
+ args: []string{"-cipher", "AES128", "-Verify", "1"},
+ config: config,
+ }
+
+ runClientTestTLS10(t, test)
+ runClientTestTLS12(t, test)
+ runClientTestTLS13(t, test)
+
+ test = &clientTest{
+ name: "ClientCert-ECDSA-ECDSA",
+ args: []string{"-cipher", "ECDHE-ECDSA-AES128-SHA", "-Verify", "1"},
+ config: config,
+ cert: testECDSACertificate,
+ key: testECDSAPrivateKey,
+ }
+
+ runClientTestTLS10(t, test)
+ runClientTestTLS12(t, test)
+}
+
+// TestHandshakeClientCertRSAPSS tests rsa_pss_rsae_sha256 signatures from both
+// client and server certificates. It also serves from both sides a certificate
+// signed itself with RSA-PSS, mostly to check that crypto/x509 chain validation
+// works.
+func TestHandshakeClientCertRSAPSS(t *testing.T) {
+ cert, err := x509.ParseCertificate(testRSAPSSCertificate)
+ if err != nil {
+ panic(err)
+ }
+ rootCAs := x509.NewCertPool()
+ rootCAs.AddCert(cert)
+
+ config := testConfig.Clone()
+ // Use GetClientCertificate to bypass the client certificate selection logic.
+ config.GetClientCertificate = func(*CertificateRequestInfo) (*Certificate, error) {
+ return &Certificate{
+ Certificate: [][]byte{testRSAPSSCertificate},
+ PrivateKey: testRSAPrivateKey,
+ }, nil
+ }
+ config.RootCAs = rootCAs
+
+ test := &clientTest{
+ name: "ClientCert-RSA-RSAPSS",
+ args: []string{"-cipher", "AES128", "-Verify", "1", "-client_sigalgs",
+ "rsa_pss_rsae_sha256", "-sigalgs", "rsa_pss_rsae_sha256"},
+ config: config,
+ cert: testRSAPSSCertificate,
+ key: testRSAPrivateKey,
+ }
+ runClientTestTLS12(t, test)
+ runClientTestTLS13(t, test)
+}
+
+func TestHandshakeClientCertRSAPKCS1v15(t *testing.T) {
+ config := testConfig.Clone()
+ cert, _ := X509KeyPair([]byte(clientCertificatePEM), []byte(clientKeyPEM))
+ config.Certificates = []Certificate{cert}
+
+ test := &clientTest{
+ name: "ClientCert-RSA-RSAPKCS1v15",
+ args: []string{"-cipher", "AES128", "-Verify", "1", "-client_sigalgs",
+ "rsa_pkcs1_sha256", "-sigalgs", "rsa_pkcs1_sha256"},
+ config: config,
+ }
+
+ runClientTestTLS12(t, test)
+}
+
+func TestClientKeyUpdate(t *testing.T) {
+ test := &clientTest{
+ name: "KeyUpdate",
+ args: []string{"-state"},
+ sendKeyUpdate: true,
+ }
+ runClientTestTLS13(t, test)
+}
+
+func TestResumption(t *testing.T) {
+ t.Run("TLSv12", func(t *testing.T) { testResumption(t, VersionTLS12) })
+ t.Run("TLSv13", func(t *testing.T) { testResumption(t, VersionTLS13) })
+}
+
+func testResumption(t *testing.T, version uint16) {
+ if testing.Short() {
+ t.Skip("skipping in -short mode")
+ }
+ serverConfig := &Config{
+ MaxVersion: version,
+ CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA},
+ Certificates: testConfig.Certificates,
+ }
+
+ issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
+ if err != nil {
+ panic(err)
+ }
+
+ rootCAs := x509.NewCertPool()
+ rootCAs.AddCert(issuer)
+
+ clientConfig := &Config{
+ MaxVersion: version,
+ CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA},
+ ClientSessionCache: NewLRUClientSessionCache(32),
+ RootCAs: rootCAs,
+ ServerName: "example.golang",
+ }
+
+ testResumeState := func(test string, didResume bool) {
+ _, hs, err := testHandshake(t, clientConfig, serverConfig)
+ if err != nil {
+ t.Fatalf("%s: handshake failed: %s", test, err)
+ }
+ if hs.DidResume != didResume {
+ t.Fatalf("%s resumed: %v, expected: %v", test, hs.DidResume, didResume)
+ }
+ if didResume && (hs.PeerCertificates == nil || hs.VerifiedChains == nil) {
+ t.Fatalf("expected non-nil certificates after resumption. Got peerCertificates: %#v, verifiedCertificates: %#v", hs.PeerCertificates, hs.VerifiedChains)
+ }
+ if got, want := hs.ServerName, clientConfig.ServerName; got != want {
+ t.Errorf("%s: server name %s, want %s", test, got, want)
+ }
+ }
+
+ getTicket := func() []byte {
+ return clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).state.sessionTicket
+ }
+ deleteTicket := func() {
+ ticketKey := clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).sessionKey
+ clientConfig.ClientSessionCache.Put(ticketKey, nil)
+ }
+ corruptTicket := func() {
+ clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).state.masterSecret[0] ^= 0xff
+ }
+ randomKey := func() [32]byte {
+ var k [32]byte
+ if _, err := io.ReadFull(serverConfig.rand(), k[:]); err != nil {
+ t.Fatalf("Failed to read new SessionTicketKey: %s", err)
+ }
+ return k
+ }
+
+ testResumeState("Handshake", false)
+ ticket := getTicket()
+ testResumeState("Resume", true)
+ if !bytes.Equal(ticket, getTicket()) && version != VersionTLS13 {
+ t.Fatal("first ticket doesn't match ticket after resumption")
+ }
+ if bytes.Equal(ticket, getTicket()) && version == VersionTLS13 {
+ t.Fatal("ticket didn't change after resumption")
+ }
+
+ // An old session ticket can resume, but the server will provide a ticket encrypted with a fresh key.
+ serverConfig.Time = func() time.Time { return time.Now().Add(24*time.Hour + time.Minute) }
+ testResumeState("ResumeWithOldTicket", true)
+ if bytes.Equal(ticket[:ticketKeyNameLen], getTicket()[:ticketKeyNameLen]) {
+ t.Fatal("old first ticket matches the fresh one")
+ }
+
+ // Now the session tickey key is expired, so a full handshake should occur.
+ serverConfig.Time = func() time.Time { return time.Now().Add(24*8*time.Hour + time.Minute) }
+ testResumeState("ResumeWithExpiredTicket", false)
+ if bytes.Equal(ticket, getTicket()) {
+ t.Fatal("expired first ticket matches the fresh one")
+ }
+
+ serverConfig.Time = func() time.Time { return time.Now() } // reset the time back
+ key1 := randomKey()
+ serverConfig.SetSessionTicketKeys([][32]byte{key1})
+
+ testResumeState("InvalidSessionTicketKey", false)
+ testResumeState("ResumeAfterInvalidSessionTicketKey", true)
+
+ key2 := randomKey()
+ serverConfig.SetSessionTicketKeys([][32]byte{key2, key1})
+ ticket = getTicket()
+ testResumeState("KeyChange", true)
+ if bytes.Equal(ticket, getTicket()) {
+ t.Fatal("new ticket wasn't included while resuming")
+ }
+ testResumeState("KeyChangeFinish", true)
+
+ // Age the session ticket a bit, but not yet expired.
+ serverConfig.Time = func() time.Time { return time.Now().Add(24*time.Hour + time.Minute) }
+ testResumeState("OldSessionTicket", true)
+ ticket = getTicket()
+ // Expire the session ticket, which would force a full handshake.
+ serverConfig.Time = func() time.Time { return time.Now().Add(24*8*time.Hour + time.Minute) }
+ testResumeState("ExpiredSessionTicket", false)
+ if bytes.Equal(ticket, getTicket()) {
+ t.Fatal("new ticket wasn't provided after old ticket expired")
+ }
+
+ // Age the session ticket a bit at a time, but don't expire it.
+ d := 0 * time.Hour
+ for i := 0; i < 13; i++ {
+ d += 12 * time.Hour
+ serverConfig.Time = func() time.Time { return time.Now().Add(d) }
+ testResumeState("OldSessionTicket", true)
+ }
+ // Expire it (now a little more than 7 days) and make sure a full
+ // handshake occurs for TLS 1.2. Resumption should still occur for
+ // TLS 1.3 since the client should be using a fresh ticket sent over
+ // by the server.
+ d += 12 * time.Hour
+ serverConfig.Time = func() time.Time { return time.Now().Add(d) }
+ if version == VersionTLS13 {
+ testResumeState("ExpiredSessionTicket", true)
+ } else {
+ testResumeState("ExpiredSessionTicket", false)
+ }
+ if bytes.Equal(ticket, getTicket()) {
+ t.Fatal("new ticket wasn't provided after old ticket expired")
+ }
+
+ // Reset serverConfig to ensure that calling SetSessionTicketKeys
+ // before the serverConfig is used works.
+ serverConfig = &Config{
+ MaxVersion: version,
+ CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA},
+ Certificates: testConfig.Certificates,
+ }
+ serverConfig.SetSessionTicketKeys([][32]byte{key2})
+
+ testResumeState("FreshConfig", true)
+
+ // In TLS 1.3, cross-cipher suite resumption is allowed as long as the KDF
+ // hash matches. Also, Config.CipherSuites does not apply to TLS 1.3.
+ if version != VersionTLS13 {
+ clientConfig.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_RC4_128_SHA}
+ testResumeState("DifferentCipherSuite", false)
+ testResumeState("DifferentCipherSuiteRecovers", true)
+ }
+
+ deleteTicket()
+ testResumeState("WithoutSessionTicket", false)
+
+ // Session resumption should work when using client certificates
+ deleteTicket()
+ serverConfig.ClientCAs = rootCAs
+ serverConfig.ClientAuth = RequireAndVerifyClientCert
+ clientConfig.Certificates = serverConfig.Certificates
+ testResumeState("InitialHandshake", false)
+ testResumeState("WithClientCertificates", true)
+ serverConfig.ClientAuth = NoClientCert
+
+ // Tickets should be removed from the session cache on TLS handshake
+ // failure, and the client should recover from a corrupted PSK
+ testResumeState("FetchTicketToCorrupt", false)
+ corruptTicket()
+ _, _, err = testHandshake(t, clientConfig, serverConfig)
+ if err == nil {
+ t.Fatalf("handshake did not fail with a corrupted client secret")
+ }
+ testResumeState("AfterHandshakeFailure", false)
+
+ clientConfig.ClientSessionCache = nil
+ testResumeState("WithoutSessionCache", false)
+}
+
+func TestLRUClientSessionCache(t *testing.T) {
+ // Initialize cache of capacity 4.
+ cache := NewLRUClientSessionCache(4)
+ cs := make([]ClientSessionState, 6)
+ keys := []string{"0", "1", "2", "3", "4", "5", "6"}
+
+ // Add 4 entries to the cache and look them up.
+ for i := 0; i < 4; i++ {
+ cache.Put(keys[i], &cs[i])
+ }
+ for i := 0; i < 4; i++ {
+ if s, ok := cache.Get(keys[i]); !ok || s != &cs[i] {
+ t.Fatalf("session cache failed lookup for added key: %s", keys[i])
+ }
+ }
+
+ // Add 2 more entries to the cache. First 2 should be evicted.
+ for i := 4; i < 6; i++ {
+ cache.Put(keys[i], &cs[i])
+ }
+ for i := 0; i < 2; i++ {
+ if s, ok := cache.Get(keys[i]); ok || s != nil {
+ t.Fatalf("session cache should have evicted key: %s", keys[i])
+ }
+ }
+
+ // Touch entry 2. LRU should evict 3 next.
+ cache.Get(keys[2])
+ cache.Put(keys[0], &cs[0])
+ if s, ok := cache.Get(keys[3]); ok || s != nil {
+ t.Fatalf("session cache should have evicted key 3")
+ }
+
+ // Update entry 0 in place.
+ cache.Put(keys[0], &cs[3])
+ if s, ok := cache.Get(keys[0]); !ok || s != &cs[3] {
+ t.Fatalf("session cache failed update for key 0")
+ }
+
+ // Calling Put with a nil entry deletes the key.
+ cache.Put(keys[0], nil)
+ if _, ok := cache.Get(keys[0]); ok {
+ t.Fatalf("session cache failed to delete key 0")
+ }
+
+ // Delete entry 2. LRU should keep 4 and 5
+ cache.Put(keys[2], nil)
+ if _, ok := cache.Get(keys[2]); ok {
+ t.Fatalf("session cache failed to delete key 4")
+ }
+ for i := 4; i < 6; i++ {
+ if s, ok := cache.Get(keys[i]); !ok || s != &cs[i] {
+ t.Fatalf("session cache should not have deleted key: %s", keys[i])
+ }
+ }
+}
+
+func TestKeyLogTLS12(t *testing.T) {
+ var serverBuf, clientBuf bytes.Buffer
+
+ clientConfig := testConfig.Clone()
+ clientConfig.KeyLogWriter = &clientBuf
+ clientConfig.MaxVersion = VersionTLS12
+
+ serverConfig := testConfig.Clone()
+ serverConfig.KeyLogWriter = &serverBuf
+ serverConfig.MaxVersion = VersionTLS12
+
+ c, s := localPipe(t)
+ done := make(chan bool)
+
+ go func() {
+ defer close(done)
+
+ if err := Server(s, serverConfig).Handshake(); err != nil {
+ t.Errorf("server: %s", err)
+ return
+ }
+ s.Close()
+ }()
+
+ if err := Client(c, clientConfig).Handshake(); err != nil {
+ t.Fatalf("client: %s", err)
+ }
+
+ c.Close()
+ <-done
+
+ checkKeylogLine := func(side, loggedLine string) {
+ if len(loggedLine) == 0 {
+ t.Fatalf("%s: no keylog line was produced", side)
+ }
+ const expectedLen = 13 /* "CLIENT_RANDOM" */ +
+ 1 /* space */ +
+ 32*2 /* hex client nonce */ +
+ 1 /* space */ +
+ 48*2 /* hex master secret */ +
+ 1 /* new line */
+ if len(loggedLine) != expectedLen {
+ t.Fatalf("%s: keylog line has incorrect length (want %d, got %d): %q", side, expectedLen, len(loggedLine), loggedLine)
+ }
+ if !strings.HasPrefix(loggedLine, "CLIENT_RANDOM "+strings.Repeat("0", 64)+" ") {
+ t.Fatalf("%s: keylog line has incorrect structure or nonce: %q", side, loggedLine)
+ }
+ }
+
+ checkKeylogLine("client", clientBuf.String())
+ checkKeylogLine("server", serverBuf.String())
+}
+
+func TestKeyLogTLS13(t *testing.T) {
+ var serverBuf, clientBuf bytes.Buffer
+
+ clientConfig := testConfig.Clone()
+ clientConfig.KeyLogWriter = &clientBuf
+
+ serverConfig := testConfig.Clone()
+ serverConfig.KeyLogWriter = &serverBuf
+
+ c, s := localPipe(t)
+ done := make(chan bool)
+
+ go func() {
+ defer close(done)
+
+ if err := Server(s, serverConfig).Handshake(); err != nil {
+ t.Errorf("server: %s", err)
+ return
+ }
+ s.Close()
+ }()
+
+ if err := Client(c, clientConfig).Handshake(); err != nil {
+ t.Fatalf("client: %s", err)
+ }
+
+ c.Close()
+ <-done
+
+ checkKeylogLines := func(side, loggedLines string) {
+ loggedLines = strings.TrimSpace(loggedLines)
+ lines := strings.Split(loggedLines, "\n")
+ if len(lines) != 4 {
+ t.Errorf("Expected the %s to log 4 lines, got %d", side, len(lines))
+ }
+ }
+
+ checkKeylogLines("client", clientBuf.String())
+ checkKeylogLines("server", serverBuf.String())
+}
+
+func TestHandshakeClientALPNMatch(t *testing.T) {
+ config := testConfig.Clone()
+ config.NextProtos = []string{"proto2", "proto1"}
+
+ test := &clientTest{
+ name: "ALPN",
+ // Note that this needs OpenSSL 1.0.2 because that is the first
+ // version that supports the -alpn flag.
+ args: []string{"-alpn", "proto1,proto2"},
+ config: config,
+ validate: func(state ConnectionState) error {
+ // The server's preferences should override the client.
+ if state.NegotiatedProtocol != "proto1" {
+ return fmt.Errorf("Got protocol %q, wanted proto1", state.NegotiatedProtocol)
+ }
+ return nil
+ },
+ }
+ runClientTestTLS12(t, test)
+ runClientTestTLS13(t, test)
+}
+
+func TestServerSelectingUnconfiguredApplicationProtocol(t *testing.T) {
+ // This checks that the server can't select an application protocol that the
+ // client didn't offer.
+
+ c, s := localPipe(t)
+ errChan := make(chan error, 1)
+
+ go func() {
+ client := Client(c, &Config{
+ ServerName: "foo",
+ CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
+ NextProtos: []string{"http", "something-else"},
+ })
+ errChan <- client.Handshake()
+ }()
+
+ var header [5]byte
+ if _, err := io.ReadFull(s, header[:]); err != nil {
+ t.Fatal(err)
+ }
+ recordLen := int(header[3])<<8 | int(header[4])
+
+ record := make([]byte, recordLen)
+ if _, err := io.ReadFull(s, record); err != nil {
+ t.Fatal(err)
+ }
+
+ serverHello := &serverHelloMsg{
+ vers: VersionTLS12,
+ random: make([]byte, 32),
+ cipherSuite: TLS_RSA_WITH_AES_128_GCM_SHA256,
+ alpnProtocol: "how-about-this",
+ }
+ serverHelloBytes := mustMarshal(t, serverHello)
+
+ s.Write([]byte{
+ byte(recordTypeHandshake),
+ byte(VersionTLS12 >> 8),
+ byte(VersionTLS12 & 0xff),
+ byte(len(serverHelloBytes) >> 8),
+ byte(len(serverHelloBytes)),
+ })
+ s.Write(serverHelloBytes)
+ s.Close()
+
+ if err := <-errChan; !strings.Contains(err.Error(), "server selected unadvertised ALPN protocol") {
+ t.Fatalf("Expected error about unconfigured cipher suite but got %q", err)
+ }
+}
+
+// sctsBase64 contains data from `openssl s_client -serverinfo 18 -connect ritter.vg:443`
+const sctsBase64 = "ABIBaQFnAHUApLkJkLQYWBSHuxOizGdwCjw1mAT5G9+443fNDsgN3BAAAAFHl5nuFgAABAMARjBEAiAcS4JdlW5nW9sElUv2zvQyPoZ6ejKrGGB03gjaBZFMLwIgc1Qbbn+hsH0RvObzhS+XZhr3iuQQJY8S9G85D9KeGPAAdgBo9pj4H2SCvjqM7rkoHUz8cVFdZ5PURNEKZ6y7T0/7xAAAAUeX4bVwAAAEAwBHMEUCIDIhFDgG2HIuADBkGuLobU5a4dlCHoJLliWJ1SYT05z6AiEAjxIoZFFPRNWMGGIjskOTMwXzQ1Wh2e7NxXE1kd1J0QsAdgDuS723dc5guuFCaR+r4Z5mow9+X7By2IMAxHuJeqj9ywAAAUhcZIqHAAAEAwBHMEUCICmJ1rBT09LpkbzxtUC+Hi7nXLR0J+2PmwLp+sJMuqK+AiEAr0NkUnEVKVhAkccIFpYDqHOlZaBsuEhWWrYpg2RtKp0="
+
+func TestHandshakClientSCTs(t *testing.T) {
+ config := testConfig.Clone()
+
+ scts, err := base64.StdEncoding.DecodeString(sctsBase64)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Note that this needs OpenSSL 1.0.2 because that is the first
+ // version that supports the -serverinfo flag.
+ test := &clientTest{
+ name: "SCT",
+ config: config,
+ extensions: [][]byte{scts},
+ validate: func(state ConnectionState) error {
+ expectedSCTs := [][]byte{
+ scts[8:125],
+ scts[127:245],
+ scts[247:],
+ }
+ if n := len(state.SignedCertificateTimestamps); n != len(expectedSCTs) {
+ return fmt.Errorf("Got %d scts, wanted %d", n, len(expectedSCTs))
+ }
+ for i, expected := range expectedSCTs {
+ if sct := state.SignedCertificateTimestamps[i]; !bytes.Equal(sct, expected) {
+ return fmt.Errorf("SCT #%d contained %x, expected %x", i, sct, expected)
+ }
+ }
+ return nil
+ },
+ }
+ runClientTestTLS12(t, test)
+
+ // TLS 1.3 moved SCTs to the Certificate extensions and -serverinfo only
+ // supports ServerHello extensions.
+}
+
+func TestRenegotiationRejected(t *testing.T) {
+ config := testConfig.Clone()
+ test := &clientTest{
+ name: "RenegotiationRejected",
+ args: []string{"-state"},
+ config: config,
+ numRenegotiations: 1,
+ renegotiationExpectedToFail: 1,
+ checkRenegotiationError: func(renegotiationNum int, err error) error {
+ if err == nil {
+ return errors.New("expected error from renegotiation but got nil")
+ }
+ if !strings.Contains(err.Error(), "no renegotiation") {
+ return fmt.Errorf("expected renegotiation to be rejected but got %q", err)
+ }
+ return nil
+ },
+ }
+ runClientTestTLS12(t, test)
+}
+
+func TestRenegotiateOnce(t *testing.T) {
+ config := testConfig.Clone()
+ config.Renegotiation = RenegotiateOnceAsClient
+
+ test := &clientTest{
+ name: "RenegotiateOnce",
+ args: []string{"-state"},
+ config: config,
+ numRenegotiations: 1,
+ }
+
+ runClientTestTLS12(t, test)
+}
+
+func TestRenegotiateTwice(t *testing.T) {
+ config := testConfig.Clone()
+ config.Renegotiation = RenegotiateFreelyAsClient
+
+ test := &clientTest{
+ name: "RenegotiateTwice",
+ args: []string{"-state"},
+ config: config,
+ numRenegotiations: 2,
+ }
+
+ runClientTestTLS12(t, test)
+}
+
+func TestRenegotiateTwiceRejected(t *testing.T) {
+ config := testConfig.Clone()
+ config.Renegotiation = RenegotiateOnceAsClient
+
+ test := &clientTest{
+ name: "RenegotiateTwiceRejected",
+ args: []string{"-state"},
+ config: config,
+ numRenegotiations: 2,
+ renegotiationExpectedToFail: 2,
+ checkRenegotiationError: func(renegotiationNum int, err error) error {
+ if renegotiationNum == 1 {
+ return err
+ }
+
+ if err == nil {
+ return errors.New("expected error from renegotiation but got nil")
+ }
+ if !strings.Contains(err.Error(), "no renegotiation") {
+ return fmt.Errorf("expected renegotiation to be rejected but got %q", err)
+ }
+ return nil
+ },
+ }
+
+ runClientTestTLS12(t, test)
+}
+
+func TestHandshakeClientExportKeyingMaterial(t *testing.T) {
+ test := &clientTest{
+ name: "ExportKeyingMaterial",
+ config: testConfig.Clone(),
+ validate: func(state ConnectionState) error {
+ if km, err := state.ExportKeyingMaterial("test", nil, 42); err != nil {
+ return fmt.Errorf("ExportKeyingMaterial failed: %v", err)
+ } else if len(km) != 42 {
+ return fmt.Errorf("Got %d bytes from ExportKeyingMaterial, wanted %d", len(km), 42)
+ }
+ return nil
+ },
+ }
+ runClientTestTLS10(t, test)
+ runClientTestTLS12(t, test)
+ runClientTestTLS13(t, test)
+}
+
+var hostnameInSNITests = []struct {
+ in, out string
+}{
+ // Opaque string
+ {"", ""},
+ {"localhost", "localhost"},
+ {"foo, bar, baz and qux", "foo, bar, baz and qux"},
+
+ // DNS hostname
+ {"golang.org", "golang.org"},
+ {"golang.org.", "golang.org"},
+
+ // Literal IPv4 address
+ {"1.2.3.4", ""},
+
+ // Literal IPv6 address
+ {"::1", ""},
+ {"::1%lo0", ""}, // with zone identifier
+ {"[::1]", ""}, // as per RFC 5952 we allow the [] style as IPv6 literal
+ {"[::1%lo0]", ""},
+}
+
+func TestHostnameInSNI(t *testing.T) {
+ for _, tt := range hostnameInSNITests {
+ c, s := localPipe(t)
+
+ go func(host string) {
+ Client(c, &Config{ServerName: host, InsecureSkipVerify: true}).Handshake()
+ }(tt.in)
+
+ var header [5]byte
+ if _, err := io.ReadFull(s, header[:]); err != nil {
+ t.Fatal(err)
+ }
+ recordLen := int(header[3])<<8 | int(header[4])
+
+ record := make([]byte, recordLen)
+ if _, err := io.ReadFull(s, record[:]); err != nil {
+ t.Fatal(err)
+ }
+
+ c.Close()
+ s.Close()
+
+ var m clientHelloMsg
+ if !m.unmarshal(record) {
+ t.Errorf("unmarshaling ClientHello for %q failed", tt.in)
+ continue
+ }
+ if tt.in != tt.out && m.serverName == tt.in {
+ t.Errorf("prohibited %q found in ClientHello: %x", tt.in, record)
+ }
+ if m.serverName != tt.out {
+ t.Errorf("expected %q not found in ClientHello: %x", tt.out, record)
+ }
+ }
+}
+
+func TestServerSelectingUnconfiguredCipherSuite(t *testing.T) {
+ // This checks that the server can't select a cipher suite that the
+ // client didn't offer. See #13174.
+
+ c, s := localPipe(t)
+ errChan := make(chan error, 1)
+
+ go func() {
+ client := Client(c, &Config{
+ ServerName: "foo",
+ CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
+ })
+ errChan <- client.Handshake()
+ }()
+
+ var header [5]byte
+ if _, err := io.ReadFull(s, header[:]); err != nil {
+ t.Fatal(err)
+ }
+ recordLen := int(header[3])<<8 | int(header[4])
+
+ record := make([]byte, recordLen)
+ if _, err := io.ReadFull(s, record); err != nil {
+ t.Fatal(err)
+ }
+
+ // Create a ServerHello that selects a different cipher suite than the
+ // sole one that the client offered.
+ serverHello := &serverHelloMsg{
+ vers: VersionTLS12,
+ random: make([]byte, 32),
+ cipherSuite: TLS_RSA_WITH_AES_256_GCM_SHA384,
+ }
+ serverHelloBytes := mustMarshal(t, serverHello)
+
+ s.Write([]byte{
+ byte(recordTypeHandshake),
+ byte(VersionTLS12 >> 8),
+ byte(VersionTLS12 & 0xff),
+ byte(len(serverHelloBytes) >> 8),
+ byte(len(serverHelloBytes)),
+ })
+ s.Write(serverHelloBytes)
+ s.Close()
+
+ if err := <-errChan; !strings.Contains(err.Error(), "unconfigured cipher") {
+ t.Fatalf("Expected error about unconfigured cipher suite but got %q", err)
+ }
+}
+
+func TestVerifyConnection(t *testing.T) {
+ t.Run("TLSv12", func(t *testing.T) { testVerifyConnection(t, VersionTLS12) })
+ t.Run("TLSv13", func(t *testing.T) { testVerifyConnection(t, VersionTLS13) })
+}
+
+func testVerifyConnection(t *testing.T, version uint16) {
+ checkFields := func(c ConnectionState, called *int, errorType string) error {
+ if c.Version != version {
+ return fmt.Errorf("%s: got Version %v, want %v", errorType, c.Version, version)
+ }
+ if c.HandshakeComplete {
+ return fmt.Errorf("%s: got HandshakeComplete, want false", errorType)
+ }
+ if c.ServerName != "example.golang" {
+ return fmt.Errorf("%s: got ServerName %s, want %s", errorType, c.ServerName, "example.golang")
+ }
+ if c.NegotiatedProtocol != "protocol1" {
+ return fmt.Errorf("%s: got NegotiatedProtocol %s, want %s", errorType, c.NegotiatedProtocol, "protocol1")
+ }
+ if c.CipherSuite == 0 {
+ return fmt.Errorf("%s: got CipherSuite 0, want non-zero", errorType)
+ }
+ wantDidResume := false
+ if *called == 2 { // if this is the second time, then it should be a resumption
+ wantDidResume = true
+ }
+ if c.DidResume != wantDidResume {
+ return fmt.Errorf("%s: got DidResume %t, want %t", errorType, c.DidResume, wantDidResume)
+ }
+ return nil
+ }
+
+ tests := []struct {
+ name string
+ configureServer func(*Config, *int)
+ configureClient func(*Config, *int)
+ }{
+ {
+ name: "RequireAndVerifyClientCert",
+ configureServer: func(config *Config, called *int) {
+ config.ClientAuth = RequireAndVerifyClientCert
+ config.VerifyConnection = func(c ConnectionState) error {
+ *called++
+ if l := len(c.PeerCertificates); l != 1 {
+ return fmt.Errorf("server: got len(PeerCertificates) = %d, wanted 1", l)
+ }
+ if len(c.VerifiedChains) == 0 {
+ return fmt.Errorf("server: got len(VerifiedChains) = 0, wanted non-zero")
+ }
+ return checkFields(c, called, "server")
+ }
+ },
+ configureClient: func(config *Config, called *int) {
+ config.VerifyConnection = func(c ConnectionState) error {
+ *called++
+ if l := len(c.PeerCertificates); l != 1 {
+ return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l)
+ }
+ if len(c.VerifiedChains) == 0 {
+ return fmt.Errorf("client: got len(VerifiedChains) = 0, wanted non-zero")
+ }
+ if c.DidResume {
+ return nil
+ // The SCTs and OCSP Response are dropped on resumption.
+ // See http://golang.org/issue/39075.
+ }
+ if len(c.OCSPResponse) == 0 {
+ return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero")
+ }
+ if len(c.SignedCertificateTimestamps) == 0 {
+ return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero")
+ }
+ return checkFields(c, called, "client")
+ }
+ },
+ },
+ {
+ name: "InsecureSkipVerify",
+ configureServer: func(config *Config, called *int) {
+ config.ClientAuth = RequireAnyClientCert
+ config.InsecureSkipVerify = true
+ config.VerifyConnection = func(c ConnectionState) error {
+ *called++
+ if l := len(c.PeerCertificates); l != 1 {
+ return fmt.Errorf("server: got len(PeerCertificates) = %d, wanted 1", l)
+ }
+ if c.VerifiedChains != nil {
+ return fmt.Errorf("server: got Verified Chains %v, want nil", c.VerifiedChains)
+ }
+ return checkFields(c, called, "server")
+ }
+ },
+ configureClient: func(config *Config, called *int) {
+ config.InsecureSkipVerify = true
+ config.VerifyConnection = func(c ConnectionState) error {
+ *called++
+ if l := len(c.PeerCertificates); l != 1 {
+ return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l)
+ }
+ if c.VerifiedChains != nil {
+ return fmt.Errorf("server: got Verified Chains %v, want nil", c.VerifiedChains)
+ }
+ if c.DidResume {
+ return nil
+ // The SCTs and OCSP Response are dropped on resumption.
+ // See http://golang.org/issue/39075.
+ }
+ if len(c.OCSPResponse) == 0 {
+ return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero")
+ }
+ if len(c.SignedCertificateTimestamps) == 0 {
+ return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero")
+ }
+ return checkFields(c, called, "client")
+ }
+ },
+ },
+ {
+ name: "NoClientCert",
+ configureServer: func(config *Config, called *int) {
+ config.ClientAuth = NoClientCert
+ config.VerifyConnection = func(c ConnectionState) error {
+ *called++
+ return checkFields(c, called, "server")
+ }
+ },
+ configureClient: func(config *Config, called *int) {
+ config.VerifyConnection = func(c ConnectionState) error {
+ *called++
+ return checkFields(c, called, "client")
+ }
+ },
+ },
+ {
+ name: "RequestClientCert",
+ configureServer: func(config *Config, called *int) {
+ config.ClientAuth = RequestClientCert
+ config.VerifyConnection = func(c ConnectionState) error {
+ *called++
+ return checkFields(c, called, "server")
+ }
+ },
+ configureClient: func(config *Config, called *int) {
+ config.Certificates = nil // clear the client cert
+ config.VerifyConnection = func(c ConnectionState) error {
+ *called++
+ if l := len(c.PeerCertificates); l != 1 {
+ return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l)
+ }
+ if len(c.VerifiedChains) == 0 {
+ return fmt.Errorf("client: got len(VerifiedChains) = 0, wanted non-zero")
+ }
+ if c.DidResume {
+ return nil
+ // The SCTs and OCSP Response are dropped on resumption.
+ // See http://golang.org/issue/39075.
+ }
+ if len(c.OCSPResponse) == 0 {
+ return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero")
+ }
+ if len(c.SignedCertificateTimestamps) == 0 {
+ return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero")
+ }
+ return checkFields(c, called, "client")
+ }
+ },
+ },
+ }
+ for _, test := range tests {
+ issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
+ if err != nil {
+ panic(err)
+ }
+ rootCAs := x509.NewCertPool()
+ rootCAs.AddCert(issuer)
+
+ var serverCalled, clientCalled int
+
+ serverConfig := &Config{
+ MaxVersion: version,
+ Certificates: []Certificate{testConfig.Certificates[0]},
+ ClientCAs: rootCAs,
+ NextProtos: []string{"protocol1"},
+ }
+ serverConfig.Certificates[0].SignedCertificateTimestamps = [][]byte{[]byte("dummy sct 1"), []byte("dummy sct 2")}
+ serverConfig.Certificates[0].OCSPStaple = []byte("dummy ocsp")
+ test.configureServer(serverConfig, &serverCalled)
+
+ clientConfig := &Config{
+ MaxVersion: version,
+ ClientSessionCache: NewLRUClientSessionCache(32),
+ RootCAs: rootCAs,
+ ServerName: "example.golang",
+ Certificates: []Certificate{testConfig.Certificates[0]},
+ NextProtos: []string{"protocol1"},
+ }
+ test.configureClient(clientConfig, &clientCalled)
+
+ testHandshakeState := func(name string, didResume bool) {
+ _, hs, err := testHandshake(t, clientConfig, serverConfig)
+ if err != nil {
+ t.Fatalf("%s: handshake failed: %s", name, err)
+ }
+ if hs.DidResume != didResume {
+ t.Errorf("%s: resumed: %v, expected: %v", name, hs.DidResume, didResume)
+ }
+ wantCalled := 1
+ if didResume {
+ wantCalled = 2 // resumption would mean this is the second time it was called in this test
+ }
+ if clientCalled != wantCalled {
+ t.Errorf("%s: expected client VerifyConnection called %d times, did %d times", name, wantCalled, clientCalled)
+ }
+ if serverCalled != wantCalled {
+ t.Errorf("%s: expected server VerifyConnection called %d times, did %d times", name, wantCalled, serverCalled)
+ }
+ }
+ testHandshakeState(fmt.Sprintf("%s-FullHandshake", test.name), false)
+ testHandshakeState(fmt.Sprintf("%s-Resumption", test.name), true)
+ }
+}
+
+func TestVerifyPeerCertificate(t *testing.T) {
+ t.Run("TLSv12", func(t *testing.T) { testVerifyPeerCertificate(t, VersionTLS12) })
+ t.Run("TLSv13", func(t *testing.T) { testVerifyPeerCertificate(t, VersionTLS13) })
+}
+
+func testVerifyPeerCertificate(t *testing.T, version uint16) {
+ issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
+ if err != nil {
+ panic(err)
+ }
+
+ rootCAs := x509.NewCertPool()
+ rootCAs.AddCert(issuer)
+
+ now := func() time.Time { return time.Unix(1476984729, 0) }
+
+ sentinelErr := errors.New("TestVerifyPeerCertificate")
+
+ verifyPeerCertificateCallback := func(called *bool, rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
+ if l := len(rawCerts); l != 1 {
+ return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l)
+ }
+ if len(validatedChains) == 0 {
+ return errors.New("got len(validatedChains) = 0, wanted non-zero")
+ }
+ *called = true
+ return nil
+ }
+ verifyConnectionCallback := func(called *bool, isClient bool, c ConnectionState) error {
+ if l := len(c.PeerCertificates); l != 1 {
+ return fmt.Errorf("got len(PeerCertificates) = %d, wanted 1", l)
+ }
+ if len(c.VerifiedChains) == 0 {
+ return fmt.Errorf("got len(VerifiedChains) = 0, wanted non-zero")
+ }
+ if isClient && len(c.OCSPResponse) == 0 {
+ return fmt.Errorf("got len(OCSPResponse) = 0, wanted non-zero")
+ }
+ *called = true
+ return nil
+ }
+
+ tests := []struct {
+ configureServer func(*Config, *bool)
+ configureClient func(*Config, *bool)
+ validate func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error)
+ }{
+ {
+ configureServer: func(config *Config, called *bool) {
+ config.InsecureSkipVerify = false
+ config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
+ return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
+ }
+ },
+ configureClient: func(config *Config, called *bool) {
+ config.InsecureSkipVerify = false
+ config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
+ return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
+ }
+ },
+ validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
+ if clientErr != nil {
+ t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr)
+ }
+ if serverErr != nil {
+ t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr)
+ }
+ if !clientCalled {
+ t.Errorf("test[%d]: client did not call callback", testNo)
+ }
+ if !serverCalled {
+ t.Errorf("test[%d]: server did not call callback", testNo)
+ }
+ },
+ },
+ {
+ configureServer: func(config *Config, called *bool) {
+ config.InsecureSkipVerify = false
+ config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
+ return sentinelErr
+ }
+ },
+ configureClient: func(config *Config, called *bool) {
+ config.VerifyPeerCertificate = nil
+ },
+ validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
+ if serverErr != sentinelErr {
+ t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr)
+ }
+ },
+ },
+ {
+ configureServer: func(config *Config, called *bool) {
+ config.InsecureSkipVerify = false
+ },
+ configureClient: func(config *Config, called *bool) {
+ config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
+ return sentinelErr
+ }
+ },
+ validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
+ if clientErr != sentinelErr {
+ t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr)
+ }
+ },
+ },
+ {
+ configureServer: func(config *Config, called *bool) {
+ config.InsecureSkipVerify = false
+ },
+ configureClient: func(config *Config, called *bool) {
+ config.InsecureSkipVerify = true
+ config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
+ if l := len(rawCerts); l != 1 {
+ return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l)
+ }
+ // With InsecureSkipVerify set, this
+ // callback should still be called but
+ // validatedChains must be empty.
+ if l := len(validatedChains); l != 0 {
+ return fmt.Errorf("got len(validatedChains) = %d, wanted zero", l)
+ }
+ *called = true
+ return nil
+ }
+ },
+ validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
+ if clientErr != nil {
+ t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr)
+ }
+ if serverErr != nil {
+ t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr)
+ }
+ if !clientCalled {
+ t.Errorf("test[%d]: client did not call callback", testNo)
+ }
+ },
+ },
+ {
+ configureServer: func(config *Config, called *bool) {
+ config.InsecureSkipVerify = false
+ config.VerifyConnection = func(c ConnectionState) error {
+ return verifyConnectionCallback(called, false, c)
+ }
+ },
+ configureClient: func(config *Config, called *bool) {
+ config.InsecureSkipVerify = false
+ config.VerifyConnection = func(c ConnectionState) error {
+ return verifyConnectionCallback(called, true, c)
+ }
+ },
+ validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
+ if clientErr != nil {
+ t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr)
+ }
+ if serverErr != nil {
+ t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr)
+ }
+ if !clientCalled {
+ t.Errorf("test[%d]: client did not call callback", testNo)
+ }
+ if !serverCalled {
+ t.Errorf("test[%d]: server did not call callback", testNo)
+ }
+ },
+ },
+ {
+ configureServer: func(config *Config, called *bool) {
+ config.InsecureSkipVerify = false
+ config.VerifyConnection = func(c ConnectionState) error {
+ return sentinelErr
+ }
+ },
+ configureClient: func(config *Config, called *bool) {
+ config.InsecureSkipVerify = false
+ config.VerifyConnection = nil
+ },
+ validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
+ if serverErr != sentinelErr {
+ t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr)
+ }
+ },
+ },
+ {
+ configureServer: func(config *Config, called *bool) {
+ config.InsecureSkipVerify = false
+ config.VerifyConnection = nil
+ },
+ configureClient: func(config *Config, called *bool) {
+ config.InsecureSkipVerify = false
+ config.VerifyConnection = func(c ConnectionState) error {
+ return sentinelErr
+ }
+ },
+ validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
+ if clientErr != sentinelErr {
+ t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr)
+ }
+ },
+ },
+ {
+ configureServer: func(config *Config, called *bool) {
+ config.InsecureSkipVerify = false
+ config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
+ return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
+ }
+ config.VerifyConnection = func(c ConnectionState) error {
+ return sentinelErr
+ }
+ },
+ configureClient: func(config *Config, called *bool) {
+ config.InsecureSkipVerify = false
+ config.VerifyPeerCertificate = nil
+ config.VerifyConnection = nil
+ },
+ validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
+ if serverErr != sentinelErr {
+ t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr)
+ }
+ if !serverCalled {
+ t.Errorf("test[%d]: server did not call callback", testNo)
+ }
+ },
+ },
+ {
+ configureServer: func(config *Config, called *bool) {
+ config.InsecureSkipVerify = false
+ config.VerifyPeerCertificate = nil
+ config.VerifyConnection = nil
+ },
+ configureClient: func(config *Config, called *bool) {
+ config.InsecureSkipVerify = false
+ config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
+ return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
+ }
+ config.VerifyConnection = func(c ConnectionState) error {
+ return sentinelErr
+ }
+ },
+ validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
+ if clientErr != sentinelErr {
+ t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr)
+ }
+ if !clientCalled {
+ t.Errorf("test[%d]: client did not call callback", testNo)
+ }
+ },
+ },
+ }
+
+ for i, test := range tests {
+ c, s := localPipe(t)
+ done := make(chan error)
+
+ var clientCalled, serverCalled bool
+
+ go func() {
+ config := testConfig.Clone()
+ config.ServerName = "example.golang"
+ config.ClientAuth = RequireAndVerifyClientCert
+ config.ClientCAs = rootCAs
+ config.Time = now
+ config.MaxVersion = version
+ config.Certificates = make([]Certificate, 1)
+ config.Certificates[0].Certificate = [][]byte{testRSACertificate}
+ config.Certificates[0].PrivateKey = testRSAPrivateKey
+ config.Certificates[0].SignedCertificateTimestamps = [][]byte{[]byte("dummy sct 1"), []byte("dummy sct 2")}
+ config.Certificates[0].OCSPStaple = []byte("dummy ocsp")
+ test.configureServer(config, &serverCalled)
+
+ err = Server(s, config).Handshake()
+ s.Close()
+ done <- err
+ }()
+
+ config := testConfig.Clone()
+ config.ServerName = "example.golang"
+ config.RootCAs = rootCAs
+ config.Time = now
+ config.MaxVersion = version
+ test.configureClient(config, &clientCalled)
+ clientErr := Client(c, config).Handshake()
+ c.Close()
+ serverErr := <-done
+
+ test.validate(t, i, clientCalled, serverCalled, clientErr, serverErr)
+ }
+}
+
+// brokenConn wraps a net.Conn and causes all Writes after a certain number to
+// fail with brokenConnErr.
+type brokenConn struct {
+ net.Conn
+
+ // breakAfter is the number of successful writes that will be allowed
+ // before all subsequent writes fail.
+ breakAfter int
+
+ // numWrites is the number of writes that have been done.
+ numWrites int
+}
+
+// brokenConnErr is the error that brokenConn returns once exhausted.
+var brokenConnErr = errors.New("too many writes to brokenConn")
+
+func (b *brokenConn) Write(data []byte) (int, error) {
+ if b.numWrites >= b.breakAfter {
+ return 0, brokenConnErr
+ }
+
+ b.numWrites++
+ return b.Conn.Write(data)
+}
+
+func TestFailedWrite(t *testing.T) {
+ // Test that a write error during the handshake is returned.
+ for _, breakAfter := range []int{0, 1} {
+ c, s := localPipe(t)
+ done := make(chan bool)
+
+ go func() {
+ Server(s, testConfig).Handshake()
+ s.Close()
+ done <- true
+ }()
+
+ brokenC := &brokenConn{Conn: c, breakAfter: breakAfter}
+ err := Client(brokenC, testConfig).Handshake()
+ if err != brokenConnErr {
+ t.Errorf("#%d: expected error from brokenConn but got %q", breakAfter, err)
+ }
+ brokenC.Close()
+
+ <-done
+ }
+}
+
+// writeCountingConn wraps a net.Conn and counts the number of Write calls.
+type writeCountingConn struct {
+ net.Conn
+
+ // numWrites is the number of writes that have been done.
+ numWrites int
+}
+
+func (wcc *writeCountingConn) Write(data []byte) (int, error) {
+ wcc.numWrites++
+ return wcc.Conn.Write(data)
+}
+
+func TestBuffering(t *testing.T) {
+ t.Run("TLSv12", func(t *testing.T) { testBuffering(t, VersionTLS12) })
+ t.Run("TLSv13", func(t *testing.T) { testBuffering(t, VersionTLS13) })
+}
+
+func testBuffering(t *testing.T, version uint16) {
+ c, s := localPipe(t)
+ done := make(chan bool)
+
+ clientWCC := &writeCountingConn{Conn: c}
+ serverWCC := &writeCountingConn{Conn: s}
+
+ go func() {
+ config := testConfig.Clone()
+ config.MaxVersion = version
+ Server(serverWCC, config).Handshake()
+ serverWCC.Close()
+ done <- true
+ }()
+
+ err := Client(clientWCC, testConfig).Handshake()
+ if err != nil {
+ t.Fatal(err)
+ }
+ clientWCC.Close()
+ <-done
+
+ var expectedClient, expectedServer int
+ if version == VersionTLS13 {
+ expectedClient = 2
+ expectedServer = 1
+ } else {
+ expectedClient = 2
+ expectedServer = 2
+ }
+
+ if n := clientWCC.numWrites; n != expectedClient {
+ t.Errorf("expected client handshake to complete with %d writes, but saw %d", expectedClient, n)
+ }
+
+ if n := serverWCC.numWrites; n != expectedServer {
+ t.Errorf("expected server handshake to complete with %d writes, but saw %d", expectedServer, n)
+ }
+}
+
+func TestAlertFlushing(t *testing.T) {
+ c, s := localPipe(t)
+ done := make(chan bool)
+
+ clientWCC := &writeCountingConn{Conn: c}
+ serverWCC := &writeCountingConn{Conn: s}
+
+ serverConfig := testConfig.Clone()
+
+ // Cause a signature-time error
+ brokenKey := rsa.PrivateKey{PublicKey: testRSAPrivateKey.PublicKey}
+ brokenKey.D = big.NewInt(42)
+ serverConfig.Certificates = []Certificate{{
+ Certificate: [][]byte{testRSACertificate},
+ PrivateKey: &brokenKey,
+ }}
+
+ go func() {
+ Server(serverWCC, serverConfig).Handshake()
+ serverWCC.Close()
+ done <- true
+ }()
+
+ err := Client(clientWCC, testConfig).Handshake()
+ if err == nil {
+ t.Fatal("client unexpectedly returned no error")
+ }
+
+ const expectedError = "remote error: tls: internal error"
+ if e := err.Error(); !strings.Contains(e, expectedError) {
+ t.Fatalf("expected to find %q in error but error was %q", expectedError, e)
+ }
+ clientWCC.Close()
+ <-done
+
+ if n := serverWCC.numWrites; n != 1 {
+ t.Errorf("expected server handshake to complete with one write, but saw %d", n)
+ }
+}
+
+func TestHandshakeRace(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping in -short mode")
+ }
+ t.Parallel()
+ // This test races a Read and Write to try and complete a handshake in
+ // order to provide some evidence that there are no races or deadlocks
+ // in the handshake locking.
+ for i := 0; i < 32; i++ {
+ c, s := localPipe(t)
+
+ go func() {
+ server := Server(s, testConfig)
+ if err := server.Handshake(); err != nil {
+ panic(err)
+ }
+
+ var request [1]byte
+ if n, err := server.Read(request[:]); err != nil || n != 1 {
+ panic(err)
+ }
+
+ server.Write(request[:])
+ server.Close()
+ }()
+
+ startWrite := make(chan struct{})
+ startRead := make(chan struct{})
+ readDone := make(chan struct{}, 1)
+
+ client := Client(c, testConfig)
+ go func() {
+ <-startWrite
+ var request [1]byte
+ client.Write(request[:])
+ }()
+
+ go func() {
+ <-startRead
+ var reply [1]byte
+ if _, err := io.ReadFull(client, reply[:]); err != nil {
+ panic(err)
+ }
+ c.Close()
+ readDone <- struct{}{}
+ }()
+
+ if i&1 == 1 {
+ startWrite <- struct{}{}
+ startRead <- struct{}{}
+ } else {
+ startRead <- struct{}{}
+ startWrite <- struct{}{}
+ }
+ <-readDone
+ }
+}
+
+var getClientCertificateTests = []struct {
+ setup func(*Config, *Config)
+ expectedClientError string
+ verify func(*testing.T, int, *ConnectionState)
+}{
+ {
+ func(clientConfig, serverConfig *Config) {
+ // Returning a Certificate with no certificate data
+ // should result in an empty message being sent to the
+ // server.
+ serverConfig.ClientCAs = nil
+ clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
+ if len(cri.SignatureSchemes) == 0 {
+ panic("empty SignatureSchemes")
+ }
+ if len(cri.AcceptableCAs) != 0 {
+ panic("AcceptableCAs should have been empty")
+ }
+ return new(Certificate), nil
+ }
+ },
+ "",
+ func(t *testing.T, testNum int, cs *ConnectionState) {
+ if l := len(cs.PeerCertificates); l != 0 {
+ t.Errorf("#%d: expected no certificates but got %d", testNum, l)
+ }
+ },
+ },
+ {
+ func(clientConfig, serverConfig *Config) {
+ // With TLS 1.1, the SignatureSchemes should be
+ // synthesised from the supported certificate types.
+ clientConfig.MaxVersion = VersionTLS11
+ clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
+ if len(cri.SignatureSchemes) == 0 {
+ panic("empty SignatureSchemes")
+ }
+ return new(Certificate), nil
+ }
+ },
+ "",
+ func(t *testing.T, testNum int, cs *ConnectionState) {
+ if l := len(cs.PeerCertificates); l != 0 {
+ t.Errorf("#%d: expected no certificates but got %d", testNum, l)
+ }
+ },
+ },
+ {
+ func(clientConfig, serverConfig *Config) {
+ // Returning an error should abort the handshake with
+ // that error.
+ clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
+ return nil, errors.New("GetClientCertificate")
+ }
+ },
+ "GetClientCertificate",
+ func(t *testing.T, testNum int, cs *ConnectionState) {
+ },
+ },
+ {
+ func(clientConfig, serverConfig *Config) {
+ clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
+ if len(cri.AcceptableCAs) == 0 {
+ panic("empty AcceptableCAs")
+ }
+ cert := &Certificate{
+ Certificate: [][]byte{testRSACertificate},
+ PrivateKey: testRSAPrivateKey,
+ }
+ return cert, nil
+ }
+ },
+ "",
+ func(t *testing.T, testNum int, cs *ConnectionState) {
+ if len(cs.VerifiedChains) == 0 {
+ t.Errorf("#%d: expected some verified chains, but found none", testNum)
+ }
+ },
+ },
+}
+
+func TestGetClientCertificate(t *testing.T) {
+ t.Run("TLSv12", func(t *testing.T) { testGetClientCertificate(t, VersionTLS12) })
+ t.Run("TLSv13", func(t *testing.T) { testGetClientCertificate(t, VersionTLS13) })
+}
+
+func testGetClientCertificate(t *testing.T, version uint16) {
+ issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
+ if err != nil {
+ panic(err)
+ }
+
+ for i, test := range getClientCertificateTests {
+ serverConfig := testConfig.Clone()
+ serverConfig.ClientAuth = VerifyClientCertIfGiven
+ serverConfig.RootCAs = x509.NewCertPool()
+ serverConfig.RootCAs.AddCert(issuer)
+ serverConfig.ClientCAs = serverConfig.RootCAs
+ serverConfig.Time = func() time.Time { return time.Unix(1476984729, 0) }
+ serverConfig.MaxVersion = version
+
+ clientConfig := testConfig.Clone()
+ clientConfig.MaxVersion = version
+
+ test.setup(clientConfig, serverConfig)
+
+ type serverResult struct {
+ cs ConnectionState
+ err error
+ }
+
+ c, s := localPipe(t)
+ done := make(chan serverResult)
+
+ go func() {
+ defer s.Close()
+ server := Server(s, serverConfig)
+ err := server.Handshake()
+
+ var cs ConnectionState
+ if err == nil {
+ cs = server.ConnectionState()
+ }
+ done <- serverResult{cs, err}
+ }()
+
+ clientErr := Client(c, clientConfig).Handshake()
+ c.Close()
+
+ result := <-done
+
+ if clientErr != nil {
+ if len(test.expectedClientError) == 0 {
+ t.Errorf("#%d: client error: %v", i, clientErr)
+ } else if got := clientErr.Error(); got != test.expectedClientError {
+ t.Errorf("#%d: expected client error %q, but got %q", i, test.expectedClientError, got)
+ } else {
+ test.verify(t, i, &result.cs)
+ }
+ } else if len(test.expectedClientError) > 0 {
+ t.Errorf("#%d: expected client error %q, but got no error", i, test.expectedClientError)
+ } else if err := result.err; err != nil {
+ t.Errorf("#%d: server error: %v", i, err)
+ } else {
+ test.verify(t, i, &result.cs)
+ }
+ }
+}
+
+func TestRSAPSSKeyError(t *testing.T) {
+ // crypto/tls does not support the rsa_pss_pss_* SignatureSchemes. If support for
+ // public keys with OID RSASSA-PSS is added to crypto/x509, they will be misused with
+ // the rsa_pss_rsae_* SignatureSchemes. Assert that RSASSA-PSS certificates don't
+ // parse, or that they don't carry *rsa.PublicKey keys.
+ b, _ := pem.Decode([]byte(`
+-----BEGIN CERTIFICATE-----
+MIIDZTCCAhygAwIBAgIUCF2x0FyTgZG0CC9QTDjGWkB5vgEwPgYJKoZIhvcNAQEK
+MDGgDTALBglghkgBZQMEAgGhGjAYBgkqhkiG9w0BAQgwCwYJYIZIAWUDBAIBogQC
+AgDeMBIxEDAOBgNVBAMMB1JTQS1QU1MwHhcNMTgwNjI3MjI0NDM2WhcNMTgwNzI3
+MjI0NDM2WjASMRAwDgYDVQQDDAdSU0EtUFNTMIIBIDALBgkqhkiG9w0BAQoDggEP
+ADCCAQoCggEBANxDm0f76JdI06YzsjB3AmmjIYkwUEGxePlafmIASFjDZl/elD0Z
+/a7xLX468b0qGxLS5al7XCcEprSdsDR6DF5L520+pCbpfLyPOjuOvGmk9KzVX4x5
+b05YXYuXdsQ0Kjxcx2i3jjCday6scIhMJVgBZxTEyMj1thPQM14SHzKCd/m6HmCL
+QmswpH2yMAAcBRWzRpp/vdH5DeOJEB3aelq7094no731mrLUCHRiZ1htq8BDB3ou
+czwqgwspbqZ4dnMXl2MvfySQ5wJUxQwILbiuAKO2lVVPUbFXHE9pgtznNoPvKwQT
+JNcX8ee8WIZc2SEGzofjk3NpjR+2ADB2u3sCAwEAAaNTMFEwHQYDVR0OBBYEFNEz
+AdyJ2f+fU+vSCS6QzohnOnprMB8GA1UdIwQYMBaAFNEzAdyJ2f+fU+vSCS6Qzohn
+OnprMA8GA1UdEwEB/wQFMAMBAf8wPgYJKoZIhvcNAQEKMDGgDTALBglghkgBZQME
+AgGhGjAYBgkqhkiG9w0BAQgwCwYJYIZIAWUDBAIBogQCAgDeA4IBAQCjEdrR5aab
+sZmCwrMeKidXgfkmWvfuLDE+TCbaqDZp7BMWcMQXT9O0UoUT5kqgKj2ARm2pEW0Z
+H3Z1vj3bbds72qcDIJXp+l0fekyLGeCrX/CbgnMZXEP7+/+P416p34ChR1Wz4dU1
+KD3gdsUuTKKeMUog3plxlxQDhRQmiL25ygH1LmjLd6dtIt0GVRGr8lj3euVeprqZ
+bZ3Uq5eLfsn8oPgfC57gpO6yiN+UURRTlK3bgYvLh4VWB3XXk9UaQZ7Mq1tpXjoD
+HYFybkWzibkZp4WRo+Fa28rirH+/wHt0vfeN7UCceURZEx4JaxIIfe4ku7uDRhJi
+RwBA9Xk1KBNF
+-----END CERTIFICATE-----`))
+ if b == nil {
+ t.Fatal("Failed to decode certificate")
+ }
+ cert, err := x509.ParseCertificate(b.Bytes)
+ if err != nil {
+ return
+ }
+ if _, ok := cert.PublicKey.(*rsa.PublicKey); ok {
+ t.Error("A RSASSA-PSS certificate was parsed like a PKCS#1 v1.5 one, and it will be mistakenly used with rsa_pss_rsae_* signature algorithms")
+ }
+}
+
+func TestCloseClientConnectionOnIdleServer(t *testing.T) {
+ clientConn, serverConn := localPipe(t)
+ client := Client(clientConn, testConfig.Clone())
+ go func() {
+ var b [1]byte
+ serverConn.Read(b[:])
+ client.Close()
+ }()
+ client.SetWriteDeadline(time.Now().Add(time.Minute))
+ err := client.Handshake()
+ if err != nil {
+ if err, ok := err.(net.Error); ok && err.Timeout() {
+ t.Errorf("Expected a closed network connection error but got '%s'", err.Error())
+ }
+ } else {
+ t.Errorf("Error expected, but no error returned")
+ }
+}
+
+func testDowngradeCanary(t *testing.T, clientVersion, serverVersion uint16) error {
+ defer func() { testingOnlyForceDowngradeCanary = false }()
+ testingOnlyForceDowngradeCanary = true
+
+ clientConfig := testConfig.Clone()
+ clientConfig.MaxVersion = clientVersion
+ serverConfig := testConfig.Clone()
+ serverConfig.MaxVersion = serverVersion
+ _, _, err := testHandshake(t, clientConfig, serverConfig)
+ return err
+}
+
+func TestDowngradeCanary(t *testing.T) {
+ if err := testDowngradeCanary(t, VersionTLS13, VersionTLS12); err == nil {
+ t.Errorf("downgrade from TLS 1.3 to TLS 1.2 was not detected")
+ }
+ if testing.Short() {
+ t.Skip("skipping the rest of the checks in short mode")
+ }
+ if err := testDowngradeCanary(t, VersionTLS13, VersionTLS11); err == nil {
+ t.Errorf("downgrade from TLS 1.3 to TLS 1.1 was not detected")
+ }
+ if err := testDowngradeCanary(t, VersionTLS13, VersionTLS10); err == nil {
+ t.Errorf("downgrade from TLS 1.3 to TLS 1.0 was not detected")
+ }
+ if err := testDowngradeCanary(t, VersionTLS12, VersionTLS11); err == nil {
+ t.Errorf("downgrade from TLS 1.2 to TLS 1.1 was not detected")
+ }
+ if err := testDowngradeCanary(t, VersionTLS12, VersionTLS10); err == nil {
+ t.Errorf("downgrade from TLS 1.2 to TLS 1.0 was not detected")
+ }
+ if err := testDowngradeCanary(t, VersionTLS13, VersionTLS13); err != nil {
+ t.Errorf("server unexpectedly sent downgrade canary for TLS 1.3")
+ }
+ if err := testDowngradeCanary(t, VersionTLS12, VersionTLS12); err != nil {
+ t.Errorf("client didn't ignore expected TLS 1.2 canary")
+ }
+ if err := testDowngradeCanary(t, VersionTLS11, VersionTLS11); err != nil {
+ t.Errorf("client unexpectedly reacted to a canary in TLS 1.1")
+ }
+ if err := testDowngradeCanary(t, VersionTLS10, VersionTLS10); err != nil {
+ t.Errorf("client unexpectedly reacted to a canary in TLS 1.0")
+ }
+}
+
+func TestResumptionKeepsOCSPAndSCT(t *testing.T) {
+ t.Run("TLSv12", func(t *testing.T) { testResumptionKeepsOCSPAndSCT(t, VersionTLS12) })
+ t.Run("TLSv13", func(t *testing.T) { testResumptionKeepsOCSPAndSCT(t, VersionTLS13) })
+}
+
+func testResumptionKeepsOCSPAndSCT(t *testing.T, ver uint16) {
+ issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
+ if err != nil {
+ t.Fatalf("failed to parse test issuer")
+ }
+ roots := x509.NewCertPool()
+ roots.AddCert(issuer)
+ clientConfig := &Config{
+ MaxVersion: ver,
+ ClientSessionCache: NewLRUClientSessionCache(32),
+ ServerName: "example.golang",
+ RootCAs: roots,
+ }
+ serverConfig := testConfig.Clone()
+ serverConfig.MaxVersion = ver
+ serverConfig.Certificates[0].OCSPStaple = []byte{1, 2, 3}
+ serverConfig.Certificates[0].SignedCertificateTimestamps = [][]byte{{4, 5, 6}}
+
+ _, ccs, err := testHandshake(t, clientConfig, serverConfig)
+ if err != nil {
+ t.Fatalf("handshake failed: %s", err)
+ }
+ // after a new session we expect to see OCSPResponse and
+ // SignedCertificateTimestamps populated as usual
+ if !bytes.Equal(ccs.OCSPResponse, serverConfig.Certificates[0].OCSPStaple) {
+ t.Errorf("client ConnectionState contained unexpected OCSPResponse: wanted %v, got %v",
+ serverConfig.Certificates[0].OCSPStaple, ccs.OCSPResponse)
+ }
+ if !reflect.DeepEqual(ccs.SignedCertificateTimestamps, serverConfig.Certificates[0].SignedCertificateTimestamps) {
+ t.Errorf("client ConnectionState contained unexpected SignedCertificateTimestamps: wanted %v, got %v",
+ serverConfig.Certificates[0].SignedCertificateTimestamps, ccs.SignedCertificateTimestamps)
+ }
+
+ // if the server doesn't send any SCTs, repopulate the old SCTs
+ oldSCTs := serverConfig.Certificates[0].SignedCertificateTimestamps
+ serverConfig.Certificates[0].SignedCertificateTimestamps = nil
+ _, ccs, err = testHandshake(t, clientConfig, serverConfig)
+ if err != nil {
+ t.Fatalf("handshake failed: %s", err)
+ }
+ if !ccs.DidResume {
+ t.Fatalf("expected session to be resumed")
+ }
+ // after a resumed session we also expect to see OCSPResponse
+ // and SignedCertificateTimestamps populated
+ if !bytes.Equal(ccs.OCSPResponse, serverConfig.Certificates[0].OCSPStaple) {
+ t.Errorf("client ConnectionState contained unexpected OCSPResponse after resumption: wanted %v, got %v",
+ serverConfig.Certificates[0].OCSPStaple, ccs.OCSPResponse)
+ }
+ if !reflect.DeepEqual(ccs.SignedCertificateTimestamps, oldSCTs) {
+ t.Errorf("client ConnectionState contained unexpected SignedCertificateTimestamps after resumption: wanted %v, got %v",
+ oldSCTs, ccs.SignedCertificateTimestamps)
+ }
+
+ // Only test overriding the SCTs for TLS 1.2, since in 1.3
+ // the server won't send the message containing them
+ if ver == VersionTLS13 {
+ return
+ }
+
+ // if the server changes the SCTs it sends, they should override the saved SCTs
+ serverConfig.Certificates[0].SignedCertificateTimestamps = [][]byte{{7, 8, 9}}
+ _, ccs, err = testHandshake(t, clientConfig, serverConfig)
+ if err != nil {
+ t.Fatalf("handshake failed: %s", err)
+ }
+ if !ccs.DidResume {
+ t.Fatalf("expected session to be resumed")
+ }
+ if !reflect.DeepEqual(ccs.SignedCertificateTimestamps, serverConfig.Certificates[0].SignedCertificateTimestamps) {
+ t.Errorf("client ConnectionState contained unexpected SignedCertificateTimestamps after resumption: wanted %v, got %v",
+ serverConfig.Certificates[0].SignedCertificateTimestamps, ccs.SignedCertificateTimestamps)
+ }
+}
+
+// TestClientHandshakeContextCancellation tests that canceling
+// the context given to the client side conn.HandshakeContext
+// interrupts the in-progress handshake.
+func TestClientHandshakeContextCancellation(t *testing.T) {
+ c, s := localPipe(t)
+ ctx, cancel := context.WithCancel(context.Background())
+ unblockServer := make(chan struct{})
+ defer close(unblockServer)
+ go func() {
+ cancel()
+ <-unblockServer
+ _ = s.Close()
+ }()
+ cli := Client(c, testConfig)
+ // Initiates client side handshake, which will block until the client hello is read
+ // by the server, unless the cancellation works.
+ err := cli.HandshakeContext(ctx)
+ if err == nil {
+ t.Fatal("Client handshake did not error when the context was canceled")
+ }
+ if err != context.Canceled {
+ t.Errorf("Unexpected client handshake error: %v", err)
+ }
+ if runtime.GOARCH == "wasm" {
+ t.Skip("conn.Close does not error as expected when called multiple times on WASM")
+ }
+ err = cli.Close()
+ if err == nil {
+ t.Error("Client connection was not closed when the context was canceled")
+ }
+}