diff options
Diffstat (limited to '')
-rw-r--r-- | src/net/splice_test.go | 417 |
1 files changed, 417 insertions, 0 deletions
diff --git a/src/net/splice_test.go b/src/net/splice_test.go new file mode 100644 index 0000000..fa14c95 --- /dev/null +++ b/src/net/splice_test.go @@ -0,0 +1,417 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build linux + +package net + +import ( + "io" + "log" + "os" + "os/exec" + "strconv" + "sync" + "testing" + "time" +) + +func TestSplice(t *testing.T) { + t.Run("tcp-to-tcp", func(t *testing.T) { testSplice(t, "tcp", "tcp") }) + if !testableNetwork("unixgram") { + t.Skip("skipping unix-to-tcp tests") + } + t.Run("unix-to-tcp", func(t *testing.T) { testSplice(t, "unix", "tcp") }) + t.Run("no-unixpacket", testSpliceNoUnixpacket) + t.Run("no-unixgram", testSpliceNoUnixgram) +} + +func testSplice(t *testing.T, upNet, downNet string) { + t.Run("simple", spliceTestCase{upNet, downNet, 128, 128, 0}.test) + t.Run("multipleWrite", spliceTestCase{upNet, downNet, 4096, 1 << 20, 0}.test) + t.Run("big", spliceTestCase{upNet, downNet, 5 << 20, 1 << 30, 0}.test) + t.Run("honorsLimitedReader", spliceTestCase{upNet, downNet, 4096, 1 << 20, 1 << 10}.test) + t.Run("updatesLimitedReaderN", spliceTestCase{upNet, downNet, 1024, 4096, 4096 + 100}.test) + t.Run("limitedReaderAtLimit", spliceTestCase{upNet, downNet, 32, 128, 128}.test) + t.Run("readerAtEOF", func(t *testing.T) { testSpliceReaderAtEOF(t, upNet, downNet) }) + t.Run("issue25985", func(t *testing.T) { testSpliceIssue25985(t, upNet, downNet) }) +} + +type spliceTestCase struct { + upNet, downNet string + + chunkSize, totalSize int + limitReadSize int +} + +func (tc spliceTestCase) test(t *testing.T) { + clientUp, serverUp := spliceTestSocketPair(t, tc.upNet) + defer serverUp.Close() + cleanup, err := startSpliceClient(clientUp, "w", tc.chunkSize, tc.totalSize) + if err != nil { + t.Fatal(err) + } + defer cleanup() + clientDown, serverDown := spliceTestSocketPair(t, tc.downNet) + defer serverDown.Close() + cleanup, err = startSpliceClient(clientDown, "r", tc.chunkSize, tc.totalSize) + if err != nil { + t.Fatal(err) + } + defer cleanup() + var ( + r io.Reader = serverUp + size = tc.totalSize + ) + if tc.limitReadSize > 0 { + if tc.limitReadSize < size { + size = tc.limitReadSize + } + + r = &io.LimitedReader{ + N: int64(tc.limitReadSize), + R: serverUp, + } + defer serverUp.Close() + } + n, err := io.Copy(serverDown, r) + serverDown.Close() + if err != nil { + t.Fatal(err) + } + if want := int64(size); want != n { + t.Errorf("want %d bytes spliced, got %d", want, n) + } + + if tc.limitReadSize > 0 { + wantN := 0 + if tc.limitReadSize > size { + wantN = tc.limitReadSize - size + } + + if n := r.(*io.LimitedReader).N; n != int64(wantN) { + t.Errorf("r.N = %d, want %d", n, wantN) + } + } +} + +func testSpliceReaderAtEOF(t *testing.T, upNet, downNet string) { + clientUp, serverUp := spliceTestSocketPair(t, upNet) + defer clientUp.Close() + clientDown, serverDown := spliceTestSocketPair(t, downNet) + defer clientDown.Close() + + serverUp.Close() + + // We'd like to call net.splice here and check the handled return + // value, but we disable splice on old Linux kernels. + // + // In that case, poll.Splice and net.splice return a non-nil error + // and handled == false. We'd ideally like to see handled == true + // because the source reader is at EOF, but if we're running on an old + // kernel, and splice is disabled, we won't see EOF from net.splice, + // because we won't touch the reader at all. + // + // Trying to untangle the errors from net.splice and match them + // against the errors created by the poll package would be brittle, + // so this is a higher level test. + // + // The following ReadFrom should return immediately, regardless of + // whether splice is disabled or not. The other side should then + // get a goodbye signal. Test for the goodbye signal. + msg := "bye" + go func() { + serverDown.(io.ReaderFrom).ReadFrom(serverUp) + io.WriteString(serverDown, msg) + serverDown.Close() + }() + + buf := make([]byte, 3) + _, err := io.ReadFull(clientDown, buf) + if err != nil { + t.Errorf("clientDown: %v", err) + } + if string(buf) != msg { + t.Errorf("clientDown got %q, want %q", buf, msg) + } +} + +func testSpliceIssue25985(t *testing.T, upNet, downNet string) { + front := newLocalListener(t, upNet) + defer front.Close() + back := newLocalListener(t, downNet) + defer back.Close() + + var wg sync.WaitGroup + wg.Add(2) + + proxy := func() { + src, err := front.Accept() + if err != nil { + return + } + dst, err := Dial(downNet, back.Addr().String()) + if err != nil { + return + } + defer dst.Close() + defer src.Close() + go func() { + io.Copy(src, dst) + wg.Done() + }() + go func() { + io.Copy(dst, src) + wg.Done() + }() + } + + go proxy() + + toFront, err := Dial(upNet, front.Addr().String()) + if err != nil { + t.Fatal(err) + } + + io.WriteString(toFront, "foo") + toFront.Close() + + fromProxy, err := back.Accept() + if err != nil { + t.Fatal(err) + } + defer fromProxy.Close() + + _, err = io.ReadAll(fromProxy) + if err != nil { + t.Fatal(err) + } + + wg.Wait() +} + +func testSpliceNoUnixpacket(t *testing.T) { + clientUp, serverUp := spliceTestSocketPair(t, "unixpacket") + defer clientUp.Close() + defer serverUp.Close() + clientDown, serverDown := spliceTestSocketPair(t, "tcp") + defer clientDown.Close() + defer serverDown.Close() + // If splice called poll.Splice here, we'd get err == syscall.EINVAL + // and handled == false. If poll.Splice gets an EINVAL on the first + // try, it assumes the kernel it's running on doesn't support splice + // for unix sockets and returns handled == false. This works for our + // purposes by somewhat of an accident, but is not entirely correct. + // + // What we want is err == nil and handled == false, i.e. we never + // called poll.Splice, because we know the unix socket's network. + _, err, handled := splice(serverDown.(*TCPConn).fd, serverUp) + if err != nil || handled != false { + t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled) + } +} + +func testSpliceNoUnixgram(t *testing.T) { + addr, err := ResolveUnixAddr("unixgram", testUnixAddr(t)) + if err != nil { + t.Fatal(err) + } + defer os.Remove(addr.Name) + up, err := ListenUnixgram("unixgram", addr) + if err != nil { + t.Fatal(err) + } + defer up.Close() + clientDown, serverDown := spliceTestSocketPair(t, "tcp") + defer clientDown.Close() + defer serverDown.Close() + // Analogous to testSpliceNoUnixpacket. + _, err, handled := splice(serverDown.(*TCPConn).fd, up) + if err != nil || handled != false { + t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled) + } +} + +func BenchmarkSplice(b *testing.B) { + testHookUninstaller.Do(uninstallTestHooks) + + b.Run("tcp-to-tcp", func(b *testing.B) { benchSplice(b, "tcp", "tcp") }) + b.Run("unix-to-tcp", func(b *testing.B) { benchSplice(b, "unix", "tcp") }) +} + +func benchSplice(b *testing.B, upNet, downNet string) { + for i := 0; i <= 10; i++ { + chunkSize := 1 << uint(i+10) + tc := spliceTestCase{ + upNet: upNet, + downNet: downNet, + chunkSize: chunkSize, + } + + b.Run(strconv.Itoa(chunkSize), tc.bench) + } +} + +func (tc spliceTestCase) bench(b *testing.B) { + // To benchmark the genericReadFrom code path, set this to false. + useSplice := true + + clientUp, serverUp := spliceTestSocketPair(b, tc.upNet) + defer serverUp.Close() + + cleanup, err := startSpliceClient(clientUp, "w", tc.chunkSize, tc.chunkSize*b.N) + if err != nil { + b.Fatal(err) + } + defer cleanup() + + clientDown, serverDown := spliceTestSocketPair(b, tc.downNet) + defer serverDown.Close() + + cleanup, err = startSpliceClient(clientDown, "r", tc.chunkSize, tc.chunkSize*b.N) + if err != nil { + b.Fatal(err) + } + defer cleanup() + + b.SetBytes(int64(tc.chunkSize)) + b.ResetTimer() + + if useSplice { + _, err := io.Copy(serverDown, serverUp) + if err != nil { + b.Fatal(err) + } + } else { + type onlyReader struct { + io.Reader + } + _, err := io.Copy(serverDown, onlyReader{serverUp}) + if err != nil { + b.Fatal(err) + } + } +} + +func spliceTestSocketPair(t testing.TB, net string) (client, server Conn) { + t.Helper() + ln := newLocalListener(t, net) + defer ln.Close() + var cerr, serr error + acceptDone := make(chan struct{}) + go func() { + server, serr = ln.Accept() + acceptDone <- struct{}{} + }() + client, cerr = Dial(ln.Addr().Network(), ln.Addr().String()) + <-acceptDone + if cerr != nil { + if server != nil { + server.Close() + } + t.Fatal(cerr) + } + if serr != nil { + if client != nil { + client.Close() + } + t.Fatal(serr) + } + return client, server +} + +func startSpliceClient(conn Conn, op string, chunkSize, totalSize int) (func(), error) { + f, err := conn.(interface{ File() (*os.File, error) }).File() + if err != nil { + return nil, err + } + + cmd := exec.Command(os.Args[0], os.Args[1:]...) + cmd.Env = []string{ + "GO_NET_TEST_SPLICE=1", + "GO_NET_TEST_SPLICE_OP=" + op, + "GO_NET_TEST_SPLICE_CHUNK_SIZE=" + strconv.Itoa(chunkSize), + "GO_NET_TEST_SPLICE_TOTAL_SIZE=" + strconv.Itoa(totalSize), + "TMPDIR=" + os.Getenv("TMPDIR"), + } + cmd.ExtraFiles = append(cmd.ExtraFiles, f) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + if err := cmd.Start(); err != nil { + return nil, err + } + + donec := make(chan struct{}) + go func() { + cmd.Wait() + conn.Close() + f.Close() + close(donec) + }() + + return func() { + select { + case <-donec: + case <-time.After(5 * time.Second): + log.Printf("killing splice client after 5 second shutdown timeout") + cmd.Process.Kill() + select { + case <-donec: + case <-time.After(5 * time.Second): + log.Printf("splice client didn't die after 10 seconds") + } + } + }, nil +} + +func init() { + if os.Getenv("GO_NET_TEST_SPLICE") == "" { + return + } + defer os.Exit(0) + + f := os.NewFile(uintptr(3), "splice-test-conn") + defer f.Close() + + conn, err := FileConn(f) + if err != nil { + log.Fatal(err) + } + + var chunkSize int + if chunkSize, err = strconv.Atoi(os.Getenv("GO_NET_TEST_SPLICE_CHUNK_SIZE")); err != nil { + log.Fatal(err) + } + buf := make([]byte, chunkSize) + + var totalSize int + if totalSize, err = strconv.Atoi(os.Getenv("GO_NET_TEST_SPLICE_TOTAL_SIZE")); err != nil { + log.Fatal(err) + } + + var fn func([]byte) (int, error) + switch op := os.Getenv("GO_NET_TEST_SPLICE_OP"); op { + case "r": + fn = conn.Read + case "w": + defer conn.Close() + + fn = conn.Write + default: + log.Fatalf("unknown op %q", op) + } + + var n int + for count := 0; count < totalSize; count += n { + if count+chunkSize > totalSize { + buf = buf[:totalSize-count] + } + + var err error + if n, err = fn(buf); err != nil { + return + } + } +} |