summaryrefslogtreecommitdiffstats
path: root/src/os/writeto_linux_test.go
blob: 5ffab88a2ab52630ba24c9855d4cda8c1515df2c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package os_test

import (
	"bytes"
	"internal/poll"
	"io"
	"math/rand"
	"net"
	. "os"
	"strconv"
	"syscall"
	"testing"
	"time"
)

func TestSendFile(t *testing.T) {
	sizes := []int{
		1,
		42,
		1025,
		syscall.Getpagesize() + 1,
		32769,
	}
	t.Run("sendfile-to-unix", func(t *testing.T) {
		for _, size := range sizes {
			t.Run(strconv.Itoa(size), func(t *testing.T) {
				testSendFile(t, "unix", int64(size))
			})
		}
	})
	t.Run("sendfile-to-tcp", func(t *testing.T) {
		for _, size := range sizes {
			t.Run(strconv.Itoa(size), func(t *testing.T) {
				testSendFile(t, "tcp", int64(size))
			})
		}
	})
}

func testSendFile(t *testing.T, proto string, size int64) {
	dst, src, recv, data, hook := newSendFileTest(t, proto, size)

	// Now call WriteTo (through io.Copy), which will hopefully call poll.SendFile
	n, err := io.Copy(dst, src)
	if err != nil {
		t.Fatalf("io.Copy error: %v", err)
	}

	// We should have called poll.Splice with the right file descriptor arguments.
	if n > 0 && !hook.called {
		t.Fatal("expected to called poll.SendFile")
	}
	if hook.called && hook.srcfd != int(src.Fd()) {
		t.Fatalf("wrong source file descriptor: got %d, want %d", hook.srcfd, src.Fd())
	}
	sc, ok := dst.(syscall.Conn)
	if !ok {
		t.Fatalf("destination is not a syscall.Conn")
	}
	rc, err := sc.SyscallConn()
	if err != nil {
		t.Fatalf("destination SyscallConn error: %v", err)
	}
	if err = rc.Control(func(fd uintptr) {
		if hook.called && hook.dstfd != int(fd) {
			t.Fatalf("wrong destination file descriptor: got %d, want %d", hook.dstfd, int(fd))
		}
	}); err != nil {
		t.Fatalf("destination Conn Control error: %v", err)
	}

	// Verify the data size and content.
	dataSize := len(data)
	dstData := make([]byte, dataSize)
	m, err := io.ReadFull(recv, dstData)
	if err != nil {
		t.Fatalf("server Conn Read error: %v", err)
	}
	if n != int64(dataSize) {
		t.Fatalf("data length mismatch for io.Copy, got %d, want %d", n, dataSize)
	}
	if m != dataSize {
		t.Fatalf("data length mismatch for net.Conn.Read, got %d, want %d", m, dataSize)
	}
	if !bytes.Equal(dstData, data) {
		t.Errorf("data mismatch, got %s, want %s", dstData, data)
	}
}

// newSendFileTest initializes a new test for sendfile.
//
// It creates source file and destination sockets, and populates the source file
// with random data of the specified size. It also hooks package os' call
// to poll.Sendfile and returns the hook so it can be inspected.
func newSendFileTest(t *testing.T, proto string, size int64) (net.Conn, *File, net.Conn, []byte, *sendFileHook) {
	t.Helper()

	hook := hookSendFile(t)

	client, server := createSocketPair(t, proto)
	tempFile, data := createTempFile(t, size)

	return client, tempFile, server, data, hook
}

func hookSendFile(t *testing.T) *sendFileHook {
	h := new(sendFileHook)
	h.install()
	t.Cleanup(h.uninstall)
	return h
}

type sendFileHook struct {
	called bool
	dstfd  int
	srcfd  int
	remain int64

	written int64
	handled bool
	err     error

	original func(dst *poll.FD, src int, remain int64) (int64, error, bool)
}

func (h *sendFileHook) install() {
	h.original = *PollSendFile
	*PollSendFile = func(dst *poll.FD, src int, remain int64) (int64, error, bool) {
		h.called = true
		h.dstfd = dst.Sysfd
		h.srcfd = src
		h.remain = remain
		h.written, h.err, h.handled = h.original(dst, src, remain)
		return h.written, h.err, h.handled
	}
}

func (h *sendFileHook) uninstall() {
	*PollSendFile = h.original
}

func createTempFile(t *testing.T, size int64) (*File, []byte) {
	f, err := CreateTemp(t.TempDir(), "writeto-sendfile-to-socket")
	if err != nil {
		t.Fatalf("failed to create temporary file: %v", err)
	}
	t.Cleanup(func() {
		f.Close()
	})

	randSeed := time.Now().Unix()
	t.Logf("random data seed: %d\n", randSeed)
	prng := rand.New(rand.NewSource(randSeed))
	data := make([]byte, size)
	prng.Read(data)
	if _, err := f.Write(data); err != nil {
		t.Fatalf("failed to create and feed the file: %v", err)
	}
	if err := f.Sync(); err != nil {
		t.Fatalf("failed to save the file: %v", err)
	}
	if _, err := f.Seek(0, io.SeekStart); err != nil {
		t.Fatalf("failed to rewind the file: %v", err)
	}

	return f, data
}