diff options
Diffstat (limited to 'src/crypto/tls/quic_test.go')
-rw-r--r-- | src/crypto/tls/quic_test.go | 489 |
1 files changed, 489 insertions, 0 deletions
diff --git a/src/crypto/tls/quic_test.go b/src/crypto/tls/quic_test.go new file mode 100644 index 0000000..323906a --- /dev/null +++ b/src/crypto/tls/quic_test.go @@ -0,0 +1,489 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "context" + "errors" + "reflect" + "testing" +) + +type testQUICConn struct { + t *testing.T + conn *QUICConn + readSecret map[QUICEncryptionLevel]suiteSecret + writeSecret map[QUICEncryptionLevel]suiteSecret + gotParams []byte + complete bool +} + +func newTestQUICClient(t *testing.T, config *Config) *testQUICConn { + q := &testQUICConn{t: t} + q.conn = QUICClient(&QUICConfig{ + TLSConfig: config, + }) + t.Cleanup(func() { + q.conn.Close() + }) + return q +} + +func newTestQUICServer(t *testing.T, config *Config) *testQUICConn { + q := &testQUICConn{t: t} + q.conn = QUICServer(&QUICConfig{ + TLSConfig: config, + }) + t.Cleanup(func() { + q.conn.Close() + }) + return q +} + +type suiteSecret struct { + suite uint16 + secret []byte +} + +func (q *testQUICConn) setReadSecret(level QUICEncryptionLevel, suite uint16, secret []byte) { + if _, ok := q.writeSecret[level]; !ok { + q.t.Errorf("SetReadSecret for level %v called before SetWriteSecret", level) + } + if level == QUICEncryptionLevelApplication && !q.complete { + q.t.Errorf("SetReadSecret for level %v called before HandshakeComplete", level) + } + if _, ok := q.readSecret[level]; ok { + q.t.Errorf("SetReadSecret for level %v called twice", level) + } + if q.readSecret == nil { + q.readSecret = map[QUICEncryptionLevel]suiteSecret{} + } + switch level { + case QUICEncryptionLevelHandshake, QUICEncryptionLevelApplication: + q.readSecret[level] = suiteSecret{suite, secret} + default: + q.t.Errorf("SetReadSecret for unexpected level %v", level) + } +} + +func (q *testQUICConn) setWriteSecret(level QUICEncryptionLevel, suite uint16, secret []byte) { + if _, ok := q.writeSecret[level]; ok { + q.t.Errorf("SetWriteSecret for level %v called twice", level) + } + if q.writeSecret == nil { + q.writeSecret = map[QUICEncryptionLevel]suiteSecret{} + } + switch level { + case QUICEncryptionLevelHandshake, QUICEncryptionLevelApplication: + q.writeSecret[level] = suiteSecret{suite, secret} + default: + q.t.Errorf("SetWriteSecret for unexpected level %v", level) + } +} + +var errTransportParametersRequired = errors.New("transport parameters required") + +func runTestQUICConnection(ctx context.Context, cli, srv *testQUICConn, onEvent func(e QUICEvent, src, dst *testQUICConn) bool) error { + a, b := cli, srv + for _, c := range []*testQUICConn{a, b} { + if !c.conn.conn.quic.started { + if err := c.conn.Start(ctx); err != nil { + return err + } + } + } + idleCount := 0 + for { + e := a.conn.NextEvent() + if onEvent != nil && onEvent(e, a, b) { + continue + } + switch e.Kind { + case QUICNoEvent: + idleCount++ + if idleCount == 2 { + if !a.complete || !b.complete { + return errors.New("handshake incomplete") + } + return nil + } + a, b = b, a + case QUICSetReadSecret: + a.setReadSecret(e.Level, e.Suite, e.Data) + case QUICSetWriteSecret: + a.setWriteSecret(e.Level, e.Suite, e.Data) + case QUICWriteData: + if err := b.conn.HandleData(e.Level, e.Data); err != nil { + return err + } + case QUICTransportParameters: + a.gotParams = e.Data + if a.gotParams == nil { + a.gotParams = []byte{} + } + case QUICTransportParametersRequired: + return errTransportParametersRequired + case QUICHandshakeDone: + a.complete = true + if a == srv { + opts := QUICSessionTicketOptions{} + if err := srv.conn.SendSessionTicket(opts); err != nil { + return err + } + } + } + if e.Kind != QUICNoEvent { + idleCount = 0 + } + } +} + +func TestQUICConnection(t *testing.T) { + config := testConfig.Clone() + config.MinVersion = VersionTLS13 + + cli := newTestQUICClient(t, config) + cli.conn.SetTransportParameters(nil) + + srv := newTestQUICServer(t, config) + srv.conn.SetTransportParameters(nil) + + if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil { + t.Fatalf("error during connection handshake: %v", err) + } + + if _, ok := cli.readSecret[QUICEncryptionLevelHandshake]; !ok { + t.Errorf("client has no Handshake secret") + } + if _, ok := cli.readSecret[QUICEncryptionLevelApplication]; !ok { + t.Errorf("client has no Application secret") + } + if _, ok := srv.readSecret[QUICEncryptionLevelHandshake]; !ok { + t.Errorf("server has no Handshake secret") + } + if _, ok := srv.readSecret[QUICEncryptionLevelApplication]; !ok { + t.Errorf("server has no Application secret") + } + for _, level := range []QUICEncryptionLevel{QUICEncryptionLevelHandshake, QUICEncryptionLevelApplication} { + if _, ok := cli.readSecret[level]; !ok { + t.Errorf("client has no %v read secret", level) + } + if _, ok := srv.readSecret[level]; !ok { + t.Errorf("server has no %v read secret", level) + } + if !reflect.DeepEqual(cli.readSecret[level], srv.writeSecret[level]) { + t.Errorf("client read secret does not match server write secret for level %v", level) + } + if !reflect.DeepEqual(cli.writeSecret[level], srv.readSecret[level]) { + t.Errorf("client write secret does not match server read secret for level %v", level) + } + } +} + +func TestQUICSessionResumption(t *testing.T) { + clientConfig := testConfig.Clone() + clientConfig.MinVersion = VersionTLS13 + clientConfig.ClientSessionCache = NewLRUClientSessionCache(1) + clientConfig.ServerName = "example.go.dev" + + serverConfig := testConfig.Clone() + serverConfig.MinVersion = VersionTLS13 + + cli := newTestQUICClient(t, clientConfig) + cli.conn.SetTransportParameters(nil) + srv := newTestQUICServer(t, serverConfig) + srv.conn.SetTransportParameters(nil) + if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil { + t.Fatalf("error during first connection handshake: %v", err) + } + if cli.conn.ConnectionState().DidResume { + t.Errorf("first connection unexpectedly used session resumption") + } + + cli2 := newTestQUICClient(t, clientConfig) + cli2.conn.SetTransportParameters(nil) + srv2 := newTestQUICServer(t, serverConfig) + srv2.conn.SetTransportParameters(nil) + if err := runTestQUICConnection(context.Background(), cli2, srv2, nil); err != nil { + t.Fatalf("error during second connection handshake: %v", err) + } + if !cli2.conn.ConnectionState().DidResume { + t.Errorf("second connection did not use session resumption") + } +} + +func TestQUICFragmentaryData(t *testing.T) { + clientConfig := testConfig.Clone() + clientConfig.MinVersion = VersionTLS13 + clientConfig.ClientSessionCache = NewLRUClientSessionCache(1) + clientConfig.ServerName = "example.go.dev" + + serverConfig := testConfig.Clone() + serverConfig.MinVersion = VersionTLS13 + + cli := newTestQUICClient(t, clientConfig) + cli.conn.SetTransportParameters(nil) + srv := newTestQUICServer(t, serverConfig) + srv.conn.SetTransportParameters(nil) + onEvent := func(e QUICEvent, src, dst *testQUICConn) bool { + if e.Kind == QUICWriteData { + // Provide the data one byte at a time. + for i := range e.Data { + if err := dst.conn.HandleData(e.Level, e.Data[i:i+1]); err != nil { + t.Errorf("HandleData: %v", err) + break + } + } + return true + } + return false + } + if err := runTestQUICConnection(context.Background(), cli, srv, onEvent); err != nil { + t.Fatalf("error during first connection handshake: %v", err) + } +} + +func TestQUICPostHandshakeClientAuthentication(t *testing.T) { + // RFC 9001, Section 4.4. + config := testConfig.Clone() + config.MinVersion = VersionTLS13 + cli := newTestQUICClient(t, config) + cli.conn.SetTransportParameters(nil) + srv := newTestQUICServer(t, config) + srv.conn.SetTransportParameters(nil) + if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil { + t.Fatalf("error during connection handshake: %v", err) + } + + certReq := new(certificateRequestMsgTLS13) + certReq.ocspStapling = true + certReq.scts = true + certReq.supportedSignatureAlgorithms = supportedSignatureAlgorithms() + certReqBytes, err := certReq.marshal() + if err != nil { + t.Fatal(err) + } + if err := cli.conn.HandleData(QUICEncryptionLevelApplication, append([]byte{ + byte(typeCertificateRequest), + byte(0), byte(0), byte(len(certReqBytes)), + }, certReqBytes...)); err == nil { + t.Fatalf("post-handshake authentication request: got no error, want one") + } +} + +func TestQUICPostHandshakeKeyUpdate(t *testing.T) { + // RFC 9001, Section 6. + config := testConfig.Clone() + config.MinVersion = VersionTLS13 + cli := newTestQUICClient(t, config) + cli.conn.SetTransportParameters(nil) + srv := newTestQUICServer(t, config) + srv.conn.SetTransportParameters(nil) + if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil { + t.Fatalf("error during connection handshake: %v", err) + } + + keyUpdate := new(keyUpdateMsg) + keyUpdateBytes, err := keyUpdate.marshal() + if err != nil { + t.Fatal(err) + } + if err := cli.conn.HandleData(QUICEncryptionLevelApplication, append([]byte{ + byte(typeKeyUpdate), + byte(0), byte(0), byte(len(keyUpdateBytes)), + }, keyUpdateBytes...)); !errors.Is(err, alertUnexpectedMessage) { + t.Fatalf("key update request: got error %v, want alertUnexpectedMessage", err) + } +} + +func TestQUICPostHandshakeMessageTooLarge(t *testing.T) { + config := testConfig.Clone() + config.MinVersion = VersionTLS13 + cli := newTestQUICClient(t, config) + cli.conn.SetTransportParameters(nil) + srv := newTestQUICServer(t, config) + srv.conn.SetTransportParameters(nil) + if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil { + t.Fatalf("error during connection handshake: %v", err) + } + + size := maxHandshake + 1 + if err := cli.conn.HandleData(QUICEncryptionLevelApplication, []byte{ + byte(typeNewSessionTicket), + byte(size >> 16), + byte(size >> 8), + byte(size), + }); err == nil { + t.Fatalf("%v-byte post-handshake message: got no error, want one", size) + } +} + +func TestQUICHandshakeError(t *testing.T) { + clientConfig := testConfig.Clone() + clientConfig.MinVersion = VersionTLS13 + clientConfig.InsecureSkipVerify = false + clientConfig.ServerName = "name" + + serverConfig := testConfig.Clone() + serverConfig.MinVersion = VersionTLS13 + + cli := newTestQUICClient(t, clientConfig) + cli.conn.SetTransportParameters(nil) + srv := newTestQUICServer(t, serverConfig) + srv.conn.SetTransportParameters(nil) + err := runTestQUICConnection(context.Background(), cli, srv, nil) + if !errors.Is(err, AlertError(alertBadCertificate)) { + t.Errorf("connection handshake terminated with error %q, want alertBadCertificate", err) + } + var e *CertificateVerificationError + if !errors.As(err, &e) { + t.Errorf("connection handshake terminated with error %q, want CertificateVerificationError", err) + } +} + +// Test that QUICConn.ConnectionState can be used during the handshake, +// and that it reports the application protocol as soon as it has been +// negotiated. +func TestQUICConnectionState(t *testing.T) { + config := testConfig.Clone() + config.MinVersion = VersionTLS13 + config.NextProtos = []string{"h3"} + cli := newTestQUICClient(t, config) + cli.conn.SetTransportParameters(nil) + srv := newTestQUICServer(t, config) + srv.conn.SetTransportParameters(nil) + onEvent := func(e QUICEvent, src, dst *testQUICConn) bool { + cliCS := cli.conn.ConnectionState() + if _, ok := cli.readSecret[QUICEncryptionLevelApplication]; ok { + if want, got := cliCS.NegotiatedProtocol, "h3"; want != got { + t.Errorf("cli.ConnectionState().NegotiatedProtocol = %q, want %q", want, got) + } + } + srvCS := srv.conn.ConnectionState() + if _, ok := srv.readSecret[QUICEncryptionLevelHandshake]; ok { + if want, got := srvCS.NegotiatedProtocol, "h3"; want != got { + t.Errorf("srv.ConnectionState().NegotiatedProtocol = %q, want %q", want, got) + } + } + return false + } + if err := runTestQUICConnection(context.Background(), cli, srv, onEvent); err != nil { + t.Fatalf("error during connection handshake: %v", err) + } +} + +func TestQUICStartContextPropagation(t *testing.T) { + const key = "key" + const value = "value" + ctx := context.WithValue(context.Background(), key, value) + config := testConfig.Clone() + config.MinVersion = VersionTLS13 + calls := 0 + config.GetConfigForClient = func(info *ClientHelloInfo) (*Config, error) { + calls++ + got, _ := info.Context().Value(key).(string) + if got != value { + t.Errorf("GetConfigForClient context key %q has value %q, want %q", key, got, value) + } + return nil, nil + } + cli := newTestQUICClient(t, config) + cli.conn.SetTransportParameters(nil) + srv := newTestQUICServer(t, config) + srv.conn.SetTransportParameters(nil) + if err := runTestQUICConnection(ctx, cli, srv, nil); err != nil { + t.Fatalf("error during connection handshake: %v", err) + } + if calls != 1 { + t.Errorf("GetConfigForClient called %v times, want 1", calls) + } +} + +func TestQUICDelayedTransportParameters(t *testing.T) { + clientConfig := testConfig.Clone() + clientConfig.MinVersion = VersionTLS13 + clientConfig.ClientSessionCache = NewLRUClientSessionCache(1) + clientConfig.ServerName = "example.go.dev" + + serverConfig := testConfig.Clone() + serverConfig.MinVersion = VersionTLS13 + + cliParams := "client params" + srvParams := "server params" + + cli := newTestQUICClient(t, clientConfig) + srv := newTestQUICServer(t, serverConfig) + if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != errTransportParametersRequired { + t.Fatalf("handshake with no client parameters: %v; want errTransportParametersRequired", err) + } + cli.conn.SetTransportParameters([]byte(cliParams)) + if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != errTransportParametersRequired { + t.Fatalf("handshake with no server parameters: %v; want errTransportParametersRequired", err) + } + srv.conn.SetTransportParameters([]byte(srvParams)) + if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil { + t.Fatalf("error during connection handshake: %v", err) + } + + if got, want := string(cli.gotParams), srvParams; got != want { + t.Errorf("client got transport params: %q, want %q", got, want) + } + if got, want := string(srv.gotParams), cliParams; got != want { + t.Errorf("server got transport params: %q, want %q", got, want) + } +} + +func TestQUICEmptyTransportParameters(t *testing.T) { + config := testConfig.Clone() + config.MinVersion = VersionTLS13 + + cli := newTestQUICClient(t, config) + cli.conn.SetTransportParameters(nil) + srv := newTestQUICServer(t, config) + srv.conn.SetTransportParameters(nil) + if err := runTestQUICConnection(context.Background(), cli, srv, nil); err != nil { + t.Fatalf("error during connection handshake: %v", err) + } + + if cli.gotParams == nil { + t.Errorf("client did not get transport params") + } + if srv.gotParams == nil { + t.Errorf("server did not get transport params") + } + if len(cli.gotParams) != 0 { + t.Errorf("client got transport params: %v, want empty", cli.gotParams) + } + if len(srv.gotParams) != 0 { + t.Errorf("server got transport params: %v, want empty", srv.gotParams) + } +} + +func TestQUICCanceledWaitingForData(t *testing.T) { + config := testConfig.Clone() + config.MinVersion = VersionTLS13 + cli := newTestQUICClient(t, config) + cli.conn.SetTransportParameters(nil) + cli.conn.Start(context.Background()) + for cli.conn.NextEvent().Kind != QUICNoEvent { + } + err := cli.conn.Close() + if !errors.Is(err, alertCloseNotify) { + t.Errorf("conn.Close() = %v, want alertCloseNotify", err) + } +} + +func TestQUICCanceledWaitingForTransportParams(t *testing.T) { + config := testConfig.Clone() + config.MinVersion = VersionTLS13 + cli := newTestQUICClient(t, config) + cli.conn.Start(context.Background()) + for cli.conn.NextEvent().Kind != QUICTransportParametersRequired { + } + err := cli.conn.Close() + if !errors.Is(err, alertCloseNotify) { + t.Errorf("conn.Close() = %v, want alertCloseNotify", err) + } +} |