summaryrefslogtreecommitdiffstats
path: root/src/net/splice_test.go
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--src/net/splice_test.go417
1 files changed, 417 insertions, 0 deletions
diff --git a/src/net/splice_test.go b/src/net/splice_test.go
new file mode 100644
index 0000000..fa14c95
--- /dev/null
+++ b/src/net/splice_test.go
@@ -0,0 +1,417 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build linux
+
+package net
+
+import (
+ "io"
+ "log"
+ "os"
+ "os/exec"
+ "strconv"
+ "sync"
+ "testing"
+ "time"
+)
+
+func TestSplice(t *testing.T) {
+ t.Run("tcp-to-tcp", func(t *testing.T) { testSplice(t, "tcp", "tcp") })
+ if !testableNetwork("unixgram") {
+ t.Skip("skipping unix-to-tcp tests")
+ }
+ t.Run("unix-to-tcp", func(t *testing.T) { testSplice(t, "unix", "tcp") })
+ t.Run("no-unixpacket", testSpliceNoUnixpacket)
+ t.Run("no-unixgram", testSpliceNoUnixgram)
+}
+
+func testSplice(t *testing.T, upNet, downNet string) {
+ t.Run("simple", spliceTestCase{upNet, downNet, 128, 128, 0}.test)
+ t.Run("multipleWrite", spliceTestCase{upNet, downNet, 4096, 1 << 20, 0}.test)
+ t.Run("big", spliceTestCase{upNet, downNet, 5 << 20, 1 << 30, 0}.test)
+ t.Run("honorsLimitedReader", spliceTestCase{upNet, downNet, 4096, 1 << 20, 1 << 10}.test)
+ t.Run("updatesLimitedReaderN", spliceTestCase{upNet, downNet, 1024, 4096, 4096 + 100}.test)
+ t.Run("limitedReaderAtLimit", spliceTestCase{upNet, downNet, 32, 128, 128}.test)
+ t.Run("readerAtEOF", func(t *testing.T) { testSpliceReaderAtEOF(t, upNet, downNet) })
+ t.Run("issue25985", func(t *testing.T) { testSpliceIssue25985(t, upNet, downNet) })
+}
+
+type spliceTestCase struct {
+ upNet, downNet string
+
+ chunkSize, totalSize int
+ limitReadSize int
+}
+
+func (tc spliceTestCase) test(t *testing.T) {
+ clientUp, serverUp := spliceTestSocketPair(t, tc.upNet)
+ defer serverUp.Close()
+ cleanup, err := startSpliceClient(clientUp, "w", tc.chunkSize, tc.totalSize)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cleanup()
+ clientDown, serverDown := spliceTestSocketPair(t, tc.downNet)
+ defer serverDown.Close()
+ cleanup, err = startSpliceClient(clientDown, "r", tc.chunkSize, tc.totalSize)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cleanup()
+ var (
+ r io.Reader = serverUp
+ size = tc.totalSize
+ )
+ if tc.limitReadSize > 0 {
+ if tc.limitReadSize < size {
+ size = tc.limitReadSize
+ }
+
+ r = &io.LimitedReader{
+ N: int64(tc.limitReadSize),
+ R: serverUp,
+ }
+ defer serverUp.Close()
+ }
+ n, err := io.Copy(serverDown, r)
+ serverDown.Close()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if want := int64(size); want != n {
+ t.Errorf("want %d bytes spliced, got %d", want, n)
+ }
+
+ if tc.limitReadSize > 0 {
+ wantN := 0
+ if tc.limitReadSize > size {
+ wantN = tc.limitReadSize - size
+ }
+
+ if n := r.(*io.LimitedReader).N; n != int64(wantN) {
+ t.Errorf("r.N = %d, want %d", n, wantN)
+ }
+ }
+}
+
+func testSpliceReaderAtEOF(t *testing.T, upNet, downNet string) {
+ clientUp, serverUp := spliceTestSocketPair(t, upNet)
+ defer clientUp.Close()
+ clientDown, serverDown := spliceTestSocketPair(t, downNet)
+ defer clientDown.Close()
+
+ serverUp.Close()
+
+ // We'd like to call net.splice here and check the handled return
+ // value, but we disable splice on old Linux kernels.
+ //
+ // In that case, poll.Splice and net.splice return a non-nil error
+ // and handled == false. We'd ideally like to see handled == true
+ // because the source reader is at EOF, but if we're running on an old
+ // kernel, and splice is disabled, we won't see EOF from net.splice,
+ // because we won't touch the reader at all.
+ //
+ // Trying to untangle the errors from net.splice and match them
+ // against the errors created by the poll package would be brittle,
+ // so this is a higher level test.
+ //
+ // The following ReadFrom should return immediately, regardless of
+ // whether splice is disabled or not. The other side should then
+ // get a goodbye signal. Test for the goodbye signal.
+ msg := "bye"
+ go func() {
+ serverDown.(io.ReaderFrom).ReadFrom(serverUp)
+ io.WriteString(serverDown, msg)
+ serverDown.Close()
+ }()
+
+ buf := make([]byte, 3)
+ _, err := io.ReadFull(clientDown, buf)
+ if err != nil {
+ t.Errorf("clientDown: %v", err)
+ }
+ if string(buf) != msg {
+ t.Errorf("clientDown got %q, want %q", buf, msg)
+ }
+}
+
+func testSpliceIssue25985(t *testing.T, upNet, downNet string) {
+ front := newLocalListener(t, upNet)
+ defer front.Close()
+ back := newLocalListener(t, downNet)
+ defer back.Close()
+
+ var wg sync.WaitGroup
+ wg.Add(2)
+
+ proxy := func() {
+ src, err := front.Accept()
+ if err != nil {
+ return
+ }
+ dst, err := Dial(downNet, back.Addr().String())
+ if err != nil {
+ return
+ }
+ defer dst.Close()
+ defer src.Close()
+ go func() {
+ io.Copy(src, dst)
+ wg.Done()
+ }()
+ go func() {
+ io.Copy(dst, src)
+ wg.Done()
+ }()
+ }
+
+ go proxy()
+
+ toFront, err := Dial(upNet, front.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ io.WriteString(toFront, "foo")
+ toFront.Close()
+
+ fromProxy, err := back.Accept()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer fromProxy.Close()
+
+ _, err = io.ReadAll(fromProxy)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ wg.Wait()
+}
+
+func testSpliceNoUnixpacket(t *testing.T) {
+ clientUp, serverUp := spliceTestSocketPair(t, "unixpacket")
+ defer clientUp.Close()
+ defer serverUp.Close()
+ clientDown, serverDown := spliceTestSocketPair(t, "tcp")
+ defer clientDown.Close()
+ defer serverDown.Close()
+ // If splice called poll.Splice here, we'd get err == syscall.EINVAL
+ // and handled == false. If poll.Splice gets an EINVAL on the first
+ // try, it assumes the kernel it's running on doesn't support splice
+ // for unix sockets and returns handled == false. This works for our
+ // purposes by somewhat of an accident, but is not entirely correct.
+ //
+ // What we want is err == nil and handled == false, i.e. we never
+ // called poll.Splice, because we know the unix socket's network.
+ _, err, handled := splice(serverDown.(*TCPConn).fd, serverUp)
+ if err != nil || handled != false {
+ t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled)
+ }
+}
+
+func testSpliceNoUnixgram(t *testing.T) {
+ addr, err := ResolveUnixAddr("unixgram", testUnixAddr(t))
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer os.Remove(addr.Name)
+ up, err := ListenUnixgram("unixgram", addr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer up.Close()
+ clientDown, serverDown := spliceTestSocketPair(t, "tcp")
+ defer clientDown.Close()
+ defer serverDown.Close()
+ // Analogous to testSpliceNoUnixpacket.
+ _, err, handled := splice(serverDown.(*TCPConn).fd, up)
+ if err != nil || handled != false {
+ t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled)
+ }
+}
+
+func BenchmarkSplice(b *testing.B) {
+ testHookUninstaller.Do(uninstallTestHooks)
+
+ b.Run("tcp-to-tcp", func(b *testing.B) { benchSplice(b, "tcp", "tcp") })
+ b.Run("unix-to-tcp", func(b *testing.B) { benchSplice(b, "unix", "tcp") })
+}
+
+func benchSplice(b *testing.B, upNet, downNet string) {
+ for i := 0; i <= 10; i++ {
+ chunkSize := 1 << uint(i+10)
+ tc := spliceTestCase{
+ upNet: upNet,
+ downNet: downNet,
+ chunkSize: chunkSize,
+ }
+
+ b.Run(strconv.Itoa(chunkSize), tc.bench)
+ }
+}
+
+func (tc spliceTestCase) bench(b *testing.B) {
+ // To benchmark the genericReadFrom code path, set this to false.
+ useSplice := true
+
+ clientUp, serverUp := spliceTestSocketPair(b, tc.upNet)
+ defer serverUp.Close()
+
+ cleanup, err := startSpliceClient(clientUp, "w", tc.chunkSize, tc.chunkSize*b.N)
+ if err != nil {
+ b.Fatal(err)
+ }
+ defer cleanup()
+
+ clientDown, serverDown := spliceTestSocketPair(b, tc.downNet)
+ defer serverDown.Close()
+
+ cleanup, err = startSpliceClient(clientDown, "r", tc.chunkSize, tc.chunkSize*b.N)
+ if err != nil {
+ b.Fatal(err)
+ }
+ defer cleanup()
+
+ b.SetBytes(int64(tc.chunkSize))
+ b.ResetTimer()
+
+ if useSplice {
+ _, err := io.Copy(serverDown, serverUp)
+ if err != nil {
+ b.Fatal(err)
+ }
+ } else {
+ type onlyReader struct {
+ io.Reader
+ }
+ _, err := io.Copy(serverDown, onlyReader{serverUp})
+ if err != nil {
+ b.Fatal(err)
+ }
+ }
+}
+
+func spliceTestSocketPair(t testing.TB, net string) (client, server Conn) {
+ t.Helper()
+ ln := newLocalListener(t, net)
+ defer ln.Close()
+ var cerr, serr error
+ acceptDone := make(chan struct{})
+ go func() {
+ server, serr = ln.Accept()
+ acceptDone <- struct{}{}
+ }()
+ client, cerr = Dial(ln.Addr().Network(), ln.Addr().String())
+ <-acceptDone
+ if cerr != nil {
+ if server != nil {
+ server.Close()
+ }
+ t.Fatal(cerr)
+ }
+ if serr != nil {
+ if client != nil {
+ client.Close()
+ }
+ t.Fatal(serr)
+ }
+ return client, server
+}
+
+func startSpliceClient(conn Conn, op string, chunkSize, totalSize int) (func(), error) {
+ f, err := conn.(interface{ File() (*os.File, error) }).File()
+ if err != nil {
+ return nil, err
+ }
+
+ cmd := exec.Command(os.Args[0], os.Args[1:]...)
+ cmd.Env = []string{
+ "GO_NET_TEST_SPLICE=1",
+ "GO_NET_TEST_SPLICE_OP=" + op,
+ "GO_NET_TEST_SPLICE_CHUNK_SIZE=" + strconv.Itoa(chunkSize),
+ "GO_NET_TEST_SPLICE_TOTAL_SIZE=" + strconv.Itoa(totalSize),
+ "TMPDIR=" + os.Getenv("TMPDIR"),
+ }
+ cmd.ExtraFiles = append(cmd.ExtraFiles, f)
+ cmd.Stdout = os.Stdout
+ cmd.Stderr = os.Stderr
+
+ if err := cmd.Start(); err != nil {
+ return nil, err
+ }
+
+ donec := make(chan struct{})
+ go func() {
+ cmd.Wait()
+ conn.Close()
+ f.Close()
+ close(donec)
+ }()
+
+ return func() {
+ select {
+ case <-donec:
+ case <-time.After(5 * time.Second):
+ log.Printf("killing splice client after 5 second shutdown timeout")
+ cmd.Process.Kill()
+ select {
+ case <-donec:
+ case <-time.After(5 * time.Second):
+ log.Printf("splice client didn't die after 10 seconds")
+ }
+ }
+ }, nil
+}
+
+func init() {
+ if os.Getenv("GO_NET_TEST_SPLICE") == "" {
+ return
+ }
+ defer os.Exit(0)
+
+ f := os.NewFile(uintptr(3), "splice-test-conn")
+ defer f.Close()
+
+ conn, err := FileConn(f)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ var chunkSize int
+ if chunkSize, err = strconv.Atoi(os.Getenv("GO_NET_TEST_SPLICE_CHUNK_SIZE")); err != nil {
+ log.Fatal(err)
+ }
+ buf := make([]byte, chunkSize)
+
+ var totalSize int
+ if totalSize, err = strconv.Atoi(os.Getenv("GO_NET_TEST_SPLICE_TOTAL_SIZE")); err != nil {
+ log.Fatal(err)
+ }
+
+ var fn func([]byte) (int, error)
+ switch op := os.Getenv("GO_NET_TEST_SPLICE_OP"); op {
+ case "r":
+ fn = conn.Read
+ case "w":
+ defer conn.Close()
+
+ fn = conn.Write
+ default:
+ log.Fatalf("unknown op %q", op)
+ }
+
+ var n int
+ for count := 0; count < totalSize; count += n {
+ if count+chunkSize > totalSize {
+ buf = buf[:totalSize-count]
+ }
+
+ var err error
+ if n, err = fn(buf); err != nil {
+ return
+ }
+ }
+}