summaryrefslogtreecommitdiffstats
path: root/src/net
diff options
context:
space:
mode:
Diffstat (limited to 'src/net')
-rw-r--r--src/net/addrselect.go376
-rw-r--r--src/net/addrselect_test.go312
-rw-r--r--src/net/cgo_aix.go24
-rw-r--r--src/net/cgo_android.go12
-rw-r--r--src/net/cgo_bsd.go14
-rw-r--r--src/net/cgo_darwin.go9
-rw-r--r--src/net/cgo_linux.go20
-rw-r--r--src/net/cgo_netbsd.go14
-rw-r--r--src/net/cgo_openbsd.go14
-rw-r--r--src/net/cgo_resnew.go22
-rw-r--r--src/net/cgo_resold.go22
-rw-r--r--src/net/cgo_socknew.go32
-rw-r--r--src/net/cgo_sockold.go32
-rw-r--r--src/net/cgo_solaris.go15
-rw-r--r--src/net/cgo_stub.go40
-rw-r--r--src/net/cgo_unix.go370
-rw-r--r--src/net/cgo_unix_cgo.go80
-rw-r--r--src/net/cgo_unix_cgo_darwin.go21
-rw-r--r--src/net/cgo_unix_cgo_res.go38
-rw-r--r--src/net/cgo_unix_cgo_resn.go39
-rw-r--r--src/net/cgo_unix_syscall.go102
-rw-r--r--src/net/cgo_unix_test.go69
-rw-r--r--src/net/conf.go523
-rw-r--r--src/net/conf_test.go461
-rw-r--r--src/net/conn_test.go64
-rw-r--r--src/net/dial.go837
-rw-r--r--src/net/dial_test.go1088
-rw-r--r--src/net/dial_unix_test.go113
-rw-r--r--src/net/dnsclient.go228
-rw-r--r--src/net/dnsclient_test.go66
-rw-r--r--src/net/dnsclient_unix.go879
-rw-r--r--src/net/dnsclient_unix_test.go2600
-rw-r--r--src/net/dnsconfig.go45
-rw-r--r--src/net/dnsconfig_unix.go167
-rw-r--r--src/net/dnsconfig_unix_test.go314
-rw-r--r--src/net/dnsconfig_windows.go63
-rw-r--r--src/net/dnsname_test.go86
-rw-r--r--src/net/error_plan9.go9
-rw-r--r--src/net/error_plan9_test.go23
-rw-r--r--src/net/error_posix.go21
-rw-r--r--src/net/error_posix_test.go34
-rw-r--r--src/net/error_test.go810
-rw-r--r--src/net/error_unix.go16
-rw-r--r--src/net/error_unix_test.go39
-rw-r--r--src/net/error_windows.go14
-rw-r--r--src/net/error_windows_test.go29
-rw-r--r--src/net/example_test.go387
-rw-r--r--src/net/external_test.go168
-rw-r--r--src/net/fd_plan9.go187
-rw-r--r--src/net/fd_posix.go147
-rw-r--r--src/net/fd_unix.go206
-rw-r--r--src/net/fd_wasip1.go184
-rw-r--r--src/net/fd_windows.go205
-rw-r--r--src/net/file.go51
-rw-r--r--src/net/file_plan9.go135
-rw-r--r--src/net/file_stub.go16
-rw-r--r--src/net/file_test.go340
-rw-r--r--src/net/file_unix.go119
-rw-r--r--src/net/file_unix_test.go101
-rw-r--r--src/net/file_wasip1.go102
-rw-r--r--src/net/file_wasip1_test.go112
-rw-r--r--src/net/file_windows.go25
-rw-r--r--src/net/hook.go26
-rw-r--r--src/net/hook_plan9.go9
-rw-r--r--src/net/hook_unix.go20
-rw-r--r--src/net/hook_windows.go21
-rw-r--r--src/net/hosts.go165
-rw-r--r--src/net/hosts_test.go214
-rw-r--r--src/net/http/alpn_test.go132
-rw-r--r--src/net/http/cgi/child.go222
-rw-r--r--src/net/http/cgi/child_test.go208
-rw-r--r--src/net/http/cgi/host.go413
-rw-r--r--src/net/http/cgi/host_test.go577
-rw-r--r--src/net/http/cgi/integration_test.go272
-rw-r--r--src/net/http/cgi/plan9_test.go17
-rw-r--r--src/net/http/cgi/posix_test.go20
-rw-r--r--src/net/http/cgi/testdata/test.cgi95
-rw-r--r--src/net/http/client.go1038
-rw-r--r--src/net/http/client_test.go2144
-rw-r--r--src/net/http/clientserver_test.go1760
-rw-r--r--src/net/http/clone.go74
-rw-r--r--src/net/http/cookie.go468
-rw-r--r--src/net/http/cookie_test.go652
-rw-r--r--src/net/http/cookiejar/dummy_publicsuffix_test.go21
-rw-r--r--src/net/http/cookiejar/example_test.go65
-rw-r--r--src/net/http/cookiejar/jar.go547
-rw-r--r--src/net/http/cookiejar/jar_test.go1355
-rw-r--r--src/net/http/cookiejar/punycode.go151
-rw-r--r--src/net/http/cookiejar/punycode_test.go161
-rw-r--r--src/net/http/doc.go110
-rw-r--r--src/net/http/example_filesystem_test.go71
-rw-r--r--src/net/http/example_handle_test.go29
-rw-r--r--src/net/http/example_test.go195
-rw-r--r--src/net/http/export_test.go317
-rw-r--r--src/net/http/fcgi/child.go395
-rw-r--r--src/net/http/fcgi/fcgi.go277
-rw-r--r--src/net/http/fcgi/fcgi_test.go453
-rw-r--r--src/net/http/filetransport.go123
-rw-r--r--src/net/http/filetransport_test.go64
-rw-r--r--src/net/http/fs.go988
-rw-r--r--src/net/http/fs_test.go1561
-rw-r--r--src/net/http/h2_bundle.go11493
-rw-r--r--src/net/http/h2_error.go38
-rw-r--r--src/net/http/h2_error_test.go44
-rw-r--r--src/net/http/header.go280
-rw-r--r--src/net/http/header_test.go272
-rw-r--r--src/net/http/http.go165
-rw-r--r--src/net/http/http_test.go201
-rw-r--r--src/net/http/httptest/example_test.go99
-rw-r--r--src/net/http/httptest/httptest.go90
-rw-r--r--src/net/http/httptest/httptest_test.go179
-rw-r--r--src/net/http/httptest/recorder.go255
-rw-r--r--src/net/http/httptest/recorder_test.go371
-rw-r--r--src/net/http/httptest/server.go385
-rw-r--r--src/net/http/httptest/server_test.go294
-rw-r--r--src/net/http/httptrace/example_test.go29
-rw-r--r--src/net/http/httptrace/trace.go255
-rw-r--r--src/net/http/httptrace/trace_test.go89
-rw-r--r--src/net/http/httputil/dump.go337
-rw-r--r--src/net/http/httputil/dump_test.go532
-rw-r--r--src/net/http/httputil/example_test.go128
-rw-r--r--src/net/http/httputil/httputil.go41
-rw-r--r--src/net/http/httputil/persist.go431
-rw-r--r--src/net/http/httputil/reverseproxy.go834
-rw-r--r--src/net/http/httputil/reverseproxy_test.go1863
-rw-r--r--src/net/http/internal/ascii/print.go61
-rw-r--r--src/net/http/internal/ascii/print_test.go95
-rw-r--r--src/net/http/internal/chunked.go284
-rw-r--r--src/net/http/internal/chunked_test.go300
-rw-r--r--src/net/http/internal/testcert/testcert.go65
-rw-r--r--src/net/http/jar.go27
-rw-r--r--src/net/http/main_test.go175
-rw-r--r--src/net/http/method.go20
-rw-r--r--src/net/http/omithttp2.go79
-rw-r--r--src/net/http/pprof/pprof.go464
-rw-r--r--src/net/http/pprof/pprof_test.go263
-rw-r--r--src/net/http/proxy_test.go50
-rw-r--r--src/net/http/range_test.go79
-rw-r--r--src/net/http/readrequest_test.go475
-rw-r--r--src/net/http/request.go1488
-rw-r--r--src/net/http/request_test.go1397
-rw-r--r--src/net/http/requestwrite_test.go977
-rw-r--r--src/net/http/response.go371
-rw-r--r--src/net/http/response_test.go999
-rw-r--r--src/net/http/responsecontroller.go147
-rw-r--r--src/net/http/responsecontroller_test.go324
-rw-r--r--src/net/http/responsewrite_test.go290
-rw-r--r--src/net/http/roundtrip.go18
-rw-r--r--src/net/http/roundtrip_js.go360
-rw-r--r--src/net/http/serve_test.go6870
-rw-r--r--src/net/http/server.go3645
-rw-r--r--src/net/http/server_test.go98
-rw-r--r--src/net/http/sniff.go304
-rw-r--r--src/net/http/sniff_test.go282
-rw-r--r--src/net/http/socks_bundle.go473
-rw-r--r--src/net/http/status.go210
-rw-r--r--src/net/http/testdata/file1
-rw-r--r--src/net/http/testdata/index.html1
-rw-r--r--src/net/http/testdata/style.css1
-rw-r--r--src/net/http/transfer.go1124
-rw-r--r--src/net/http/transfer_test.go363
-rw-r--r--src/net/http/transport.go2942
-rw-r--r--src/net/http/transport_default_other.go16
-rw-r--r--src/net/http/transport_default_wasm.go16
-rw-r--r--src/net/http/transport_internal_test.go267
-rw-r--r--src/net/http/transport_test.go6752
-rw-r--r--src/net/http/triv.go140
-rw-r--r--src/net/interface.go259
-rw-r--r--src/net/interface_aix.go189
-rw-r--r--src/net/interface_bsd.go121
-rw-r--r--src/net/interface_bsd_test.go60
-rw-r--r--src/net/interface_bsdvar.go28
-rw-r--r--src/net/interface_darwin.go53
-rw-r--r--src/net/interface_freebsd.go53
-rw-r--r--src/net/interface_linux.go272
-rw-r--r--src/net/interface_linux_test.go133
-rw-r--r--src/net/interface_plan9.go200
-rw-r--r--src/net/interface_solaris.go92
-rw-r--r--src/net/interface_stub.go27
-rw-r--r--src/net/interface_test.go382
-rw-r--r--src/net/interface_unix_test.go215
-rw-r--r--src/net/interface_windows.go178
-rw-r--r--src/net/internal/socktest/main_test.go56
-rw-r--r--src/net/internal/socktest/main_unix_test.go24
-rw-r--r--src/net/internal/socktest/main_windows_test.go22
-rw-r--r--src/net/internal/socktest/switch.go169
-rw-r--r--src/net/internal/socktest/switch_posix.go58
-rw-r--r--src/net/internal/socktest/switch_stub.go16
-rw-r--r--src/net/internal/socktest/switch_unix.go29
-rw-r--r--src/net/internal/socktest/switch_windows.go29
-rw-r--r--src/net/internal/socktest/sys_cloexec.go42
-rw-r--r--src/net/internal/socktest/sys_unix.go193
-rw-r--r--src/net/internal/socktest/sys_windows.go221
-rw-r--r--src/net/ip.go542
-rw-r--r--src/net/ip_test.go784
-rw-r--r--src/net/iprawsock.go240
-rw-r--r--src/net/iprawsock_plan9.go34
-rw-r--r--src/net/iprawsock_posix.go159
-rw-r--r--src/net/iprawsock_test.go202
-rw-r--r--src/net/ipsock.go315
-rw-r--r--src/net/ipsock_plan9.go367
-rw-r--r--src/net/ipsock_plan9_test.go29
-rw-r--r--src/net/ipsock_posix.go232
-rw-r--r--src/net/ipsock_test.go282
-rw-r--r--src/net/listen_test.go750
-rw-r--r--src/net/lookup.go908
-rw-r--r--src/net/lookup_fake.go58
-rw-r--r--src/net/lookup_plan9.go389
-rw-r--r--src/net/lookup_test.go1464
-rw-r--r--src/net/lookup_unix.go149
-rw-r--r--src/net/lookup_windows.go455
-rw-r--r--src/net/lookup_windows_test.go340
-rw-r--r--src/net/mac.go86
-rw-r--r--src/net/mac_test.go109
-rw-r--r--src/net/mail/example_test.go77
-rw-r--r--src/net/mail/message.go915
-rw-r--r--src/net/mail/message_test.go1219
-rw-r--r--src/net/main_cloexec_test.go27
-rw-r--r--src/net/main_conf_test.go59
-rw-r--r--src/net/main_noconf_test.go22
-rw-r--r--src/net/main_plan9_test.go16
-rw-r--r--src/net/main_posix_test.go50
-rw-r--r--src/net/main_test.go209
-rw-r--r--src/net/main_unix_test.go55
-rw-r--r--src/net/main_windows_test.go45
-rw-r--r--src/net/mockserver_test.go510
-rw-r--r--src/net/mptcpsock_linux.go127
-rw-r--r--src/net/mptcpsock_linux_test.go192
-rw-r--r--src/net/mptcpsock_stub.go23
-rw-r--r--src/net/net.go767
-rw-r--r--src/net/net_fake.go406
-rw-r--r--src/net/net_fake_js.go36
-rw-r--r--src/net/net_fake_test.go203
-rw-r--r--src/net/net_test.go593
-rw-r--r--src/net/net_windows_test.go631
-rw-r--r--src/net/netcgo_off.go9
-rw-r--r--src/net/netcgo_on.go9
-rw-r--r--src/net/netgo_netcgo.go14
-rw-r--r--src/net/netgo_off.go9
-rw-r--r--src/net/netgo_on.go9
-rw-r--r--src/net/netip/export_test.go30
-rw-r--r--src/net/netip/fuzz_test.go351
-rw-r--r--src/net/netip/inlining_test.go102
-rw-r--r--src/net/netip/leaf_alts.go54
-rw-r--r--src/net/netip/netip.go1482
-rw-r--r--src/net/netip/netip_pkg_test.go365
-rw-r--r--src/net/netip/netip_test.go2029
-rw-r--r--src/net/netip/slow_test.go190
-rw-r--r--src/net/netip/uint128.go81
-rw-r--r--src/net/netip/uint128_test.go89
-rw-r--r--src/net/nss.go249
-rw-r--r--src/net/nss_test.go172
-rw-r--r--src/net/packetconn_test.go151
-rw-r--r--src/net/parse.go319
-rw-r--r--src/net/parse_test.go74
-rw-r--r--src/net/pipe.go238
-rw-r--r--src/net/pipe_test.go49
-rw-r--r--src/net/platform_test.go178
-rw-r--r--src/net/port.go62
-rw-r--r--src/net/port_test.go52
-rw-r--r--src/net/port_unix.go57
-rw-r--r--src/net/protoconn_test.go350
-rw-r--r--src/net/rawconn.go96
-rw-r--r--src/net/rawconn_stub_test.go28
-rw-r--r--src/net/rawconn_test.go211
-rw-r--r--src/net/rawconn_unix_test.go115
-rw-r--r--src/net/rawconn_windows_test.go116
-rw-r--r--src/net/resolverdialfunc_test.go327
-rw-r--r--src/net/rpc/client.go323
-rw-r--r--src/net/rpc/client_test.go87
-rw-r--r--src/net/rpc/debug.go90
-rw-r--r--src/net/rpc/jsonrpc/all_test.go352
-rw-r--r--src/net/rpc/jsonrpc/client.go124
-rw-r--r--src/net/rpc/jsonrpc/server.go134
-rw-r--r--src/net/rpc/server.go725
-rw-r--r--src/net/rpc/server_test.go839
-rw-r--r--src/net/sendfile_linux.go53
-rw-r--r--src/net/sendfile_linux_test.go77
-rw-r--r--src/net/sendfile_stub.go13
-rw-r--r--src/net/sendfile_test.go364
-rw-r--r--src/net/sendfile_unix_alt.go85
-rw-r--r--src/net/sendfile_windows.go47
-rw-r--r--src/net/server_test.go383
-rw-r--r--src/net/smtp/auth.go109
-rw-r--r--src/net/smtp/example_test.go83
-rw-r--r--src/net/smtp/smtp.go432
-rw-r--r--src/net/smtp/smtp_test.go1144
-rw-r--r--src/net/sock_bsd.go39
-rw-r--r--src/net/sock_cloexec.go48
-rw-r--r--src/net/sock_linux.go54
-rw-r--r--src/net/sock_linux_test.go23
-rw-r--r--src/net/sock_plan9.go10
-rw-r--r--src/net/sock_posix.go259
-rw-r--r--src/net/sock_stub.go15
-rw-r--r--src/net/sock_windows.go41
-rw-r--r--src/net/sockaddr_posix.go34
-rw-r--r--src/net/sockopt_aix.go39
-rw-r--r--src/net/sockopt_bsd.go57
-rw-r--r--src/net/sockopt_linux.go35
-rw-r--r--src/net/sockopt_plan9.go19
-rw-r--r--src/net/sockopt_posix.go134
-rw-r--r--src/net/sockopt_solaris.go35
-rw-r--r--src/net/sockopt_stub.go37
-rw-r--r--src/net/sockopt_windows.go40
-rw-r--r--src/net/sockoptip_bsdvar.go30
-rw-r--r--src/net/sockoptip_linux.go27
-rw-r--r--src/net/sockoptip_posix.go49
-rw-r--r--src/net/sockoptip_stub.go33
-rw-r--r--src/net/sockoptip_windows.go30
-rw-r--r--src/net/splice_linux.go44
-rw-r--r--src/net/splice_stub.go13
-rw-r--r--src/net/splice_test.go532
-rw-r--r--src/net/sys_cloexec.go36
-rw-r--r--src/net/tcpsock.go398
-rw-r--r--src/net/tcpsock_plan9.go86
-rw-r--r--src/net/tcpsock_posix.go187
-rw-r--r--src/net/tcpsock_test.go785
-rw-r--r--src/net/tcpsock_unix_test.go112
-rw-r--r--src/net/tcpsockopt_darwin.go25
-rw-r--r--src/net/tcpsockopt_dragonfly.go23
-rw-r--r--src/net/tcpsockopt_openbsd.go16
-rw-r--r--src/net/tcpsockopt_plan9.go24
-rw-r--r--src/net/tcpsockopt_posix.go18
-rw-r--r--src/net/tcpsockopt_solaris.go32
-rw-r--r--src/net/tcpsockopt_stub.go20
-rw-r--r--src/net/tcpsockopt_unix.go24
-rw-r--r--src/net/tcpsockopt_windows.go29
-rw-r--r--src/net/testdata/aliases8
-rw-r--r--src/net/testdata/case-hosts2
-rw-r--r--src/net/testdata/domain-resolv.conf5
-rw-r--r--src/net/testdata/empty-resolv.conf1
-rw-r--r--src/net/testdata/freebsd-usevc-resolv.conf1
-rw-r--r--src/net/testdata/hosts11
-rw-r--r--src/net/testdata/igmp24
-rw-r--r--src/net/testdata/igmp618
-rw-r--r--src/net/testdata/invalid-ndots-resolv.conf1
-rw-r--r--src/net/testdata/ipv4-hosts8
-rw-r--r--src/net/testdata/ipv6-hosts11
-rw-r--r--src/net/testdata/large-ndots-resolv.conf1
-rw-r--r--src/net/testdata/linux-use-vc-resolv.conf1
-rw-r--r--src/net/testdata/negative-ndots-resolv.conf1
-rw-r--r--src/net/testdata/openbsd-resolv.conf5
-rw-r--r--src/net/testdata/openbsd-tcp-resolv.conf1
-rw-r--r--src/net/testdata/resolv.conf8
-rw-r--r--src/net/testdata/search-resolv.conf5
-rw-r--r--src/net/testdata/search-single-dot-resolv.conf5
-rw-r--r--src/net/testdata/single-request-reopen-resolv.conf1
-rw-r--r--src/net/testdata/single-request-resolv.conf1
-rw-r--r--src/net/testdata/singleline-hosts1
-rw-r--r--src/net/textproto/header.go56
-rw-r--r--src/net/textproto/header_test.go54
-rw-r--r--src/net/textproto/pipeline.go118
-rw-r--r--src/net/textproto/reader.go840
-rw-r--r--src/net/textproto/reader_test.go537
-rw-r--r--src/net/textproto/textproto.go152
-rw-r--r--src/net/textproto/writer.go119
-rw-r--r--src/net/textproto/writer_test.go61
-rw-r--r--src/net/timeout_test.go1161
-rw-r--r--src/net/udpsock.go368
-rw-r--r--src/net/udpsock_plan9.go182
-rw-r--r--src/net/udpsock_plan9_test.go69
-rw-r--r--src/net/udpsock_posix.go287
-rw-r--r--src/net/udpsock_test.go666
-rw-r--r--src/net/unixsock.go352
-rw-r--r--src/net/unixsock_linux_test.go104
-rw-r--r--src/net/unixsock_plan9.go51
-rw-r--r--src/net/unixsock_posix.go245
-rw-r--r--src/net/unixsock_readmsg_cloexec.go30
-rw-r--r--src/net/unixsock_readmsg_cmsg_cloexec.go13
-rw-r--r--src/net/unixsock_readmsg_other.go11
-rw-r--r--src/net/unixsock_readmsg_test.go105
-rw-r--r--src/net/unixsock_test.go463
-rw-r--r--src/net/unixsock_windows_test.go97
-rw-r--r--src/net/url/example_test.go374
-rw-r--r--src/net/url/url.go1265
-rw-r--r--src/net/url/url_test.go2210
-rw-r--r--src/net/write_unix_test.go66
-rw-r--r--src/net/writev_test.go224
-rw-r--r--src/net/writev_unix.go29
379 files changed, 123388 insertions, 0 deletions
diff --git a/src/net/addrselect.go b/src/net/addrselect.go
new file mode 100644
index 0000000..4f07032
--- /dev/null
+++ b/src/net/addrselect.go
@@ -0,0 +1,376 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Minimal RFC 6724 address selection.
+
+package net
+
+import (
+ "net/netip"
+ "sort"
+)
+
+func sortByRFC6724(addrs []IPAddr) {
+ if len(addrs) < 2 {
+ return
+ }
+ sortByRFC6724withSrcs(addrs, srcAddrs(addrs))
+}
+
+func sortByRFC6724withSrcs(addrs []IPAddr, srcs []netip.Addr) {
+ if len(addrs) != len(srcs) {
+ panic("internal error")
+ }
+ addrAttr := make([]ipAttr, len(addrs))
+ srcAttr := make([]ipAttr, len(srcs))
+ for i, v := range addrs {
+ addrAttrIP, _ := netip.AddrFromSlice(v.IP)
+ addrAttr[i] = ipAttrOf(addrAttrIP)
+ srcAttr[i] = ipAttrOf(srcs[i])
+ }
+ sort.Stable(&byRFC6724{
+ addrs: addrs,
+ addrAttr: addrAttr,
+ srcs: srcs,
+ srcAttr: srcAttr,
+ })
+}
+
+// srcAddrs tries to UDP-connect to each address to see if it has a
+// route. (This doesn't send any packets). The destination port
+// number is irrelevant.
+func srcAddrs(addrs []IPAddr) []netip.Addr {
+ srcs := make([]netip.Addr, len(addrs))
+ dst := UDPAddr{Port: 9}
+ for i := range addrs {
+ dst.IP = addrs[i].IP
+ dst.Zone = addrs[i].Zone
+ c, err := DialUDP("udp", nil, &dst)
+ if err == nil {
+ if src, ok := c.LocalAddr().(*UDPAddr); ok {
+ srcs[i], _ = netip.AddrFromSlice(src.IP)
+ }
+ c.Close()
+ }
+ }
+ return srcs
+}
+
+type ipAttr struct {
+ Scope scope
+ Precedence uint8
+ Label uint8
+}
+
+func ipAttrOf(ip netip.Addr) ipAttr {
+ if !ip.IsValid() {
+ return ipAttr{}
+ }
+ match := rfc6724policyTable.Classify(ip)
+ return ipAttr{
+ Scope: classifyScope(ip),
+ Precedence: match.Precedence,
+ Label: match.Label,
+ }
+}
+
+type byRFC6724 struct {
+ addrs []IPAddr // addrs to sort
+ addrAttr []ipAttr
+ srcs []netip.Addr // or not valid addr if unreachable
+ srcAttr []ipAttr
+}
+
+func (s *byRFC6724) Len() int { return len(s.addrs) }
+
+func (s *byRFC6724) Swap(i, j int) {
+ s.addrs[i], s.addrs[j] = s.addrs[j], s.addrs[i]
+ s.srcs[i], s.srcs[j] = s.srcs[j], s.srcs[i]
+ s.addrAttr[i], s.addrAttr[j] = s.addrAttr[j], s.addrAttr[i]
+ s.srcAttr[i], s.srcAttr[j] = s.srcAttr[j], s.srcAttr[i]
+}
+
+// Less reports whether i is a better destination address for this
+// host than j.
+//
+// The algorithm and variable names comes from RFC 6724 section 6.
+func (s *byRFC6724) Less(i, j int) bool {
+ DA := s.addrs[i].IP
+ DB := s.addrs[j].IP
+ SourceDA := s.srcs[i]
+ SourceDB := s.srcs[j]
+ attrDA := &s.addrAttr[i]
+ attrDB := &s.addrAttr[j]
+ attrSourceDA := &s.srcAttr[i]
+ attrSourceDB := &s.srcAttr[j]
+
+ const preferDA = true
+ const preferDB = false
+
+ // Rule 1: Avoid unusable destinations.
+ // If DB is known to be unreachable or if Source(DB) is undefined, then
+ // prefer DA. Similarly, if DA is known to be unreachable or if
+ // Source(DA) is undefined, then prefer DB.
+ if !SourceDA.IsValid() && !SourceDB.IsValid() {
+ return false // "equal"
+ }
+ if !SourceDB.IsValid() {
+ return preferDA
+ }
+ if !SourceDA.IsValid() {
+ return preferDB
+ }
+
+ // Rule 2: Prefer matching scope.
+ // If Scope(DA) = Scope(Source(DA)) and Scope(DB) <> Scope(Source(DB)),
+ // then prefer DA. Similarly, if Scope(DA) <> Scope(Source(DA)) and
+ // Scope(DB) = Scope(Source(DB)), then prefer DB.
+ if attrDA.Scope == attrSourceDA.Scope && attrDB.Scope != attrSourceDB.Scope {
+ return preferDA
+ }
+ if attrDA.Scope != attrSourceDA.Scope && attrDB.Scope == attrSourceDB.Scope {
+ return preferDB
+ }
+
+ // Rule 3: Avoid deprecated addresses.
+ // If Source(DA) is deprecated and Source(DB) is not, then prefer DB.
+ // Similarly, if Source(DA) is not deprecated and Source(DB) is
+ // deprecated, then prefer DA.
+
+ // TODO(bradfitz): implement? low priority for now.
+
+ // Rule 4: Prefer home addresses.
+ // If Source(DA) is simultaneously a home address and care-of address
+ // and Source(DB) is not, then prefer DA. Similarly, if Source(DB) is
+ // simultaneously a home address and care-of address and Source(DA) is
+ // not, then prefer DB.
+
+ // TODO(bradfitz): implement? low priority for now.
+
+ // Rule 5: Prefer matching label.
+ // If Label(Source(DA)) = Label(DA) and Label(Source(DB)) <> Label(DB),
+ // then prefer DA. Similarly, if Label(Source(DA)) <> Label(DA) and
+ // Label(Source(DB)) = Label(DB), then prefer DB.
+ if attrSourceDA.Label == attrDA.Label &&
+ attrSourceDB.Label != attrDB.Label {
+ return preferDA
+ }
+ if attrSourceDA.Label != attrDA.Label &&
+ attrSourceDB.Label == attrDB.Label {
+ return preferDB
+ }
+
+ // Rule 6: Prefer higher precedence.
+ // If Precedence(DA) > Precedence(DB), then prefer DA. Similarly, if
+ // Precedence(DA) < Precedence(DB), then prefer DB.
+ if attrDA.Precedence > attrDB.Precedence {
+ return preferDA
+ }
+ if attrDA.Precedence < attrDB.Precedence {
+ return preferDB
+ }
+
+ // Rule 7: Prefer native transport.
+ // If DA is reached via an encapsulating transition mechanism (e.g.,
+ // IPv6 in IPv4) and DB is not, then prefer DB. Similarly, if DB is
+ // reached via encapsulation and DA is not, then prefer DA.
+
+ // TODO(bradfitz): implement? low priority for now.
+
+ // Rule 8: Prefer smaller scope.
+ // If Scope(DA) < Scope(DB), then prefer DA. Similarly, if Scope(DA) >
+ // Scope(DB), then prefer DB.
+ if attrDA.Scope < attrDB.Scope {
+ return preferDA
+ }
+ if attrDA.Scope > attrDB.Scope {
+ return preferDB
+ }
+
+ // Rule 9: Use the longest matching prefix.
+ // When DA and DB belong to the same address family (both are IPv6 or
+ // both are IPv4 [but see below]): If CommonPrefixLen(Source(DA), DA) >
+ // CommonPrefixLen(Source(DB), DB), then prefer DA. Similarly, if
+ // CommonPrefixLen(Source(DA), DA) < CommonPrefixLen(Source(DB), DB),
+ // then prefer DB.
+ //
+ // However, applying this rule to IPv4 addresses causes
+ // problems (see issues 13283 and 18518), so limit to IPv6.
+ if DA.To4() == nil && DB.To4() == nil {
+ commonA := commonPrefixLen(SourceDA, DA)
+ commonB := commonPrefixLen(SourceDB, DB)
+
+ if commonA > commonB {
+ return preferDA
+ }
+ if commonA < commonB {
+ return preferDB
+ }
+ }
+
+ // Rule 10: Otherwise, leave the order unchanged.
+ // If DA preceded DB in the original list, prefer DA.
+ // Otherwise, prefer DB.
+ return false // "equal"
+}
+
+type policyTableEntry struct {
+ Prefix netip.Prefix
+ Precedence uint8
+ Label uint8
+}
+
+type policyTable []policyTableEntry
+
+// RFC 6724 section 2.1.
+// Items are sorted by the size of their Prefix.Mask.Size,
+var rfc6724policyTable = policyTable{
+ {
+ // "::1/128"
+ Prefix: netip.PrefixFrom(netip.AddrFrom16([16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01}), 128),
+ Precedence: 50,
+ Label: 0,
+ },
+ {
+ // "::ffff:0:0/96"
+ // IPv4-compatible, etc.
+ Prefix: netip.PrefixFrom(netip.AddrFrom16([16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff}), 96),
+ Precedence: 35,
+ Label: 4,
+ },
+ {
+ // "::/96"
+ Prefix: netip.PrefixFrom(netip.AddrFrom16([16]byte{}), 96),
+ Precedence: 1,
+ Label: 3,
+ },
+ {
+ // "2001::/32"
+ // Teredo
+ Prefix: netip.PrefixFrom(netip.AddrFrom16([16]byte{0x20, 0x01}), 32),
+ Precedence: 5,
+ Label: 5,
+ },
+ {
+ // "2002::/16"
+ // 6to4
+ Prefix: netip.PrefixFrom(netip.AddrFrom16([16]byte{0x20, 0x02}), 16),
+ Precedence: 30,
+ Label: 2,
+ },
+ {
+ // "3ffe::/16"
+ Prefix: netip.PrefixFrom(netip.AddrFrom16([16]byte{0x3f, 0xfe}), 16),
+ Precedence: 1,
+ Label: 12,
+ },
+ {
+ // "fec0::/10"
+ Prefix: netip.PrefixFrom(netip.AddrFrom16([16]byte{0xfe, 0xc0}), 10),
+ Precedence: 1,
+ Label: 11,
+ },
+ {
+ // "fc00::/7"
+ Prefix: netip.PrefixFrom(netip.AddrFrom16([16]byte{0xfc}), 7),
+ Precedence: 3,
+ Label: 13,
+ },
+ {
+ // "::/0"
+ Prefix: netip.PrefixFrom(netip.AddrFrom16([16]byte{}), 0),
+ Precedence: 40,
+ Label: 1,
+ },
+}
+
+// Classify returns the policyTableEntry of the entry with the longest
+// matching prefix that contains ip.
+// The table t must be sorted from largest mask size to smallest.
+func (t policyTable) Classify(ip netip.Addr) policyTableEntry {
+ // Prefix.Contains() will not match an IPv6 prefix for an IPv4 address.
+ if ip.Is4() {
+ ip = netip.AddrFrom16(ip.As16())
+ }
+ for _, ent := range t {
+ if ent.Prefix.Contains(ip) {
+ return ent
+ }
+ }
+ return policyTableEntry{}
+}
+
+// RFC 6724 section 3.1.
+type scope uint8
+
+const (
+ scopeInterfaceLocal scope = 0x1
+ scopeLinkLocal scope = 0x2
+ scopeAdminLocal scope = 0x4
+ scopeSiteLocal scope = 0x5
+ scopeOrgLocal scope = 0x8
+ scopeGlobal scope = 0xe
+)
+
+func classifyScope(ip netip.Addr) scope {
+ if ip.IsLoopback() || ip.IsLinkLocalUnicast() {
+ return scopeLinkLocal
+ }
+ ipv6 := ip.Is6() && !ip.Is4In6()
+ ipv6AsBytes := ip.As16()
+ if ipv6 && ip.IsMulticast() {
+ return scope(ipv6AsBytes[1] & 0xf)
+ }
+ // Site-local addresses are defined in RFC 3513 section 2.5.6
+ // (and deprecated in RFC 3879).
+ if ipv6 && ipv6AsBytes[0] == 0xfe && ipv6AsBytes[1]&0xc0 == 0xc0 {
+ return scopeSiteLocal
+ }
+ return scopeGlobal
+}
+
+// commonPrefixLen reports the length of the longest prefix (looking
+// at the most significant, or leftmost, bits) that the
+// two addresses have in common, up to the length of a's prefix (i.e.,
+// the portion of the address not including the interface ID).
+//
+// If a or b is an IPv4 address as an IPv6 address, the IPv4 addresses
+// are compared (with max common prefix length of 32).
+// If a and b are different IP versions, 0 is returned.
+//
+// See https://tools.ietf.org/html/rfc6724#section-2.2
+func commonPrefixLen(a netip.Addr, b IP) (cpl int) {
+ if b4 := b.To4(); b4 != nil {
+ b = b4
+ }
+ aAsSlice := a.AsSlice()
+ if len(aAsSlice) != len(b) {
+ return 0
+ }
+ // If IPv6, only up to the prefix (first 64 bits)
+ if len(aAsSlice) > 8 {
+ aAsSlice = aAsSlice[:8]
+ b = b[:8]
+ }
+ for len(aAsSlice) > 0 {
+ if aAsSlice[0] == b[0] {
+ cpl += 8
+ aAsSlice = aAsSlice[1:]
+ b = b[1:]
+ continue
+ }
+ bits := 8
+ ab, bb := aAsSlice[0], b[0]
+ for {
+ ab >>= 1
+ bb >>= 1
+ bits--
+ if ab == bb {
+ cpl += bits
+ return
+ }
+ }
+ }
+ return
+}
diff --git a/src/net/addrselect_test.go b/src/net/addrselect_test.go
new file mode 100644
index 0000000..7e8134d
--- /dev/null
+++ b/src/net/addrselect_test.go
@@ -0,0 +1,312 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
+
+package net
+
+import (
+ "net/netip"
+ "reflect"
+ "testing"
+)
+
+func TestSortByRFC6724(t *testing.T) {
+ tests := []struct {
+ in []IPAddr
+ srcs []netip.Addr
+ want []IPAddr
+ reverse bool // also test it starting backwards
+ }{
+ // Examples from RFC 6724 section 10.2:
+
+ // Prefer matching scope.
+ {
+ in: []IPAddr{
+ {IP: ParseIP("2001:db8:1::1")},
+ {IP: ParseIP("198.51.100.121")},
+ },
+ srcs: []netip.Addr{
+ netip.MustParseAddr("2001:db8:1::2"),
+ netip.MustParseAddr("169.254.13.78"),
+ },
+ want: []IPAddr{
+ {IP: ParseIP("2001:db8:1::1")},
+ {IP: ParseIP("198.51.100.121")},
+ },
+ reverse: true,
+ },
+
+ // Prefer matching scope.
+ {
+ in: []IPAddr{
+ {IP: ParseIP("2001:db8:1::1")},
+ {IP: ParseIP("198.51.100.121")},
+ },
+ srcs: []netip.Addr{
+ netip.MustParseAddr("fe80::1"),
+ netip.MustParseAddr("198.51.100.117"),
+ },
+ want: []IPAddr{
+ {IP: ParseIP("198.51.100.121")},
+ {IP: ParseIP("2001:db8:1::1")},
+ },
+ reverse: true,
+ },
+
+ // Prefer higher precedence.
+ {
+ in: []IPAddr{
+ {IP: ParseIP("2001:db8:1::1")},
+ {IP: ParseIP("10.1.2.3")},
+ },
+ srcs: []netip.Addr{
+ netip.MustParseAddr("2001:db8:1::2"),
+ netip.MustParseAddr("10.1.2.4"),
+ },
+ want: []IPAddr{
+ {IP: ParseIP("2001:db8:1::1")},
+ {IP: ParseIP("10.1.2.3")},
+ },
+ reverse: true,
+ },
+
+ // Prefer smaller scope.
+ {
+ in: []IPAddr{
+ {IP: ParseIP("2001:db8:1::1")},
+ {IP: ParseIP("fe80::1")},
+ },
+ srcs: []netip.Addr{
+ netip.MustParseAddr("2001:db8:1::2"),
+ netip.MustParseAddr("fe80::2"),
+ },
+ want: []IPAddr{
+ {IP: ParseIP("fe80::1")},
+ {IP: ParseIP("2001:db8:1::1")},
+ },
+ reverse: true,
+ },
+
+ // Issue 13283. Having a 10/8 source address does not
+ // mean we should prefer 23/8 destination addresses.
+ {
+ in: []IPAddr{
+ {IP: ParseIP("54.83.193.112")},
+ {IP: ParseIP("184.72.238.214")},
+ {IP: ParseIP("23.23.172.185")},
+ {IP: ParseIP("75.101.148.21")},
+ {IP: ParseIP("23.23.134.56")},
+ {IP: ParseIP("23.21.50.150")},
+ },
+ srcs: []netip.Addr{
+ netip.MustParseAddr("10.2.3.4"),
+ netip.MustParseAddr("10.2.3.4"),
+ netip.MustParseAddr("10.2.3.4"),
+ netip.MustParseAddr("10.2.3.4"),
+ netip.MustParseAddr("10.2.3.4"),
+ netip.MustParseAddr("10.2.3.4"),
+ },
+ want: []IPAddr{
+ {IP: ParseIP("54.83.193.112")},
+ {IP: ParseIP("184.72.238.214")},
+ {IP: ParseIP("23.23.172.185")},
+ {IP: ParseIP("75.101.148.21")},
+ {IP: ParseIP("23.23.134.56")},
+ {IP: ParseIP("23.21.50.150")},
+ },
+ reverse: false,
+ },
+ }
+ for i, tt := range tests {
+ inCopy := make([]IPAddr, len(tt.in))
+ copy(inCopy, tt.in)
+ srcCopy := make([]netip.Addr, len(tt.in))
+ copy(srcCopy, tt.srcs)
+ sortByRFC6724withSrcs(inCopy, srcCopy)
+ if !reflect.DeepEqual(inCopy, tt.want) {
+ t.Errorf("test %d:\nin = %s\ngot: %s\nwant: %s\n", i, tt.in, inCopy, tt.want)
+ }
+ if tt.reverse {
+ copy(inCopy, tt.in)
+ copy(srcCopy, tt.srcs)
+ for j := 0; j < len(inCopy)/2; j++ {
+ k := len(inCopy) - j - 1
+ inCopy[j], inCopy[k] = inCopy[k], inCopy[j]
+ srcCopy[j], srcCopy[k] = srcCopy[k], srcCopy[j]
+ }
+ sortByRFC6724withSrcs(inCopy, srcCopy)
+ if !reflect.DeepEqual(inCopy, tt.want) {
+ t.Errorf("test %d, starting backwards:\nin = %s\ngot: %s\nwant: %s\n", i, tt.in, inCopy, tt.want)
+ }
+ }
+
+ }
+
+}
+
+func TestRFC6724PolicyTableOrder(t *testing.T) {
+ for i := 0; i < len(rfc6724policyTable)-1; i++ {
+ if !(rfc6724policyTable[i].Prefix.Bits() >= rfc6724policyTable[i+1].Prefix.Bits()) {
+ t.Errorf("rfc6724policyTable item number %d sorted in wrong order = %d bits, next item = %d bits;", i, rfc6724policyTable[i].Prefix.Bits(), rfc6724policyTable[i+1].Prefix.Bits())
+ }
+ }
+}
+
+func TestRFC6724PolicyTableContent(t *testing.T) {
+ expectedRfc6724policyTable := policyTable{
+ {
+ Prefix: netip.MustParsePrefix("::1/128"),
+ Precedence: 50,
+ Label: 0,
+ },
+ {
+ Prefix: netip.MustParsePrefix("::ffff:0:0/96"),
+ Precedence: 35,
+ Label: 4,
+ },
+ {
+ Prefix: netip.MustParsePrefix("::/96"),
+ Precedence: 1,
+ Label: 3,
+ },
+ {
+ Prefix: netip.MustParsePrefix("2001::/32"),
+ Precedence: 5,
+ Label: 5,
+ },
+ {
+ Prefix: netip.MustParsePrefix("2002::/16"),
+ Precedence: 30,
+ Label: 2,
+ },
+ {
+ Prefix: netip.MustParsePrefix("3ffe::/16"),
+ Precedence: 1,
+ Label: 12,
+ },
+ {
+ Prefix: netip.MustParsePrefix("fec0::/10"),
+ Precedence: 1,
+ Label: 11,
+ },
+ {
+ Prefix: netip.MustParsePrefix("fc00::/7"),
+ Precedence: 3,
+ Label: 13,
+ },
+ {
+ Prefix: netip.MustParsePrefix("::/0"),
+ Precedence: 40,
+ Label: 1,
+ },
+ }
+ if !reflect.DeepEqual(rfc6724policyTable, expectedRfc6724policyTable) {
+ t.Errorf("rfc6724policyTable has wrong contend = %v; want %v", rfc6724policyTable, expectedRfc6724policyTable)
+ }
+}
+
+func TestRFC6724PolicyTableClassify(t *testing.T) {
+ tests := []struct {
+ ip netip.Addr
+ want policyTableEntry
+ }{
+ {
+ ip: netip.MustParseAddr("127.0.0.1"),
+ want: policyTableEntry{
+ Prefix: netip.MustParsePrefix("::ffff:0:0/96"),
+ Precedence: 35,
+ Label: 4,
+ },
+ },
+ {
+ ip: netip.MustParseAddr("2601:645:8002:a500:986f:1db8:c836:bd65"),
+ want: policyTableEntry{
+ Prefix: netip.MustParsePrefix("::/0"),
+ Precedence: 40,
+ Label: 1,
+ },
+ },
+ {
+ ip: netip.MustParseAddr("::1"),
+ want: policyTableEntry{
+ Prefix: netip.MustParsePrefix("::1/128"),
+ Precedence: 50,
+ Label: 0,
+ },
+ },
+ {
+ ip: netip.MustParseAddr("2002::ab12"),
+ want: policyTableEntry{
+ Prefix: netip.MustParsePrefix("2002::/16"),
+ Precedence: 30,
+ Label: 2,
+ },
+ },
+ }
+ for i, tt := range tests {
+ got := rfc6724policyTable.Classify(tt.ip)
+ if !reflect.DeepEqual(got, tt.want) {
+ t.Errorf("%d. Classify(%s) = %v; want %v", i, tt.ip, got, tt.want)
+ }
+ }
+}
+
+func TestRFC6724ClassifyScope(t *testing.T) {
+ tests := []struct {
+ ip netip.Addr
+ want scope
+ }{
+ {netip.MustParseAddr("127.0.0.1"), scopeLinkLocal}, // rfc6724#section-3.2
+ {netip.MustParseAddr("::1"), scopeLinkLocal}, // rfc4007#section-4
+ {netip.MustParseAddr("169.254.1.2"), scopeLinkLocal}, // rfc6724#section-3.2
+ {netip.MustParseAddr("fec0::1"), scopeSiteLocal},
+ {netip.MustParseAddr("8.8.8.8"), scopeGlobal},
+
+ {netip.MustParseAddr("ff02::"), scopeLinkLocal}, // IPv6 multicast
+ {netip.MustParseAddr("ff05::"), scopeSiteLocal}, // IPv6 multicast
+ {netip.MustParseAddr("ff04::"), scopeAdminLocal}, // IPv6 multicast
+ {netip.MustParseAddr("ff0e::"), scopeGlobal}, // IPv6 multicast
+
+ {netip.AddrFrom16([16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xe0, 0, 0, 0}), scopeGlobal}, // IPv4 link-local multicast as 16 bytes
+ {netip.AddrFrom16([16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xe0, 2, 2, 2}), scopeGlobal}, // IPv4 global multicast as 16 bytes
+ {netip.AddrFrom4([4]byte{0xe0, 0, 0, 0}), scopeGlobal}, // IPv4 link-local multicast as 4 bytes
+ {netip.AddrFrom4([4]byte{0xe0, 2, 2, 2}), scopeGlobal}, // IPv4 global multicast as 4 bytes
+ }
+ for i, tt := range tests {
+ got := classifyScope(tt.ip)
+ if got != tt.want {
+ t.Errorf("%d. classifyScope(%s) = %x; want %x", i, tt.ip, got, tt.want)
+ }
+ }
+}
+
+func TestRFC6724CommonPrefixLength(t *testing.T) {
+ tests := []struct {
+ a netip.Addr
+ b IP
+ want int
+ }{
+ {netip.MustParseAddr("fe80::1"), ParseIP("fe80::2"), 64},
+ {netip.MustParseAddr("fe81::1"), ParseIP("fe80::2"), 15},
+ {netip.MustParseAddr("127.0.0.1"), ParseIP("fe80::1"), 0}, // diff size
+ {netip.AddrFrom4([4]byte{1, 2, 3, 4}), IP{1, 2, 3, 4}, 32},
+ {netip.AddrFrom4([4]byte{1, 2, 255, 255}), IP{1, 2, 0, 0}, 16},
+ {netip.AddrFrom4([4]byte{1, 2, 127, 255}), IP{1, 2, 0, 0}, 17},
+ {netip.AddrFrom4([4]byte{1, 2, 63, 255}), IP{1, 2, 0, 0}, 18},
+ {netip.AddrFrom4([4]byte{1, 2, 31, 255}), IP{1, 2, 0, 0}, 19},
+ {netip.AddrFrom4([4]byte{1, 2, 15, 255}), IP{1, 2, 0, 0}, 20},
+ {netip.AddrFrom4([4]byte{1, 2, 7, 255}), IP{1, 2, 0, 0}, 21},
+ {netip.AddrFrom4([4]byte{1, 2, 3, 255}), IP{1, 2, 0, 0}, 22},
+ {netip.AddrFrom4([4]byte{1, 2, 1, 255}), IP{1, 2, 0, 0}, 23},
+ {netip.AddrFrom4([4]byte{1, 2, 0, 255}), IP{1, 2, 0, 0}, 24},
+ }
+ for i, tt := range tests {
+ got := commonPrefixLen(tt.a, tt.b)
+ if got != tt.want {
+ t.Errorf("%d. commonPrefixLen(%s, %s) = %d; want %d", i, tt.a, tt.b, got, tt.want)
+ }
+ }
+
+}
diff --git a/src/net/cgo_aix.go b/src/net/cgo_aix.go
new file mode 100644
index 0000000..f347814
--- /dev/null
+++ b/src/net/cgo_aix.go
@@ -0,0 +1,24 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build cgo && !netgo
+
+package net
+
+/*
+#include <sys/types.h>
+#include <sys/socket.h>
+
+#include <netdb.h>
+*/
+import "C"
+
+import "unsafe"
+
+const cgoAddrInfoFlags = C.AI_CANONNAME
+
+func cgoNameinfoPTR(b []byte, sa *C.struct_sockaddr, salen C.socklen_t) (int, error) {
+ gerrno, err := C.getnameinfo(sa, C.size_t(salen), (*C.char)(unsafe.Pointer(&b[0])), C.size_t(len(b)), nil, 0, C.NI_NAMEREQD)
+ return int(gerrno), err
+}
diff --git a/src/net/cgo_android.go b/src/net/cgo_android.go
new file mode 100644
index 0000000..5ab8b5f
--- /dev/null
+++ b/src/net/cgo_android.go
@@ -0,0 +1,12 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build cgo && !netgo
+
+package net
+
+//#include <netdb.h>
+import "C"
+
+const cgoAddrInfoFlags = C.AI_CANONNAME
diff --git a/src/net/cgo_bsd.go b/src/net/cgo_bsd.go
new file mode 100644
index 0000000..082e91f
--- /dev/null
+++ b/src/net/cgo_bsd.go
@@ -0,0 +1,14 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build cgo && !netgo && (dragonfly || freebsd)
+
+package net
+
+/*
+#include <netdb.h>
+*/
+import "C"
+
+const cgoAddrInfoFlags = (C.AI_CANONNAME | C.AI_V4MAPPED | C.AI_ALL) & C.AI_MASK
diff --git a/src/net/cgo_darwin.go b/src/net/cgo_darwin.go
new file mode 100644
index 0000000..129dd93
--- /dev/null
+++ b/src/net/cgo_darwin.go
@@ -0,0 +1,9 @@
+// Copyright 2022 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import "internal/syscall/unix"
+
+const cgoAddrInfoFlags = (unix.AI_CANONNAME | unix.AI_V4MAPPED | unix.AI_ALL) & unix.AI_MASK
diff --git a/src/net/cgo_linux.go b/src/net/cgo_linux.go
new file mode 100644
index 0000000..de6e87f
--- /dev/null
+++ b/src/net/cgo_linux.go
@@ -0,0 +1,20 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !android && cgo && !netgo
+
+package net
+
+/*
+#include <netdb.h>
+*/
+import "C"
+
+// NOTE(rsc): In theory there are approximately balanced
+// arguments for and against including AI_ADDRCONFIG
+// in the flags (it includes IPv4 results only on IPv4 systems,
+// and similarly for IPv6), but in practice setting it causes
+// getaddrinfo to return the wrong canonical name on Linux.
+// So definitely leave it out.
+const cgoAddrInfoFlags = C.AI_CANONNAME | C.AI_V4MAPPED | C.AI_ALL
diff --git a/src/net/cgo_netbsd.go b/src/net/cgo_netbsd.go
new file mode 100644
index 0000000..03392e8
--- /dev/null
+++ b/src/net/cgo_netbsd.go
@@ -0,0 +1,14 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build cgo && !netgo
+
+package net
+
+/*
+#include <netdb.h>
+*/
+import "C"
+
+const cgoAddrInfoFlags = C.AI_CANONNAME
diff --git a/src/net/cgo_openbsd.go b/src/net/cgo_openbsd.go
new file mode 100644
index 0000000..03392e8
--- /dev/null
+++ b/src/net/cgo_openbsd.go
@@ -0,0 +1,14 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build cgo && !netgo
+
+package net
+
+/*
+#include <netdb.h>
+*/
+import "C"
+
+const cgoAddrInfoFlags = C.AI_CANONNAME
diff --git a/src/net/cgo_resnew.go b/src/net/cgo_resnew.go
new file mode 100644
index 0000000..3f21c5c
--- /dev/null
+++ b/src/net/cgo_resnew.go
@@ -0,0 +1,22 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build cgo && !netgo && ((linux && !android) || netbsd || solaris)
+
+package net
+
+/*
+#include <sys/types.h>
+#include <sys/socket.h>
+
+#include <netdb.h>
+*/
+import "C"
+
+import "unsafe"
+
+func cgoNameinfoPTR(b []byte, sa *C.struct_sockaddr, salen C.socklen_t) (int, error) {
+ gerrno, err := C.getnameinfo(sa, salen, (*C.char)(unsafe.Pointer(&b[0])), C.socklen_t(len(b)), nil, 0, C.NI_NAMEREQD)
+ return int(gerrno), err
+}
diff --git a/src/net/cgo_resold.go b/src/net/cgo_resold.go
new file mode 100644
index 0000000..37c7552
--- /dev/null
+++ b/src/net/cgo_resold.go
@@ -0,0 +1,22 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build cgo && !netgo && (android || freebsd || dragonfly || openbsd)
+
+package net
+
+/*
+#include <sys/types.h>
+#include <sys/socket.h>
+
+#include <netdb.h>
+*/
+import "C"
+
+import "unsafe"
+
+func cgoNameinfoPTR(b []byte, sa *C.struct_sockaddr, salen C.socklen_t) (int, error) {
+ gerrno, err := C.getnameinfo(sa, salen, (*C.char)(unsafe.Pointer(&b[0])), C.size_t(len(b)), nil, 0, C.NI_NAMEREQD)
+ return int(gerrno), err
+}
diff --git a/src/net/cgo_socknew.go b/src/net/cgo_socknew.go
new file mode 100644
index 0000000..fbb9e10
--- /dev/null
+++ b/src/net/cgo_socknew.go
@@ -0,0 +1,32 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build cgo && !netgo && (android || linux || solaris)
+
+package net
+
+/*
+#include <sys/types.h>
+#include <sys/socket.h>
+
+#include <netinet/in.h>
+*/
+import "C"
+
+import (
+ "syscall"
+ "unsafe"
+)
+
+func cgoSockaddrInet4(ip IP) *C.struct_sockaddr {
+ sa := syscall.RawSockaddrInet4{Family: syscall.AF_INET}
+ copy(sa.Addr[:], ip)
+ return (*C.struct_sockaddr)(unsafe.Pointer(&sa))
+}
+
+func cgoSockaddrInet6(ip IP, zone int) *C.struct_sockaddr {
+ sa := syscall.RawSockaddrInet6{Family: syscall.AF_INET6, Scope_id: uint32(zone)}
+ copy(sa.Addr[:], ip)
+ return (*C.struct_sockaddr)(unsafe.Pointer(&sa))
+}
diff --git a/src/net/cgo_sockold.go b/src/net/cgo_sockold.go
new file mode 100644
index 0000000..d0a99e0
--- /dev/null
+++ b/src/net/cgo_sockold.go
@@ -0,0 +1,32 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build cgo && !netgo && (aix || dragonfly || freebsd || netbsd || openbsd)
+
+package net
+
+/*
+#include <sys/types.h>
+#include <sys/socket.h>
+
+#include <netinet/in.h>
+*/
+import "C"
+
+import (
+ "syscall"
+ "unsafe"
+)
+
+func cgoSockaddrInet4(ip IP) *C.struct_sockaddr {
+ sa := syscall.RawSockaddrInet4{Len: syscall.SizeofSockaddrInet4, Family: syscall.AF_INET}
+ copy(sa.Addr[:], ip)
+ return (*C.struct_sockaddr)(unsafe.Pointer(&sa))
+}
+
+func cgoSockaddrInet6(ip IP, zone int) *C.struct_sockaddr {
+ sa := syscall.RawSockaddrInet6{Len: syscall.SizeofSockaddrInet6, Family: syscall.AF_INET6, Scope_id: uint32(zone)}
+ copy(sa.Addr[:], ip)
+ return (*C.struct_sockaddr)(unsafe.Pointer(&sa))
+}
diff --git a/src/net/cgo_solaris.go b/src/net/cgo_solaris.go
new file mode 100644
index 0000000..cde9c95
--- /dev/null
+++ b/src/net/cgo_solaris.go
@@ -0,0 +1,15 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build cgo && !netgo
+
+package net
+
+/*
+#cgo LDFLAGS: -lsocket -lnsl -lsendfile
+#include <netdb.h>
+*/
+import "C"
+
+const cgoAddrInfoFlags = C.AI_CANONNAME | C.AI_V4MAPPED | C.AI_ALL
diff --git a/src/net/cgo_stub.go b/src/net/cgo_stub.go
new file mode 100644
index 0000000..b26b11a
--- /dev/null
+++ b/src/net/cgo_stub.go
@@ -0,0 +1,40 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This file holds stub versions of the cgo functions called on Unix systems.
+// We build this file:
+// - if using the netgo build tag on a Unix system
+// - on a Unix system without the cgo resolver functions
+// (Darwin always provides the cgo functions, in cgo_unix_syscall.go)
+// - on wasip1, where cgo is never available
+
+//go:build (netgo && unix) || (unix && !cgo && !darwin) || wasip1
+
+package net
+
+import "context"
+
+// cgoAvailable set to false to indicate that the cgo resolver
+// is not available on this system.
+const cgoAvailable = false
+
+func cgoLookupHost(ctx context.Context, name string) (addrs []string, err error) {
+ panic("cgo stub: cgo not available")
+}
+
+func cgoLookupPort(ctx context.Context, network, service string) (port int, err error) {
+ panic("cgo stub: cgo not available")
+}
+
+func cgoLookupIP(ctx context.Context, network, name string) (addrs []IPAddr, err error) {
+ panic("cgo stub: cgo not available")
+}
+
+func cgoLookupCNAME(ctx context.Context, name string) (cname string, err error, completed bool) {
+ panic("cgo stub: cgo not available")
+}
+
+func cgoLookupPTR(ctx context.Context, addr string) (ptrs []string, err error) {
+ panic("cgo stub: cgo not available")
+}
diff --git a/src/net/cgo_unix.go b/src/net/cgo_unix.go
new file mode 100644
index 0000000..f10f3ea
--- /dev/null
+++ b/src/net/cgo_unix.go
@@ -0,0 +1,370 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This file is called cgo_unix.go, but to allow syscalls-to-libc-based
+// implementations to share the code, it does not use cgo directly.
+// Instead of C.foo it uses _C_foo, which is defined in either
+// cgo_unix_cgo.go or cgo_unix_syscall.go
+
+//go:build !netgo && ((cgo && unix) || darwin)
+
+package net
+
+import (
+ "context"
+ "errors"
+ "net/netip"
+ "syscall"
+ "unsafe"
+
+ "golang.org/x/net/dns/dnsmessage"
+)
+
+// cgoAvailable set to true to indicate that the cgo resolver
+// is available on this system.
+const cgoAvailable = true
+
+// An addrinfoErrno represents a getaddrinfo, getnameinfo-specific
+// error number. It's a signed number and a zero value is a non-error
+// by convention.
+type addrinfoErrno int
+
+func (eai addrinfoErrno) Error() string { return _C_gai_strerror(_C_int(eai)) }
+func (eai addrinfoErrno) Temporary() bool { return eai == _C_EAI_AGAIN }
+func (eai addrinfoErrno) Timeout() bool { return false }
+
+// isAddrinfoErrno is just for testing purposes.
+func (eai addrinfoErrno) isAddrinfoErrno() {}
+
+// doBlockingWithCtx executes a blocking function in a separate goroutine when the provided
+// context is cancellable. It is intended for use with calls that don't support context
+// cancellation (cgo, syscalls). blocking func may still be running after this function finishes.
+func doBlockingWithCtx[T any](ctx context.Context, blocking func() (T, error)) (T, error) {
+ if ctx.Done() == nil {
+ return blocking()
+ }
+
+ type result struct {
+ res T
+ err error
+ }
+
+ res := make(chan result, 1)
+ go func() {
+ var r result
+ r.res, r.err = blocking()
+ res <- r
+ }()
+
+ select {
+ case r := <-res:
+ return r.res, r.err
+ case <-ctx.Done():
+ var zero T
+ return zero, mapErr(ctx.Err())
+ }
+}
+
+func cgoLookupHost(ctx context.Context, name string) (hosts []string, err error) {
+ addrs, err := cgoLookupIP(ctx, "ip", name)
+ if err != nil {
+ return nil, err
+ }
+ for _, addr := range addrs {
+ hosts = append(hosts, addr.String())
+ }
+ return hosts, nil
+}
+
+func cgoLookupPort(ctx context.Context, network, service string) (port int, err error) {
+ var hints _C_struct_addrinfo
+ switch network {
+ case "": // no hints
+ case "tcp", "tcp4", "tcp6":
+ *_C_ai_socktype(&hints) = _C_SOCK_STREAM
+ *_C_ai_protocol(&hints) = _C_IPPROTO_TCP
+ case "udp", "udp4", "udp6":
+ *_C_ai_socktype(&hints) = _C_SOCK_DGRAM
+ *_C_ai_protocol(&hints) = _C_IPPROTO_UDP
+ default:
+ return 0, &DNSError{Err: "unknown network", Name: network + "/" + service}
+ }
+ switch ipVersion(network) {
+ case '4':
+ *_C_ai_family(&hints) = _C_AF_INET
+ case '6':
+ *_C_ai_family(&hints) = _C_AF_INET6
+ }
+
+ return doBlockingWithCtx(ctx, func() (int, error) {
+ return cgoLookupServicePort(&hints, network, service)
+ })
+}
+
+func cgoLookupServicePort(hints *_C_struct_addrinfo, network, service string) (port int, err error) {
+ cservice, err := syscall.ByteSliceFromString(service)
+ if err != nil {
+ return 0, &DNSError{Err: err.Error(), Name: network + "/" + service}
+ }
+ // Lowercase the C service name.
+ for i, b := range cservice[:len(service)] {
+ cservice[i] = lowerASCII(b)
+ }
+ var res *_C_struct_addrinfo
+ gerrno, err := _C_getaddrinfo(nil, (*_C_char)(unsafe.Pointer(&cservice[0])), hints, &res)
+ if gerrno != 0 {
+ isTemporary := false
+ switch gerrno {
+ case _C_EAI_SYSTEM:
+ if err == nil { // see golang.org/issue/6232
+ err = syscall.EMFILE
+ }
+ default:
+ err = addrinfoErrno(gerrno)
+ isTemporary = addrinfoErrno(gerrno).Temporary()
+ }
+ return 0, &DNSError{Err: err.Error(), Name: network + "/" + service, IsTemporary: isTemporary}
+ }
+ defer _C_freeaddrinfo(res)
+
+ for r := res; r != nil; r = *_C_ai_next(r) {
+ switch *_C_ai_family(r) {
+ case _C_AF_INET:
+ sa := (*syscall.RawSockaddrInet4)(unsafe.Pointer(*_C_ai_addr(r)))
+ p := (*[2]byte)(unsafe.Pointer(&sa.Port))
+ return int(p[0])<<8 | int(p[1]), nil
+ case _C_AF_INET6:
+ sa := (*syscall.RawSockaddrInet6)(unsafe.Pointer(*_C_ai_addr(r)))
+ p := (*[2]byte)(unsafe.Pointer(&sa.Port))
+ return int(p[0])<<8 | int(p[1]), nil
+ }
+ }
+ return 0, &DNSError{Err: "unknown port", Name: network + "/" + service}
+}
+
+func cgoLookupHostIP(network, name string) (addrs []IPAddr, err error) {
+ acquireThread()
+ defer releaseThread()
+
+ var hints _C_struct_addrinfo
+ *_C_ai_flags(&hints) = cgoAddrInfoFlags
+ *_C_ai_socktype(&hints) = _C_SOCK_STREAM
+ *_C_ai_family(&hints) = _C_AF_UNSPEC
+ switch ipVersion(network) {
+ case '4':
+ *_C_ai_family(&hints) = _C_AF_INET
+ case '6':
+ *_C_ai_family(&hints) = _C_AF_INET6
+ }
+
+ h, err := syscall.BytePtrFromString(name)
+ if err != nil {
+ return nil, &DNSError{Err: err.Error(), Name: name}
+ }
+ var res *_C_struct_addrinfo
+ gerrno, err := _C_getaddrinfo((*_C_char)(unsafe.Pointer(h)), nil, &hints, &res)
+ if gerrno != 0 {
+ isErrorNoSuchHost := false
+ isTemporary := false
+ switch gerrno {
+ case _C_EAI_SYSTEM:
+ if err == nil {
+ // err should not be nil, but sometimes getaddrinfo returns
+ // gerrno == _C_EAI_SYSTEM with err == nil on Linux.
+ // The report claims that it happens when we have too many
+ // open files, so use syscall.EMFILE (too many open files in system).
+ // Most system calls would return ENFILE (too many open files),
+ // so at the least EMFILE should be easy to recognize if this
+ // comes up again. golang.org/issue/6232.
+ err = syscall.EMFILE
+ }
+ case _C_EAI_NONAME, _C_EAI_NODATA:
+ err = errNoSuchHost
+ isErrorNoSuchHost = true
+ default:
+ err = addrinfoErrno(gerrno)
+ isTemporary = addrinfoErrno(gerrno).Temporary()
+ }
+
+ return nil, &DNSError{Err: err.Error(), Name: name, IsNotFound: isErrorNoSuchHost, IsTemporary: isTemporary}
+ }
+ defer _C_freeaddrinfo(res)
+
+ for r := res; r != nil; r = *_C_ai_next(r) {
+ // We only asked for SOCK_STREAM, but check anyhow.
+ if *_C_ai_socktype(r) != _C_SOCK_STREAM {
+ continue
+ }
+ switch *_C_ai_family(r) {
+ case _C_AF_INET:
+ sa := (*syscall.RawSockaddrInet4)(unsafe.Pointer(*_C_ai_addr(r)))
+ addr := IPAddr{IP: copyIP(sa.Addr[:])}
+ addrs = append(addrs, addr)
+ case _C_AF_INET6:
+ sa := (*syscall.RawSockaddrInet6)(unsafe.Pointer(*_C_ai_addr(r)))
+ addr := IPAddr{IP: copyIP(sa.Addr[:]), Zone: zoneCache.name(int(sa.Scope_id))}
+ addrs = append(addrs, addr)
+ }
+ }
+ return addrs, nil
+}
+
+func cgoLookupIP(ctx context.Context, network, name string) (addrs []IPAddr, err error) {
+ return doBlockingWithCtx(ctx, func() ([]IPAddr, error) {
+ return cgoLookupHostIP(network, name)
+ })
+}
+
+// These are roughly enough for the following:
+//
+// Source Encoding Maximum length of single name entry
+// Unicast DNS ASCII or <=253 + a NUL terminator
+// Unicode in RFC 5892 252 * total number of labels + delimiters + a NUL terminator
+// Multicast DNS UTF-8 in RFC 5198 or <=253 + a NUL terminator
+// the same as unicast DNS ASCII <=253 + a NUL terminator
+// Local database various depends on implementation
+const (
+ nameinfoLen = 64
+ maxNameinfoLen = 4096
+)
+
+func cgoLookupPTR(ctx context.Context, addr string) (names []string, err error) {
+ ip, err := netip.ParseAddr(addr)
+ if err != nil {
+ return nil, &DNSError{Err: "invalid address", Name: addr}
+ }
+ sa, salen := cgoSockaddr(IP(ip.AsSlice()), ip.Zone())
+ if sa == nil {
+ return nil, &DNSError{Err: "invalid address " + ip.String(), Name: addr}
+ }
+
+ return doBlockingWithCtx(ctx, func() ([]string, error) {
+ return cgoLookupAddrPTR(addr, sa, salen)
+ })
+}
+
+func cgoLookupAddrPTR(addr string, sa *_C_struct_sockaddr, salen _C_socklen_t) (names []string, err error) {
+ acquireThread()
+ defer releaseThread()
+
+ var gerrno int
+ var b []byte
+ for l := nameinfoLen; l <= maxNameinfoLen; l *= 2 {
+ b = make([]byte, l)
+ gerrno, err = cgoNameinfoPTR(b, sa, salen)
+ if gerrno == 0 || gerrno != _C_EAI_OVERFLOW {
+ break
+ }
+ }
+ if gerrno != 0 {
+ isErrorNoSuchHost := false
+ isTemporary := false
+ switch gerrno {
+ case _C_EAI_SYSTEM:
+ if err == nil { // see golang.org/issue/6232
+ err = syscall.EMFILE
+ }
+ case _C_EAI_NONAME:
+ err = errNoSuchHost
+ isErrorNoSuchHost = true
+ default:
+ err = addrinfoErrno(gerrno)
+ isTemporary = addrinfoErrno(gerrno).Temporary()
+ }
+ return nil, &DNSError{Err: err.Error(), Name: addr, IsTemporary: isTemporary, IsNotFound: isErrorNoSuchHost}
+ }
+ for i := 0; i < len(b); i++ {
+ if b[i] == 0 {
+ b = b[:i]
+ break
+ }
+ }
+ return []string{absDomainName(string(b))}, nil
+}
+
+func cgoSockaddr(ip IP, zone string) (*_C_struct_sockaddr, _C_socklen_t) {
+ if ip4 := ip.To4(); ip4 != nil {
+ return cgoSockaddrInet4(ip4), _C_socklen_t(syscall.SizeofSockaddrInet4)
+ }
+ if ip6 := ip.To16(); ip6 != nil {
+ return cgoSockaddrInet6(ip6, zoneCache.index(zone)), _C_socklen_t(syscall.SizeofSockaddrInet6)
+ }
+ return nil, 0
+}
+
+func cgoLookupCNAME(ctx context.Context, name string) (cname string, err error, completed bool) {
+ resources, err := resSearch(ctx, name, int(dnsmessage.TypeCNAME), int(dnsmessage.ClassINET))
+ if err != nil {
+ return
+ }
+ cname, err = parseCNAMEFromResources(resources)
+ if err != nil {
+ return "", err, false
+ }
+ return cname, nil, true
+}
+
+// resSearch will make a call to the 'res_nsearch' routine in the C library
+// and parse the output as a slice of DNS resources.
+func resSearch(ctx context.Context, hostname string, rtype, class int) ([]dnsmessage.Resource, error) {
+ return doBlockingWithCtx(ctx, func() ([]dnsmessage.Resource, error) {
+ return cgoResSearch(hostname, rtype, class)
+ })
+}
+
+func cgoResSearch(hostname string, rtype, class int) ([]dnsmessage.Resource, error) {
+ acquireThread()
+ defer releaseThread()
+
+ state := (*_C_struct___res_state)(_C_malloc(unsafe.Sizeof(_C_struct___res_state{})))
+ defer _C_free(unsafe.Pointer(state))
+ if err := _C_res_ninit(state); err != nil {
+ return nil, errors.New("res_ninit failure: " + err.Error())
+ }
+ defer _C_res_nclose(state)
+
+ // Some res_nsearch implementations (like macOS) do not set errno.
+ // They set h_errno, which is not per-thread and useless to us.
+ // res_nsearch returns the size of the DNS response packet.
+ // But if the DNS response packet contains failure-like response codes,
+ // res_search returns -1 even though it has copied the packet into buf,
+ // giving us no way to find out how big the packet is.
+ // For now, we are willing to take res_search's word that there's nothing
+ // useful in the response, even though there *is* a response.
+ bufSize := maxDNSPacketSize
+ buf := (*_C_uchar)(_C_malloc(uintptr(bufSize)))
+ defer _C_free(unsafe.Pointer(buf))
+
+ s, err := syscall.BytePtrFromString(hostname)
+ if err != nil {
+ return nil, err
+ }
+
+ var size int
+ for {
+ size, _ = _C_res_nsearch(state, (*_C_char)(unsafe.Pointer(s)), class, rtype, buf, bufSize)
+ if size <= 0 || size > 0xffff {
+ return nil, errors.New("res_nsearch failure")
+ }
+ if size <= bufSize {
+ break
+ }
+
+ // Allocate a bigger buffer to fit the entire msg.
+ _C_free(unsafe.Pointer(buf))
+ bufSize = size
+ buf = (*_C_uchar)(_C_malloc(uintptr(bufSize)))
+ }
+
+ var p dnsmessage.Parser
+ if _, err := p.Start(unsafe.Slice((*byte)(unsafe.Pointer(buf)), size)); err != nil {
+ return nil, err
+ }
+ p.SkipAllQuestions()
+ resources, err := p.AllAnswers()
+ if err != nil {
+ return nil, err
+ }
+ return resources, nil
+}
diff --git a/src/net/cgo_unix_cgo.go b/src/net/cgo_unix_cgo.go
new file mode 100644
index 0000000..d11f3e3
--- /dev/null
+++ b/src/net/cgo_unix_cgo.go
@@ -0,0 +1,80 @@
+// Copyright 2022 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build cgo && !netgo && unix && !darwin
+
+package net
+
+/*
+#define _GNU_SOURCE
+
+#cgo CFLAGS: -fno-stack-protector
+#include <sys/types.h>
+#include <sys/socket.h>
+#include <netinet/in.h>
+#include <netdb.h>
+#include <unistd.h>
+#include <string.h>
+#include <stdlib.h>
+
+#ifndef EAI_NODATA
+#define EAI_NODATA -5
+#endif
+
+// If nothing else defined EAI_OVERFLOW, make sure it has a value.
+#ifndef EAI_OVERFLOW
+#define EAI_OVERFLOW -12
+#endif
+*/
+import "C"
+import "unsafe"
+
+const (
+ _C_AF_INET = C.AF_INET
+ _C_AF_INET6 = C.AF_INET6
+ _C_AF_UNSPEC = C.AF_UNSPEC
+ _C_EAI_AGAIN = C.EAI_AGAIN
+ _C_EAI_NODATA = C.EAI_NODATA
+ _C_EAI_NONAME = C.EAI_NONAME
+ _C_EAI_OVERFLOW = C.EAI_OVERFLOW
+ _C_EAI_SYSTEM = C.EAI_SYSTEM
+ _C_IPPROTO_TCP = C.IPPROTO_TCP
+ _C_IPPROTO_UDP = C.IPPROTO_UDP
+ _C_SOCK_DGRAM = C.SOCK_DGRAM
+ _C_SOCK_STREAM = C.SOCK_STREAM
+)
+
+type (
+ _C_char = C.char
+ _C_uchar = C.uchar
+ _C_int = C.int
+ _C_uint = C.uint
+ _C_socklen_t = C.socklen_t
+ _C_struct_addrinfo = C.struct_addrinfo
+ _C_struct_sockaddr = C.struct_sockaddr
+)
+
+func _C_GoString(p *_C_char) string { return C.GoString(p) }
+func _C_malloc(n uintptr) unsafe.Pointer { return C.malloc(C.size_t(n)) }
+func _C_free(p unsafe.Pointer) { C.free(p) }
+
+func _C_ai_addr(ai *_C_struct_addrinfo) **_C_struct_sockaddr { return &ai.ai_addr }
+func _C_ai_family(ai *_C_struct_addrinfo) *_C_int { return &ai.ai_family }
+func _C_ai_flags(ai *_C_struct_addrinfo) *_C_int { return &ai.ai_flags }
+func _C_ai_next(ai *_C_struct_addrinfo) **_C_struct_addrinfo { return &ai.ai_next }
+func _C_ai_protocol(ai *_C_struct_addrinfo) *_C_int { return &ai.ai_protocol }
+func _C_ai_socktype(ai *_C_struct_addrinfo) *_C_int { return &ai.ai_socktype }
+
+func _C_freeaddrinfo(ai *_C_struct_addrinfo) {
+ C.freeaddrinfo(ai)
+}
+
+func _C_gai_strerror(eai _C_int) string {
+ return C.GoString(C.gai_strerror(eai))
+}
+
+func _C_getaddrinfo(hostname, servname *_C_char, hints *_C_struct_addrinfo, res **_C_struct_addrinfo) (int, error) {
+ x, err := C.getaddrinfo(hostname, servname, hints, res)
+ return int(x), err
+}
diff --git a/src/net/cgo_unix_cgo_darwin.go b/src/net/cgo_unix_cgo_darwin.go
new file mode 100644
index 0000000..40d5e42
--- /dev/null
+++ b/src/net/cgo_unix_cgo_darwin.go
@@ -0,0 +1,21 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !netgo && cgo && darwin
+
+package net
+
+/*
+#include <resolv.h>
+*/
+import "C"
+
+import (
+ "internal/syscall/unix"
+ "unsafe"
+)
+
+// This will cause a compile error when the size of
+// unix.ResState is too small.
+type _ [unsafe.Sizeof(unix.ResState{}) - unsafe.Sizeof(C.struct___res_state{})]byte
diff --git a/src/net/cgo_unix_cgo_res.go b/src/net/cgo_unix_cgo_res.go
new file mode 100644
index 0000000..37bbc9a
--- /dev/null
+++ b/src/net/cgo_unix_cgo_res.go
@@ -0,0 +1,38 @@
+// Copyright 2022 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// res_search, for cgo systems where that is thread-safe.
+
+//go:build cgo && !netgo && (linux || openbsd)
+
+package net
+
+/*
+#include <sys/types.h>
+#include <sys/socket.h>
+#include <netinet/in.h>
+#include <netdb.h>
+#include <unistd.h>
+#include <string.h>
+#include <arpa/nameser.h>
+#include <resolv.h>
+
+#cgo !android,!openbsd LDFLAGS: -lresolv
+*/
+import "C"
+
+type _C_struct___res_state = struct{}
+
+func _C_res_ninit(state *_C_struct___res_state) error {
+ return nil
+}
+
+func _C_res_nclose(state *_C_struct___res_state) {
+ return
+}
+
+func _C_res_nsearch(state *_C_struct___res_state, dname *_C_char, class, typ int, ans *_C_uchar, anslen int) (int, error) {
+ x, err := C.res_search(dname, C.int(class), C.int(typ), ans, C.int(anslen))
+ return int(x), err
+}
diff --git a/src/net/cgo_unix_cgo_resn.go b/src/net/cgo_unix_cgo_resn.go
new file mode 100644
index 0000000..4a5ff16
--- /dev/null
+++ b/src/net/cgo_unix_cgo_resn.go
@@ -0,0 +1,39 @@
+// Copyright 2022 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// res_nsearch, for cgo systems where that's available.
+
+//go:build cgo && !netgo && unix && !(darwin || linux || openbsd)
+
+package net
+
+/*
+#include <sys/types.h>
+#include <sys/socket.h>
+#include <netinet/in.h>
+#include <netdb.h>
+#include <unistd.h>
+#include <string.h>
+#include <arpa/nameser.h>
+#include <resolv.h>
+
+#cgo !aix,!dragonfly,!freebsd LDFLAGS: -lresolv
+*/
+import "C"
+
+type _C_struct___res_state = C.struct___res_state
+
+func _C_res_ninit(state *_C_struct___res_state) error {
+ _, err := C.res_ninit(state)
+ return err
+}
+
+func _C_res_nclose(state *_C_struct___res_state) {
+ C.res_nclose(state)
+}
+
+func _C_res_nsearch(state *_C_struct___res_state, dname *_C_char, class, typ int, ans *_C_uchar, anslen int) (int, error) {
+ x, err := C.res_nsearch(state, dname, C.int(class), C.int(typ), ans, C.int(anslen))
+ return int(x), err
+}
diff --git a/src/net/cgo_unix_syscall.go b/src/net/cgo_unix_syscall.go
new file mode 100644
index 0000000..2eb8df1
--- /dev/null
+++ b/src/net/cgo_unix_syscall.go
@@ -0,0 +1,102 @@
+// Copyright 2022 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !netgo && darwin
+
+package net
+
+import (
+ "internal/syscall/unix"
+ "runtime"
+ "syscall"
+ "unsafe"
+)
+
+const (
+ _C_AF_INET = syscall.AF_INET
+ _C_AF_INET6 = syscall.AF_INET6
+ _C_AF_UNSPEC = syscall.AF_UNSPEC
+ _C_EAI_AGAIN = unix.EAI_AGAIN
+ _C_EAI_NONAME = unix.EAI_NONAME
+ _C_EAI_NODATA = unix.EAI_NODATA
+ _C_EAI_OVERFLOW = unix.EAI_OVERFLOW
+ _C_EAI_SYSTEM = unix.EAI_SYSTEM
+ _C_IPPROTO_TCP = syscall.IPPROTO_TCP
+ _C_IPPROTO_UDP = syscall.IPPROTO_UDP
+ _C_SOCK_DGRAM = syscall.SOCK_DGRAM
+ _C_SOCK_STREAM = syscall.SOCK_STREAM
+)
+
+type (
+ _C_char = byte
+ _C_int = int32
+ _C_uchar = byte
+ _C_uint = uint32
+ _C_socklen_t = int
+ _C_struct___res_state = unix.ResState
+ _C_struct_addrinfo = unix.Addrinfo
+ _C_struct_sockaddr = syscall.RawSockaddr
+)
+
+func _C_GoString(p *_C_char) string {
+ return unix.GoString(p)
+}
+
+func _C_free(p unsafe.Pointer) { runtime.KeepAlive(p) }
+
+func _C_malloc(n uintptr) unsafe.Pointer {
+ if n <= 0 {
+ n = 1
+ }
+ return unsafe.Pointer(&make([]byte, n)[0])
+}
+
+func _C_ai_addr(ai *_C_struct_addrinfo) **_C_struct_sockaddr { return &ai.Addr }
+func _C_ai_family(ai *_C_struct_addrinfo) *_C_int { return &ai.Family }
+func _C_ai_flags(ai *_C_struct_addrinfo) *_C_int { return &ai.Flags }
+func _C_ai_next(ai *_C_struct_addrinfo) **_C_struct_addrinfo { return &ai.Next }
+func _C_ai_protocol(ai *_C_struct_addrinfo) *_C_int { return &ai.Protocol }
+func _C_ai_socktype(ai *_C_struct_addrinfo) *_C_int { return &ai.Socktype }
+
+func _C_freeaddrinfo(ai *_C_struct_addrinfo) {
+ unix.Freeaddrinfo(ai)
+}
+
+func _C_gai_strerror(eai _C_int) string {
+ return unix.GaiStrerror(int(eai))
+}
+
+func _C_getaddrinfo(hostname, servname *byte, hints *_C_struct_addrinfo, res **_C_struct_addrinfo) (int, error) {
+ return unix.Getaddrinfo(hostname, servname, hints, res)
+}
+
+func _C_res_ninit(state *_C_struct___res_state) error {
+ unix.ResNinit(state)
+ return nil
+}
+
+func _C_res_nsearch(state *_C_struct___res_state, dname *_C_char, class, typ int, ans *_C_char, anslen int) (int, error) {
+ return unix.ResNsearch(state, dname, class, typ, ans, anslen)
+}
+
+func _C_res_nclose(state *_C_struct___res_state) {
+ unix.ResNclose(state)
+}
+
+func cgoNameinfoPTR(b []byte, sa *syscall.RawSockaddr, salen int) (int, error) {
+ gerrno, err := unix.Getnameinfo(sa, salen, &b[0], len(b), nil, 0, unix.NI_NAMEREQD)
+ return int(gerrno), err
+}
+
+func cgoSockaddrInet4(ip IP) *syscall.RawSockaddr {
+ sa := syscall.RawSockaddrInet4{Len: syscall.SizeofSockaddrInet4, Family: syscall.AF_INET}
+ copy(sa.Addr[:], ip)
+ return (*syscall.RawSockaddr)(unsafe.Pointer(&sa))
+}
+
+func cgoSockaddrInet6(ip IP, zone int) *syscall.RawSockaddr {
+ sa := syscall.RawSockaddrInet6{Len: syscall.SizeofSockaddrInet6, Family: syscall.AF_INET6, Scope_id: uint32(zone)}
+ copy(sa.Addr[:], ip)
+ return (*syscall.RawSockaddr)(unsafe.Pointer(&sa))
+}
diff --git a/src/net/cgo_unix_test.go b/src/net/cgo_unix_test.go
new file mode 100644
index 0000000..d8233df
--- /dev/null
+++ b/src/net/cgo_unix_test.go
@@ -0,0 +1,69 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !netgo && ((cgo && unix) || darwin)
+
+package net
+
+import (
+ "context"
+ "testing"
+)
+
+func TestCgoLookupIP(t *testing.T) {
+ defer dnsWaitGroup.Wait()
+ ctx := context.Background()
+ _, err := cgoLookupIP(ctx, "ip", "localhost")
+ if err != nil {
+ t.Error(err)
+ }
+}
+
+func TestCgoLookupIPWithCancel(t *testing.T) {
+ defer dnsWaitGroup.Wait()
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ _, err := cgoLookupIP(ctx, "ip", "localhost")
+ if err != nil {
+ t.Error(err)
+ }
+}
+
+func TestCgoLookupPort(t *testing.T) {
+ defer dnsWaitGroup.Wait()
+ ctx := context.Background()
+ _, err := cgoLookupPort(ctx, "tcp", "smtp")
+ if err != nil {
+ t.Error(err)
+ }
+}
+
+func TestCgoLookupPortWithCancel(t *testing.T) {
+ defer dnsWaitGroup.Wait()
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ _, err := cgoLookupPort(ctx, "tcp", "smtp")
+ if err != nil {
+ t.Error(err)
+ }
+}
+
+func TestCgoLookupPTR(t *testing.T) {
+ defer dnsWaitGroup.Wait()
+ ctx := context.Background()
+ _, err := cgoLookupPTR(ctx, "127.0.0.1")
+ if err != nil {
+ t.Error(err)
+ }
+}
+
+func TestCgoLookupPTRWithCancel(t *testing.T) {
+ defer dnsWaitGroup.Wait()
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ _, err := cgoLookupPTR(ctx, "127.0.0.1")
+ if err != nil {
+ t.Error(err)
+ }
+}
diff --git a/src/net/conf.go b/src/net/conf.go
new file mode 100644
index 0000000..77cc635
--- /dev/null
+++ b/src/net/conf.go
@@ -0,0 +1,523 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !js
+
+package net
+
+import (
+ "errors"
+ "internal/bytealg"
+ "internal/godebug"
+ "io/fs"
+ "os"
+ "runtime"
+ "sync"
+ "syscall"
+)
+
+// The net package's name resolution is rather complicated.
+// There are two main approaches, go and cgo.
+// The cgo resolver uses C functions like getaddrinfo.
+// The go resolver reads system files directly and
+// sends DNS packets directly to servers.
+//
+// The netgo build tag prefers the go resolver.
+// The netcgo build tag prefers the cgo resolver.
+//
+// The netgo build tag also prohibits the use of the cgo tool.
+// However, on Darwin, Plan 9, and Windows the cgo resolver is still available.
+// On those systems the cgo resolver does not require the cgo tool.
+// (The term "cgo resolver" was locked in by GODEBUG settings
+// at a time when the cgo resolver did require the cgo tool.)
+//
+// Adding netdns=go to GODEBUG will prefer the go resolver.
+// Adding netdns=cgo to GODEBUG will prefer the cgo resolver.
+//
+// The Resolver struct has a PreferGo field that user code
+// may set to prefer the go resolver. It is documented as being
+// equivalent to adding netdns=go to GODEBUG.
+//
+// When deciding which resolver to use, we first check the PreferGo field.
+// If that is not set, we check the GODEBUG setting.
+// If that is not set, we check the netgo or netcgo build tag.
+// If none of those are set, we normally prefer the go resolver by default.
+// However, if the cgo resolver is available,
+// there is a complex set of conditions for which we prefer the cgo resolver.
+//
+// Other files define the netGoBuildTag, netCgoBuildTag, and cgoAvailable
+// constants.
+
+// conf is used to determine name resolution configuration.
+type conf struct {
+ netGo bool // prefer go approach, based on build tag and GODEBUG
+ netCgo bool // prefer cgo approach, based on build tag and GODEBUG
+
+ dnsDebugLevel int // from GODEBUG
+
+ preferCgo bool // if no explicit preference, use cgo
+
+ goos string // copy of runtime.GOOS, used for testing
+ mdnsTest mdnsTest // assume /etc/mdns.allow exists, for testing
+}
+
+// mdnsTest is for testing only.
+type mdnsTest int
+
+const (
+ mdnsFromSystem mdnsTest = iota
+ mdnsAssumeExists
+ mdnsAssumeDoesNotExist
+)
+
+var (
+ confOnce sync.Once // guards init of confVal via initConfVal
+ confVal = &conf{goos: runtime.GOOS}
+)
+
+// systemConf returns the machine's network configuration.
+func systemConf() *conf {
+ confOnce.Do(initConfVal)
+ return confVal
+}
+
+// initConfVal initializes confVal based on the environment
+// that will not change during program execution.
+func initConfVal() {
+ dnsMode, debugLevel := goDebugNetDNS()
+ confVal.netGo = netGoBuildTag || dnsMode == "go"
+ confVal.netCgo = netCgoBuildTag || dnsMode == "cgo"
+ confVal.dnsDebugLevel = debugLevel
+
+ if confVal.dnsDebugLevel > 0 {
+ defer func() {
+ if confVal.dnsDebugLevel > 1 {
+ println("go package net: confVal.netCgo =", confVal.netCgo, " netGo =", confVal.netGo)
+ }
+ switch {
+ case confVal.netGo:
+ if netGoBuildTag {
+ println("go package net: built with netgo build tag; using Go's DNS resolver")
+ } else {
+ println("go package net: GODEBUG setting forcing use of Go's resolver")
+ }
+ case !cgoAvailable:
+ println("go package net: cgo resolver not supported; using Go's DNS resolver")
+ case confVal.netCgo || confVal.preferCgo:
+ println("go package net: using cgo DNS resolver")
+ default:
+ println("go package net: dynamic selection of DNS resolver")
+ }
+ }()
+ }
+
+ // The remainder of this function sets preferCgo based on
+ // conditions that will not change during program execution.
+
+ // By default, prefer the go resolver.
+ confVal.preferCgo = false
+
+ // If the cgo resolver is not available, we can't prefer it.
+ if !cgoAvailable {
+ return
+ }
+
+ // Some operating systems always prefer the cgo resolver.
+ if goosPrefersCgo() {
+ confVal.preferCgo = true
+ return
+ }
+
+ // The remaining checks are specific to Unix systems.
+ switch runtime.GOOS {
+ case "plan9", "windows", "js", "wasip1":
+ return
+ }
+
+ // If any environment-specified resolver options are specified,
+ // prefer the cgo resolver.
+ // Note that LOCALDOMAIN can change behavior merely by being
+ // specified with the empty string.
+ _, localDomainDefined := syscall.Getenv("LOCALDOMAIN")
+ if localDomainDefined || os.Getenv("RES_OPTIONS") != "" || os.Getenv("HOSTALIASES") != "" {
+ confVal.preferCgo = true
+ return
+ }
+
+ // OpenBSD apparently lets you override the location of resolv.conf
+ // with ASR_CONFIG. If we notice that, defer to libc.
+ if runtime.GOOS == "openbsd" && os.Getenv("ASR_CONFIG") != "" {
+ confVal.preferCgo = true
+ return
+ }
+}
+
+// goosPreferCgo reports whether the GOOS value passed in prefers
+// the cgo resolver.
+func goosPrefersCgo() bool {
+ switch runtime.GOOS {
+ // Historically on Windows and Plan 9 we prefer the
+ // cgo resolver (which doesn't use the cgo tool) rather than
+ // the go resolver. This is because originally these
+ // systems did not support the go resolver.
+ // Keep it this way for better compatibility.
+ // Perhaps we can revisit this some day.
+ case "windows", "plan9":
+ return true
+
+ // Darwin pops up annoying dialog boxes if programs try to
+ // do their own DNS requests, so prefer cgo.
+ case "darwin", "ios":
+ return true
+
+ // DNS requests don't work on Android, so prefer the cgo resolver.
+ // Issue #10714.
+ case "android":
+ return true
+
+ default:
+ return false
+ }
+}
+
+// mustUseGoResolver reports whether a DNS lookup of any sort is
+// required to use the go resolver. The provided Resolver is optional.
+// This will report true if the cgo resolver is not available.
+func (c *conf) mustUseGoResolver(r *Resolver) bool {
+ return c.netGo || r.preferGo() || !cgoAvailable
+}
+
+// addrLookupOrder determines which strategy to use to resolve addresses.
+// The provided Resolver is optional. nil means to not consider its options.
+// It also returns dnsConfig when it was used to determine the lookup order.
+func (c *conf) addrLookupOrder(r *Resolver, addr string) (ret hostLookupOrder, dnsConf *dnsConfig) {
+ if c.dnsDebugLevel > 1 {
+ defer func() {
+ print("go package net: addrLookupOrder(", addr, ") = ", ret.String(), "\n")
+ }()
+ }
+ return c.lookupOrder(r, "")
+}
+
+// hostLookupOrder determines which strategy to use to resolve hostname.
+// The provided Resolver is optional. nil means to not consider its options.
+// It also returns dnsConfig when it was used to determine the lookup order.
+func (c *conf) hostLookupOrder(r *Resolver, hostname string) (ret hostLookupOrder, dnsConf *dnsConfig) {
+ if c.dnsDebugLevel > 1 {
+ defer func() {
+ print("go package net: hostLookupOrder(", hostname, ") = ", ret.String(), "\n")
+ }()
+ }
+ return c.lookupOrder(r, hostname)
+}
+
+func (c *conf) lookupOrder(r *Resolver, hostname string) (ret hostLookupOrder, dnsConf *dnsConfig) {
+ // fallbackOrder is the order we return if we can't figure it out.
+ var fallbackOrder hostLookupOrder
+
+ var canUseCgo bool
+ if c.mustUseGoResolver(r) {
+ // Go resolver was explicitly requested
+ // or cgo resolver is not available.
+ // Figure out the order below.
+ switch c.goos {
+ case "windows":
+ // TODO(bradfitz): implement files-based
+ // lookup on Windows too? I guess /etc/hosts
+ // kinda exists on Windows. But for now, only
+ // do DNS.
+ fallbackOrder = hostLookupDNS
+ default:
+ fallbackOrder = hostLookupFilesDNS
+ }
+ canUseCgo = false
+ } else if c.netCgo {
+ // Cgo resolver was explicitly requested.
+ return hostLookupCgo, nil
+ } else if c.preferCgo {
+ // Given a choice, we prefer the cgo resolver.
+ return hostLookupCgo, nil
+ } else {
+ // Neither resolver was explicitly requested
+ // and we have no preference.
+
+ if bytealg.IndexByteString(hostname, '\\') != -1 || bytealg.IndexByteString(hostname, '%') != -1 {
+ // Don't deal with special form hostnames
+ // with backslashes or '%'.
+ return hostLookupCgo, nil
+ }
+
+ // If something is unrecognized, use cgo.
+ fallbackOrder = hostLookupCgo
+ canUseCgo = true
+ }
+
+ // On systems that don't use /etc/resolv.conf or /etc/nsswitch.conf, we are done.
+ switch c.goos {
+ case "windows", "plan9", "android", "ios":
+ return fallbackOrder, nil
+ }
+
+ // Try to figure out the order to use for searches.
+ // If we don't recognize something, use fallbackOrder.
+ // That will use cgo unless the Go resolver was explicitly requested.
+ // If we do figure out the order, return something other
+ // than fallbackOrder to use the Go resolver with that order.
+
+ dnsConf = getSystemDNSConfig()
+
+ if canUseCgo && dnsConf.err != nil && !errors.Is(dnsConf.err, fs.ErrNotExist) && !errors.Is(dnsConf.err, fs.ErrPermission) {
+ // We can't read the resolv.conf file, so use cgo if we can.
+ return hostLookupCgo, dnsConf
+ }
+
+ if canUseCgo && dnsConf.unknownOpt {
+ // We didn't recognize something in resolv.conf,
+ // so use cgo if we can.
+ return hostLookupCgo, dnsConf
+ }
+
+ // OpenBSD is unique and doesn't use nsswitch.conf.
+ // It also doesn't support mDNS.
+ if c.goos == "openbsd" {
+ // OpenBSD's resolv.conf manpage says that a
+ // non-existent resolv.conf means "lookup" defaults
+ // to only "files", without DNS lookups.
+ if errors.Is(dnsConf.err, fs.ErrNotExist) {
+ return hostLookupFiles, dnsConf
+ }
+
+ lookup := dnsConf.lookup
+ if len(lookup) == 0 {
+ // https://www.openbsd.org/cgi-bin/man.cgi/OpenBSD-current/man5/resolv.conf.5
+ // "If the lookup keyword is not used in the
+ // system's resolv.conf file then the assumed
+ // order is 'bind file'"
+ return hostLookupDNSFiles, dnsConf
+ }
+ if len(lookup) < 1 || len(lookup) > 2 {
+ // We don't recognize this format.
+ return fallbackOrder, dnsConf
+ }
+ switch lookup[0] {
+ case "bind":
+ if len(lookup) == 2 {
+ if lookup[1] == "file" {
+ return hostLookupDNSFiles, dnsConf
+ }
+ // Unrecognized.
+ return fallbackOrder, dnsConf
+ }
+ return hostLookupDNS, dnsConf
+ case "file":
+ if len(lookup) == 2 {
+ if lookup[1] == "bind" {
+ return hostLookupFilesDNS, dnsConf
+ }
+ // Unrecognized.
+ return fallbackOrder, dnsConf
+ }
+ return hostLookupFiles, dnsConf
+ default:
+ // Unrecognized.
+ return fallbackOrder, dnsConf
+ }
+
+ // We always return before this point.
+ // The code below is for non-OpenBSD.
+ }
+
+ // Canonicalize the hostname by removing any trailing dot.
+ if stringsHasSuffix(hostname, ".") {
+ hostname = hostname[:len(hostname)-1]
+ }
+ if canUseCgo && stringsHasSuffixFold(hostname, ".local") {
+ // Per RFC 6762, the ".local" TLD is special. And
+ // because Go's native resolver doesn't do mDNS or
+ // similar local resolution mechanisms, assume that
+ // libc might (via Avahi, etc) and use cgo.
+ return hostLookupCgo, dnsConf
+ }
+
+ nss := getSystemNSS()
+ srcs := nss.sources["hosts"]
+ // If /etc/nsswitch.conf doesn't exist or doesn't specify any
+ // sources for "hosts", assume Go's DNS will work fine.
+ if errors.Is(nss.err, fs.ErrNotExist) || (nss.err == nil && len(srcs) == 0) {
+ if canUseCgo && c.goos == "solaris" {
+ // illumos defaults to
+ // "nis [NOTFOUND=return] files",
+ // which the go resolver doesn't support.
+ return hostLookupCgo, dnsConf
+ }
+
+ return hostLookupFilesDNS, dnsConf
+ }
+ if nss.err != nil {
+ // We failed to parse or open nsswitch.conf, so
+ // we have nothing to base an order on.
+ return fallbackOrder, dnsConf
+ }
+
+ var hasDNSSource bool
+ var hasDNSSourceChecked bool
+
+ var filesSource, dnsSource bool
+ var first string
+ for i, src := range srcs {
+ if src.source == "files" || src.source == "dns" {
+ if canUseCgo && !src.standardCriteria() {
+ // non-standard; let libc deal with it.
+ return hostLookupCgo, dnsConf
+ }
+ if src.source == "files" {
+ filesSource = true
+ } else {
+ hasDNSSource = true
+ hasDNSSourceChecked = true
+ dnsSource = true
+ }
+ if first == "" {
+ first = src.source
+ }
+ continue
+ }
+
+ if canUseCgo {
+ switch {
+ case hostname != "" && src.source == "myhostname":
+ // Let the cgo resolver handle myhostname
+ // if we are looking up the local hostname.
+ if isLocalhost(hostname) || isGateway(hostname) || isOutbound(hostname) {
+ return hostLookupCgo, dnsConf
+ }
+ hn, err := getHostname()
+ if err != nil || stringsEqualFold(hostname, hn) {
+ return hostLookupCgo, dnsConf
+ }
+ continue
+ case hostname != "" && stringsHasPrefix(src.source, "mdns"):
+ // e.g. "mdns4", "mdns4_minimal"
+ // We already returned true before if it was *.local.
+ // libc wouldn't have found a hit on this anyway.
+
+ // We don't parse mdns.allow files. They're rare. If one
+ // exists, it might list other TLDs (besides .local) or even
+ // '*', so just let libc deal with it.
+ var haveMDNSAllow bool
+ switch c.mdnsTest {
+ case mdnsFromSystem:
+ _, err := os.Stat("/etc/mdns.allow")
+ if err != nil && !errors.Is(err, fs.ErrNotExist) {
+ // Let libc figure out what is going on.
+ return hostLookupCgo, dnsConf
+ }
+ haveMDNSAllow = err == nil
+ case mdnsAssumeExists:
+ haveMDNSAllow = true
+ case mdnsAssumeDoesNotExist:
+ haveMDNSAllow = false
+ }
+ if haveMDNSAllow {
+ return hostLookupCgo, dnsConf
+ }
+ continue
+ default:
+ // Some source we don't know how to deal with.
+ return hostLookupCgo, dnsConf
+ }
+ }
+
+ if !hasDNSSourceChecked {
+ hasDNSSourceChecked = true
+ for _, v := range srcs[i+1:] {
+ if v.source == "dns" {
+ hasDNSSource = true
+ break
+ }
+ }
+ }
+
+ // If we saw a source we don't recognize, which can only
+ // happen if we can't use the cgo resolver, treat it as DNS,
+ // but only when there is no dns in all other sources.
+ if !hasDNSSource {
+ dnsSource = true
+ if first == "" {
+ first = "dns"
+ }
+ }
+ }
+
+ // Cases where Go can handle it without cgo and C thread overhead,
+ // or where the Go resolver has been forced.
+ switch {
+ case filesSource && dnsSource:
+ if first == "files" {
+ return hostLookupFilesDNS, dnsConf
+ } else {
+ return hostLookupDNSFiles, dnsConf
+ }
+ case filesSource:
+ return hostLookupFiles, dnsConf
+ case dnsSource:
+ return hostLookupDNS, dnsConf
+ }
+
+ // Something weird. Fallback to the default.
+ return fallbackOrder, dnsConf
+}
+
+var netdns = godebug.New("netdns")
+
+// goDebugNetDNS parses the value of the GODEBUG "netdns" value.
+// The netdns value can be of the form:
+//
+// 1 // debug level 1
+// 2 // debug level 2
+// cgo // use cgo for DNS lookups
+// go // use go for DNS lookups
+// cgo+1 // use cgo for DNS lookups + debug level 1
+// 1+cgo // same
+// cgo+2 // same, but debug level 2
+//
+// etc.
+func goDebugNetDNS() (dnsMode string, debugLevel int) {
+ goDebug := netdns.Value()
+ parsePart := func(s string) {
+ if s == "" {
+ return
+ }
+ if '0' <= s[0] && s[0] <= '9' {
+ debugLevel, _, _ = dtoi(s)
+ } else {
+ dnsMode = s
+ }
+ }
+ if i := bytealg.IndexByteString(goDebug, '+'); i != -1 {
+ parsePart(goDebug[:i])
+ parsePart(goDebug[i+1:])
+ return
+ }
+ parsePart(goDebug)
+ return
+}
+
+// isLocalhost reports whether h should be considered a "localhost"
+// name for the myhostname NSS module.
+func isLocalhost(h string) bool {
+ return stringsEqualFold(h, "localhost") || stringsEqualFold(h, "localhost.localdomain") || stringsHasSuffixFold(h, ".localhost") || stringsHasSuffixFold(h, ".localhost.localdomain")
+}
+
+// isGateway reports whether h should be considered a "gateway"
+// name for the myhostname NSS module.
+func isGateway(h string) bool {
+ return stringsEqualFold(h, "_gateway")
+}
+
+// isOutbound reports whether h should be considered a "outbound"
+// name for the myhostname NSS module.
+func isOutbound(h string) bool {
+ return stringsEqualFold(h, "_outbound")
+}
diff --git a/src/net/conf_test.go b/src/net/conf_test.go
new file mode 100644
index 0000000..0f324b2
--- /dev/null
+++ b/src/net/conf_test.go
@@ -0,0 +1,461 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build unix
+
+package net
+
+import (
+ "io/fs"
+ "os"
+ "testing"
+ "time"
+)
+
+type nssHostTest struct {
+ host string
+ localhost string
+ want hostLookupOrder
+}
+
+func nssStr(t *testing.T, s string) *nssConf {
+ f, err := os.CreateTemp(t.TempDir(), "nss")
+ if err != nil {
+ t.Fatal(err)
+ }
+ if _, err := f.WriteString(s); err != nil {
+ t.Fatal(err)
+ }
+ if err := f.Close(); err != nil {
+ t.Fatal(err)
+ }
+ return parseNSSConfFile(f.Name())
+}
+
+// represents a dnsConfig returned by parsing a nonexistent resolv.conf
+var defaultResolvConf = &dnsConfig{
+ servers: defaultNS,
+ ndots: 1,
+ timeout: 5,
+ attempts: 2,
+ err: fs.ErrNotExist,
+}
+
+func TestConfHostLookupOrder(t *testing.T) {
+ // These tests are written for a system with cgo available,
+ // without using the netgo tag.
+ if netGoBuildTag {
+ t.Skip("skipping test because net package built with netgo tag")
+ }
+ if !cgoAvailable {
+ t.Skip("skipping test because cgo resolver not available")
+ }
+
+ tests := []struct {
+ name string
+ c *conf
+ nss *nssConf
+ resolver *Resolver
+ resolv *dnsConfig
+ hostTests []nssHostTest
+ }{
+ {
+ name: "force",
+ c: &conf{
+ preferCgo: true,
+ netCgo: true,
+ },
+ resolv: defaultResolvConf,
+ nss: nssStr(t, "foo: bar"),
+ hostTests: []nssHostTest{
+ {"foo.local", "myhostname", hostLookupCgo},
+ {"google.com", "myhostname", hostLookupCgo},
+ },
+ },
+ {
+ name: "netgo_dns_before_files",
+ c: &conf{
+ netGo: true,
+ },
+ resolv: defaultResolvConf,
+ nss: nssStr(t, "hosts: dns files"),
+ hostTests: []nssHostTest{
+ {"x.com", "myhostname", hostLookupDNSFiles},
+ },
+ },
+ {
+ name: "netgo_fallback_on_cgo",
+ c: &conf{
+ netGo: true,
+ },
+ resolv: defaultResolvConf,
+ nss: nssStr(t, "hosts: dns files something_custom"),
+ hostTests: []nssHostTest{
+ {"x.com", "myhostname", hostLookupDNSFiles},
+ },
+ },
+ {
+ name: "ubuntu_trusty_avahi",
+ c: &conf{
+ mdnsTest: mdnsAssumeDoesNotExist,
+ },
+ resolv: defaultResolvConf,
+ nss: nssStr(t, "hosts: files mdns4_minimal [NOTFOUND=return] dns mdns4"),
+ hostTests: []nssHostTest{
+ {"foo.local", "myhostname", hostLookupCgo},
+ {"foo.local.", "myhostname", hostLookupCgo},
+ {"foo.LOCAL", "myhostname", hostLookupCgo},
+ {"foo.LOCAL.", "myhostname", hostLookupCgo},
+ {"google.com", "myhostname", hostLookupFilesDNS},
+ },
+ },
+ {
+ name: "freebsdlinux_no_resolv_conf",
+ c: &conf{
+ goos: "freebsd",
+ },
+ resolv: defaultResolvConf,
+ nss: nssStr(t, "foo: bar"),
+ hostTests: []nssHostTest{{"google.com", "myhostname", hostLookupFilesDNS}},
+ },
+ // On OpenBSD, no resolv.conf means no DNS.
+ {
+ name: "openbsd_no_resolv_conf",
+ c: &conf{
+ goos: "openbsd",
+ },
+ resolv: defaultResolvConf,
+ hostTests: []nssHostTest{{"google.com", "myhostname", hostLookupFiles}},
+ },
+ {
+ name: "solaris_no_nsswitch",
+ c: &conf{
+ goos: "solaris",
+ },
+ resolv: defaultResolvConf,
+ nss: &nssConf{err: fs.ErrNotExist},
+ hostTests: []nssHostTest{{"google.com", "myhostname", hostLookupCgo}},
+ },
+ {
+ name: "openbsd_lookup_bind_file",
+ c: &conf{
+ goos: "openbsd",
+ },
+ resolv: &dnsConfig{lookup: []string{"bind", "file"}},
+ hostTests: []nssHostTest{
+ {"google.com", "myhostname", hostLookupDNSFiles},
+ {"foo.local", "myhostname", hostLookupDNSFiles},
+ },
+ },
+ {
+ name: "openbsd_lookup_file_bind",
+ c: &conf{
+ goos: "openbsd",
+ },
+ resolv: &dnsConfig{lookup: []string{"file", "bind"}},
+ hostTests: []nssHostTest{{"google.com", "myhostname", hostLookupFilesDNS}},
+ },
+ {
+ name: "openbsd_lookup_bind",
+ c: &conf{
+ goos: "openbsd",
+ },
+ resolv: &dnsConfig{lookup: []string{"bind"}},
+ hostTests: []nssHostTest{{"google.com", "myhostname", hostLookupDNS}},
+ },
+ {
+ name: "openbsd_lookup_file",
+ c: &conf{
+ goos: "openbsd",
+ },
+ resolv: &dnsConfig{lookup: []string{"file"}},
+ hostTests: []nssHostTest{{"google.com", "myhostname", hostLookupFiles}},
+ },
+ {
+ name: "openbsd_lookup_yp",
+ c: &conf{
+ goos: "openbsd",
+ },
+ resolv: &dnsConfig{lookup: []string{"file", "bind", "yp"}},
+ hostTests: []nssHostTest{{"google.com", "myhostname", hostLookupCgo}},
+ },
+ {
+ name: "openbsd_lookup_two",
+ c: &conf{
+ goos: "openbsd",
+ },
+ resolv: &dnsConfig{lookup: []string{"file", "foo"}},
+ hostTests: []nssHostTest{{"google.com", "myhostname", hostLookupCgo}},
+ },
+ {
+ name: "openbsd_lookup_empty",
+ c: &conf{
+ goos: "openbsd",
+ },
+ resolv: &dnsConfig{lookup: nil},
+ hostTests: []nssHostTest{{"google.com", "myhostname", hostLookupDNSFiles}},
+ },
+ {
+ name: "linux_no_nsswitch.conf",
+ c: &conf{
+ goos: "linux",
+ },
+ resolv: defaultResolvConf,
+ nss: &nssConf{err: fs.ErrNotExist},
+ hostTests: []nssHostTest{{"google.com", "myhostname", hostLookupFilesDNS}},
+ },
+ {
+ name: "linux_empty_nsswitch.conf",
+ c: &conf{
+ goos: "linux",
+ },
+ resolv: defaultResolvConf,
+ nss: nssStr(t, ""),
+ hostTests: []nssHostTest{{"google.com", "myhostname", hostLookupFilesDNS}},
+ },
+ {
+ name: "files_mdns_dns",
+ c: &conf{
+ mdnsTest: mdnsAssumeDoesNotExist,
+ },
+ resolv: defaultResolvConf,
+ nss: nssStr(t, "hosts: files mdns dns"),
+ hostTests: []nssHostTest{
+ {"x.com", "myhostname", hostLookupFilesDNS},
+ {"x.local", "myhostname", hostLookupCgo},
+ },
+ },
+ {
+ name: "dns_special_hostnames",
+ c: &conf{},
+ resolv: defaultResolvConf,
+ nss: nssStr(t, "hosts: dns"),
+ hostTests: []nssHostTest{
+ {"x.com", "myhostname", hostLookupDNS},
+ {"x\\.com", "myhostname", hostLookupCgo}, // punt on weird glibc escape
+ {"foo.com%en0", "myhostname", hostLookupCgo}, // and IPv6 zones
+ },
+ },
+ {
+ name: "mdns_allow",
+ c: &conf{
+ mdnsTest: mdnsAssumeExists,
+ },
+ resolv: defaultResolvConf,
+ nss: nssStr(t, "hosts: files mdns dns"),
+ hostTests: []nssHostTest{
+ {"x.com", "myhostname", hostLookupCgo},
+ {"x.local", "myhostname", hostLookupCgo},
+ },
+ },
+ {
+ name: "files_dns",
+ c: &conf{},
+ resolv: defaultResolvConf,
+ nss: nssStr(t, "hosts: files dns"),
+ hostTests: []nssHostTest{
+ {"x.com", "myhostname", hostLookupFilesDNS},
+ {"x", "myhostname", hostLookupFilesDNS},
+ {"x.local", "myhostname", hostLookupCgo},
+ },
+ },
+ {
+ name: "dns_files",
+ c: &conf{},
+ resolv: defaultResolvConf,
+ nss: nssStr(t, "hosts: dns files"),
+ hostTests: []nssHostTest{
+ {"x.com", "myhostname", hostLookupDNSFiles},
+ {"x", "myhostname", hostLookupDNSFiles},
+ {"x.local", "myhostname", hostLookupCgo},
+ },
+ },
+ {
+ name: "something_custom",
+ c: &conf{},
+ resolv: defaultResolvConf,
+ nss: nssStr(t, "hosts: dns files something_custom"),
+ hostTests: []nssHostTest{
+ {"x.com", "myhostname", hostLookupCgo},
+ },
+ },
+ {
+ name: "myhostname",
+ c: &conf{},
+ resolv: defaultResolvConf,
+ nss: nssStr(t, "hosts: files dns myhostname"),
+ hostTests: []nssHostTest{
+ {"x.com", "myhostname", hostLookupFilesDNS},
+ {"myhostname", "myhostname", hostLookupCgo},
+ {"myHostname", "myhostname", hostLookupCgo},
+ {"myhostname.dot", "myhostname.dot", hostLookupCgo},
+ {"myHostname.dot", "myhostname.dot", hostLookupCgo},
+ {"_gateway", "myhostname", hostLookupCgo},
+ {"_Gateway", "myhostname", hostLookupCgo},
+ {"_outbound", "myhostname", hostLookupCgo},
+ {"_Outbound", "myhostname", hostLookupCgo},
+ {"localhost", "myhostname", hostLookupCgo},
+ {"Localhost", "myhostname", hostLookupCgo},
+ {"anything.localhost", "myhostname", hostLookupCgo},
+ {"Anything.localhost", "myhostname", hostLookupCgo},
+ {"localhost.localdomain", "myhostname", hostLookupCgo},
+ {"Localhost.Localdomain", "myhostname", hostLookupCgo},
+ {"anything.localhost.localdomain", "myhostname", hostLookupCgo},
+ {"Anything.Localhost.Localdomain", "myhostname", hostLookupCgo},
+ {"somehostname", "myhostname", hostLookupFilesDNS},
+ },
+ },
+ {
+ name: "ubuntu14.04.02",
+ c: &conf{
+ mdnsTest: mdnsAssumeDoesNotExist,
+ },
+ resolv: defaultResolvConf,
+ nss: nssStr(t, "hosts: files myhostname mdns4_minimal [NOTFOUND=return] dns mdns4"),
+ hostTests: []nssHostTest{
+ {"x.com", "myhostname", hostLookupFilesDNS},
+ {"somehostname", "myhostname", hostLookupFilesDNS},
+ {"myhostname", "myhostname", hostLookupCgo},
+ },
+ },
+ // Debian Squeeze is just "dns,files", but lists all
+ // the default criteria for dns, but then has a
+ // non-standard but redundant notfound=return for the
+ // files.
+ {
+ name: "debian_squeeze",
+ c: &conf{},
+ resolv: defaultResolvConf,
+ nss: nssStr(t, "hosts: dns [success=return notfound=continue unavail=continue tryagain=continue] files [notfound=return]"),
+ hostTests: []nssHostTest{
+ {"x.com", "myhostname", hostLookupDNSFiles},
+ {"somehostname", "myhostname", hostLookupDNSFiles},
+ },
+ },
+ {
+ name: "resolv.conf-unknown",
+ c: &conf{},
+ resolv: &dnsConfig{servers: defaultNS, ndots: 1, timeout: 5, attempts: 2, unknownOpt: true},
+ nss: nssStr(t, "foo: bar"),
+ hostTests: []nssHostTest{{"google.com", "myhostname", hostLookupCgo}},
+ },
+ // Issue 24393: make sure "Resolver.PreferGo = true" acts like netgo.
+ {
+ name: "resolver-prefergo",
+ resolver: &Resolver{PreferGo: true},
+ c: &conf{
+ preferCgo: true,
+ netCgo: true,
+ },
+ resolv: defaultResolvConf,
+ nss: nssStr(t, ""),
+ hostTests: []nssHostTest{
+ {"localhost", "myhostname", hostLookupFilesDNS},
+ },
+ },
+ {
+ name: "unknown-source",
+ resolver: &Resolver{PreferGo: true},
+ c: &conf{},
+ resolv: defaultResolvConf,
+ nss: nssStr(t, "hosts: resolve files"),
+ hostTests: []nssHostTest{
+ {"x.com", "myhostname", hostLookupDNSFiles},
+ },
+ },
+ {
+ name: "dns-among-unknown-sources",
+ resolver: &Resolver{PreferGo: true},
+ c: &conf{},
+ resolv: defaultResolvConf,
+ nss: nssStr(t, "hosts: mymachines files dns"),
+ hostTests: []nssHostTest{
+ {"x.com", "myhostname", hostLookupFilesDNS},
+ },
+ },
+ {
+ name: "dns-among-unknown-sources-2",
+ resolver: &Resolver{PreferGo: true},
+ c: &conf{},
+ resolv: defaultResolvConf,
+ nss: nssStr(t, "hosts: dns mymachines files"),
+ hostTests: []nssHostTest{
+ {"x.com", "myhostname", hostLookupDNSFiles},
+ },
+ },
+ }
+
+ origGetHostname := getHostname
+ defer func() { getHostname = origGetHostname }()
+ defer setSystemNSS(getSystemNSS(), 0)
+ conf, err := newResolvConfTest()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conf.teardown()
+
+ for _, tt := range tests {
+ if !conf.forceUpdateConf(tt.resolv, time.Now().Add(time.Hour)) {
+ t.Errorf("%s: failed to change resolv config", tt.name)
+ }
+ for _, ht := range tt.hostTests {
+ getHostname = func() (string, error) { return ht.localhost, nil }
+ setSystemNSS(tt.nss, time.Hour)
+
+ gotOrder, _ := tt.c.hostLookupOrder(tt.resolver, ht.host)
+ if gotOrder != ht.want {
+ t.Errorf("%s: hostLookupOrder(%q) = %v; want %v", tt.name, ht.host, gotOrder, ht.want)
+ }
+ }
+ }
+}
+
+func TestAddrLookupOrder(t *testing.T) {
+ // This test is written for a system with cgo available,
+ // without using the netgo tag.
+ if netGoBuildTag {
+ t.Skip("skipping test because net package built with netgo tag")
+ }
+ if !cgoAvailable {
+ t.Skip("skipping test because cgo resolver not available")
+ }
+
+ defer setSystemNSS(getSystemNSS(), 0)
+ c, err := newResolvConfTest()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.teardown()
+
+ if !c.forceUpdateConf(defaultResolvConf, time.Now().Add(time.Hour)) {
+ t.Fatal("failed to change resolv config")
+ }
+
+ setSystemNSS(nssStr(t, "hosts: files myhostname dns"), time.Hour)
+ cnf := &conf{}
+ order, _ := cnf.addrLookupOrder(nil, "192.0.2.1")
+ if order != hostLookupCgo {
+ t.Errorf("addrLookupOrder returned: %v, want cgo", order)
+ }
+
+ setSystemNSS(nssStr(t, "hosts: files mdns4 dns"), time.Hour)
+ order, _ = cnf.addrLookupOrder(nil, "192.0.2.1")
+ if order != hostLookupCgo {
+ t.Errorf("addrLookupOrder returned: %v, want cgo", order)
+ }
+
+}
+
+func setSystemNSS(nss *nssConf, addDur time.Duration) {
+ nssConfig.mu.Lock()
+ nssConfig.nssConf = nss
+ nssConfig.mu.Unlock()
+ nssConfig.acquireSema()
+ nssConfig.lastChecked = time.Now().Add(addDur)
+ nssConfig.releaseSema()
+}
+
+func TestSystemConf(t *testing.T) {
+ systemConf()
+}
diff --git a/src/net/conn_test.go b/src/net/conn_test.go
new file mode 100644
index 0000000..4f391b0
--- /dev/null
+++ b/src/net/conn_test.go
@@ -0,0 +1,64 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This file implements API tests across platforms and will never have a build
+// tag.
+
+//go:build !js && !wasip1
+
+package net
+
+import (
+ "testing"
+ "time"
+)
+
+// someTimeout is used just to test that net.Conn implementations
+// don't explode when their SetFooDeadline methods are called.
+// It isn't actually used for testing timeouts.
+const someTimeout = 1 * time.Hour
+
+func TestConnAndListener(t *testing.T) {
+ for i, network := range []string{"tcp", "unix", "unixpacket"} {
+ if !testableNetwork(network) {
+ t.Logf("skipping %s test", network)
+ continue
+ }
+
+ ls := newLocalServer(t, network)
+ defer ls.teardown()
+ ch := make(chan error, 1)
+ handler := func(ls *localServer, ln Listener) { ls.transponder(ln, ch) }
+ if err := ls.buildup(handler); err != nil {
+ t.Fatal(err)
+ }
+ if ls.Listener.Addr().Network() != network {
+ t.Fatalf("got %s; want %s", ls.Listener.Addr().Network(), network)
+ }
+
+ c, err := Dial(ls.Listener.Addr().Network(), ls.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+ if c.LocalAddr().Network() != network || c.RemoteAddr().Network() != network {
+ t.Fatalf("got %s->%s; want %s->%s", c.LocalAddr().Network(), c.RemoteAddr().Network(), network, network)
+ }
+ c.SetDeadline(time.Now().Add(someTimeout))
+ c.SetReadDeadline(time.Now().Add(someTimeout))
+ c.SetWriteDeadline(time.Now().Add(someTimeout))
+
+ if _, err := c.Write([]byte("CONN AND LISTENER TEST")); err != nil {
+ t.Fatal(err)
+ }
+ rb := make([]byte, 128)
+ if _, err := c.Read(rb); err != nil {
+ t.Fatal(err)
+ }
+
+ for err := range ch {
+ t.Errorf("#%d: %v", i, err)
+ }
+ }
+}
diff --git a/src/net/dial.go b/src/net/dial.go
new file mode 100644
index 0000000..79bc495
--- /dev/null
+++ b/src/net/dial.go
@@ -0,0 +1,837 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "context"
+ "internal/godebug"
+ "internal/nettrace"
+ "syscall"
+ "time"
+)
+
+const (
+ // defaultTCPKeepAlive is a default constant value for TCPKeepAlive times
+ // See go.dev/issue/31510
+ defaultTCPKeepAlive = 15 * time.Second
+
+ // For the moment, MultiPath TCP is not used by default
+ // See go.dev/issue/56539
+ defaultMPTCPEnabled = false
+)
+
+var multipathtcp = godebug.New("multipathtcp")
+
+// mptcpStatus is a tristate for Multipath TCP, see go.dev/issue/56539
+type mptcpStatus uint8
+
+const (
+ // The value 0 is the system default, linked to defaultMPTCPEnabled
+ mptcpUseDefault mptcpStatus = iota
+ mptcpEnabled
+ mptcpDisabled
+)
+
+func (m *mptcpStatus) get() bool {
+ switch *m {
+ case mptcpEnabled:
+ return true
+ case mptcpDisabled:
+ return false
+ }
+
+ // If MPTCP is forced via GODEBUG=multipathtcp=1
+ if multipathtcp.Value() == "1" {
+ multipathtcp.IncNonDefault()
+
+ return true
+ }
+
+ return defaultMPTCPEnabled
+}
+
+func (m *mptcpStatus) set(use bool) {
+ if use {
+ *m = mptcpEnabled
+ } else {
+ *m = mptcpDisabled
+ }
+}
+
+// A Dialer contains options for connecting to an address.
+//
+// The zero value for each field is equivalent to dialing
+// without that option. Dialing with the zero value of Dialer
+// is therefore equivalent to just calling the Dial function.
+//
+// It is safe to call Dialer's methods concurrently.
+type Dialer struct {
+ // Timeout is the maximum amount of time a dial will wait for
+ // a connect to complete. If Deadline is also set, it may fail
+ // earlier.
+ //
+ // The default is no timeout.
+ //
+ // When using TCP and dialing a host name with multiple IP
+ // addresses, the timeout may be divided between them.
+ //
+ // With or without a timeout, the operating system may impose
+ // its own earlier timeout. For instance, TCP timeouts are
+ // often around 3 minutes.
+ Timeout time.Duration
+
+ // Deadline is the absolute point in time after which dials
+ // will fail. If Timeout is set, it may fail earlier.
+ // Zero means no deadline, or dependent on the operating system
+ // as with the Timeout option.
+ Deadline time.Time
+
+ // LocalAddr is the local address to use when dialing an
+ // address. The address must be of a compatible type for the
+ // network being dialed.
+ // If nil, a local address is automatically chosen.
+ LocalAddr Addr
+
+ // DualStack previously enabled RFC 6555 Fast Fallback
+ // support, also known as "Happy Eyeballs", in which IPv4 is
+ // tried soon if IPv6 appears to be misconfigured and
+ // hanging.
+ //
+ // Deprecated: Fast Fallback is enabled by default. To
+ // disable, set FallbackDelay to a negative value.
+ DualStack bool
+
+ // FallbackDelay specifies the length of time to wait before
+ // spawning a RFC 6555 Fast Fallback connection. That is, this
+ // is the amount of time to wait for IPv6 to succeed before
+ // assuming that IPv6 is misconfigured and falling back to
+ // IPv4.
+ //
+ // If zero, a default delay of 300ms is used.
+ // A negative value disables Fast Fallback support.
+ FallbackDelay time.Duration
+
+ // KeepAlive specifies the interval between keep-alive
+ // probes for an active network connection.
+ // If zero, keep-alive probes are sent with a default value
+ // (currently 15 seconds), if supported by the protocol and operating
+ // system. Network protocols or operating systems that do
+ // not support keep-alives ignore this field.
+ // If negative, keep-alive probes are disabled.
+ KeepAlive time.Duration
+
+ // Resolver optionally specifies an alternate resolver to use.
+ Resolver *Resolver
+
+ // Cancel is an optional channel whose closure indicates that
+ // the dial should be canceled. Not all types of dials support
+ // cancellation.
+ //
+ // Deprecated: Use DialContext instead.
+ Cancel <-chan struct{}
+
+ // If Control is not nil, it is called after creating the network
+ // connection but before actually dialing.
+ //
+ // Network and address parameters passed to Control function are not
+ // necessarily the ones passed to Dial. For example, passing "tcp" to Dial
+ // will cause the Control function to be called with "tcp4" or "tcp6".
+ //
+ // Control is ignored if ControlContext is not nil.
+ Control func(network, address string, c syscall.RawConn) error
+
+ // If ControlContext is not nil, it is called after creating the network
+ // connection but before actually dialing.
+ //
+ // Network and address parameters passed to ControlContext function are not
+ // necessarily the ones passed to Dial. For example, passing "tcp" to Dial
+ // will cause the ControlContext function to be called with "tcp4" or "tcp6".
+ //
+ // If ControlContext is not nil, Control is ignored.
+ ControlContext func(ctx context.Context, network, address string, c syscall.RawConn) error
+
+ // If mptcpStatus is set to a value allowing Multipath TCP (MPTCP) to be
+ // used, any call to Dial with "tcp(4|6)" as network will use MPTCP if
+ // supported by the operating system.
+ mptcpStatus mptcpStatus
+}
+
+func (d *Dialer) dualStack() bool { return d.FallbackDelay >= 0 }
+
+func minNonzeroTime(a, b time.Time) time.Time {
+ if a.IsZero() {
+ return b
+ }
+ if b.IsZero() || a.Before(b) {
+ return a
+ }
+ return b
+}
+
+// deadline returns the earliest of:
+// - now+Timeout
+// - d.Deadline
+// - the context's deadline
+//
+// Or zero, if none of Timeout, Deadline, or context's deadline is set.
+func (d *Dialer) deadline(ctx context.Context, now time.Time) (earliest time.Time) {
+ if d.Timeout != 0 { // including negative, for historical reasons
+ earliest = now.Add(d.Timeout)
+ }
+ if d, ok := ctx.Deadline(); ok {
+ earliest = minNonzeroTime(earliest, d)
+ }
+ return minNonzeroTime(earliest, d.Deadline)
+}
+
+func (d *Dialer) resolver() *Resolver {
+ if d.Resolver != nil {
+ return d.Resolver
+ }
+ return DefaultResolver
+}
+
+// partialDeadline returns the deadline to use for a single address,
+// when multiple addresses are pending.
+func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, error) {
+ if deadline.IsZero() {
+ return deadline, nil
+ }
+ timeRemaining := deadline.Sub(now)
+ if timeRemaining <= 0 {
+ return time.Time{}, errTimeout
+ }
+ // Tentatively allocate equal time to each remaining address.
+ timeout := timeRemaining / time.Duration(addrsRemaining)
+ // If the time per address is too short, steal from the end of the list.
+ const saneMinimum = 2 * time.Second
+ if timeout < saneMinimum {
+ if timeRemaining < saneMinimum {
+ timeout = timeRemaining
+ } else {
+ timeout = saneMinimum
+ }
+ }
+ return now.Add(timeout), nil
+}
+
+func (d *Dialer) fallbackDelay() time.Duration {
+ if d.FallbackDelay > 0 {
+ return d.FallbackDelay
+ } else {
+ return 300 * time.Millisecond
+ }
+}
+
+func parseNetwork(ctx context.Context, network string, needsProto bool) (afnet string, proto int, err error) {
+ i := last(network, ':')
+ if i < 0 { // no colon
+ switch network {
+ case "tcp", "tcp4", "tcp6":
+ case "udp", "udp4", "udp6":
+ case "ip", "ip4", "ip6":
+ if needsProto {
+ return "", 0, UnknownNetworkError(network)
+ }
+ case "unix", "unixgram", "unixpacket":
+ default:
+ return "", 0, UnknownNetworkError(network)
+ }
+ return network, 0, nil
+ }
+ afnet = network[:i]
+ switch afnet {
+ case "ip", "ip4", "ip6":
+ protostr := network[i+1:]
+ proto, i, ok := dtoi(protostr)
+ if !ok || i != len(protostr) {
+ proto, err = lookupProtocol(ctx, protostr)
+ if err != nil {
+ return "", 0, err
+ }
+ }
+ return afnet, proto, nil
+ }
+ return "", 0, UnknownNetworkError(network)
+}
+
+// resolveAddrList resolves addr using hint and returns a list of
+// addresses. The result contains at least one address when error is
+// nil.
+func (r *Resolver) resolveAddrList(ctx context.Context, op, network, addr string, hint Addr) (addrList, error) {
+ afnet, _, err := parseNetwork(ctx, network, true)
+ if err != nil {
+ return nil, err
+ }
+ if op == "dial" && addr == "" {
+ return nil, errMissingAddress
+ }
+ switch afnet {
+ case "unix", "unixgram", "unixpacket":
+ addr, err := ResolveUnixAddr(afnet, addr)
+ if err != nil {
+ return nil, err
+ }
+ if op == "dial" && hint != nil && addr.Network() != hint.Network() {
+ return nil, &AddrError{Err: "mismatched local address type", Addr: hint.String()}
+ }
+ return addrList{addr}, nil
+ }
+ addrs, err := r.internetAddrList(ctx, afnet, addr)
+ if err != nil || op != "dial" || hint == nil {
+ return addrs, err
+ }
+ var (
+ tcp *TCPAddr
+ udp *UDPAddr
+ ip *IPAddr
+ wildcard bool
+ )
+ switch hint := hint.(type) {
+ case *TCPAddr:
+ tcp = hint
+ wildcard = tcp.isWildcard()
+ case *UDPAddr:
+ udp = hint
+ wildcard = udp.isWildcard()
+ case *IPAddr:
+ ip = hint
+ wildcard = ip.isWildcard()
+ }
+ naddrs := addrs[:0]
+ for _, addr := range addrs {
+ if addr.Network() != hint.Network() {
+ return nil, &AddrError{Err: "mismatched local address type", Addr: hint.String()}
+ }
+ switch addr := addr.(type) {
+ case *TCPAddr:
+ if !wildcard && !addr.isWildcard() && !addr.IP.matchAddrFamily(tcp.IP) {
+ continue
+ }
+ naddrs = append(naddrs, addr)
+ case *UDPAddr:
+ if !wildcard && !addr.isWildcard() && !addr.IP.matchAddrFamily(udp.IP) {
+ continue
+ }
+ naddrs = append(naddrs, addr)
+ case *IPAddr:
+ if !wildcard && !addr.isWildcard() && !addr.IP.matchAddrFamily(ip.IP) {
+ continue
+ }
+ naddrs = append(naddrs, addr)
+ }
+ }
+ if len(naddrs) == 0 {
+ return nil, &AddrError{Err: errNoSuitableAddress.Error(), Addr: hint.String()}
+ }
+ return naddrs, nil
+}
+
+// MultipathTCP reports whether MPTCP will be used.
+//
+// This method doesn't check if MPTCP is supported by the operating
+// system or not.
+func (d *Dialer) MultipathTCP() bool {
+ return d.mptcpStatus.get()
+}
+
+// SetMultipathTCP directs the Dial methods to use, or not use, MPTCP,
+// if supported by the operating system. This method overrides the
+// system default and the GODEBUG=multipathtcp=... setting if any.
+//
+// If MPTCP is not available on the host or not supported by the server,
+// the Dial methods will fall back to TCP.
+func (d *Dialer) SetMultipathTCP(use bool) {
+ d.mptcpStatus.set(use)
+}
+
+// Dial connects to the address on the named network.
+//
+// Known networks are "tcp", "tcp4" (IPv4-only), "tcp6" (IPv6-only),
+// "udp", "udp4" (IPv4-only), "udp6" (IPv6-only), "ip", "ip4"
+// (IPv4-only), "ip6" (IPv6-only), "unix", "unixgram" and
+// "unixpacket".
+//
+// For TCP and UDP networks, the address has the form "host:port".
+// The host must be a literal IP address, or a host name that can be
+// resolved to IP addresses.
+// The port must be a literal port number or a service name.
+// If the host is a literal IPv6 address it must be enclosed in square
+// brackets, as in "[2001:db8::1]:80" or "[fe80::1%zone]:80".
+// The zone specifies the scope of the literal IPv6 address as defined
+// in RFC 4007.
+// The functions JoinHostPort and SplitHostPort manipulate a pair of
+// host and port in this form.
+// When using TCP, and the host resolves to multiple IP addresses,
+// Dial will try each IP address in order until one succeeds.
+//
+// Examples:
+//
+// Dial("tcp", "golang.org:http")
+// Dial("tcp", "192.0.2.1:http")
+// Dial("tcp", "198.51.100.1:80")
+// Dial("udp", "[2001:db8::1]:domain")
+// Dial("udp", "[fe80::1%lo0]:53")
+// Dial("tcp", ":80")
+//
+// For IP networks, the network must be "ip", "ip4" or "ip6" followed
+// by a colon and a literal protocol number or a protocol name, and
+// the address has the form "host". The host must be a literal IP
+// address or a literal IPv6 address with zone.
+// It depends on each operating system how the operating system
+// behaves with a non-well known protocol number such as "0" or "255".
+//
+// Examples:
+//
+// Dial("ip4:1", "192.0.2.1")
+// Dial("ip6:ipv6-icmp", "2001:db8::1")
+// Dial("ip6:58", "fe80::1%lo0")
+//
+// For TCP, UDP and IP networks, if the host is empty or a literal
+// unspecified IP address, as in ":80", "0.0.0.0:80" or "[::]:80" for
+// TCP and UDP, "", "0.0.0.0" or "::" for IP, the local system is
+// assumed.
+//
+// For Unix networks, the address must be a file system path.
+func Dial(network, address string) (Conn, error) {
+ var d Dialer
+ return d.Dial(network, address)
+}
+
+// DialTimeout acts like Dial but takes a timeout.
+//
+// The timeout includes name resolution, if required.
+// When using TCP, and the host in the address parameter resolves to
+// multiple IP addresses, the timeout is spread over each consecutive
+// dial, such that each is given an appropriate fraction of the time
+// to connect.
+//
+// See func Dial for a description of the network and address
+// parameters.
+func DialTimeout(network, address string, timeout time.Duration) (Conn, error) {
+ d := Dialer{Timeout: timeout}
+ return d.Dial(network, address)
+}
+
+// sysDialer contains a Dial's parameters and configuration.
+type sysDialer struct {
+ Dialer
+ network, address string
+ testHookDialTCP func(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error)
+}
+
+// Dial connects to the address on the named network.
+//
+// See func Dial for a description of the network and address
+// parameters.
+//
+// Dial uses context.Background internally; to specify the context, use
+// DialContext.
+func (d *Dialer) Dial(network, address string) (Conn, error) {
+ return d.DialContext(context.Background(), network, address)
+}
+
+// DialContext connects to the address on the named network using
+// the provided context.
+//
+// The provided Context must be non-nil. If the context expires before
+// the connection is complete, an error is returned. Once successfully
+// connected, any expiration of the context will not affect the
+// connection.
+//
+// When using TCP, and the host in the address parameter resolves to multiple
+// network addresses, any dial timeout (from d.Timeout or ctx) is spread
+// over each consecutive dial, such that each is given an appropriate
+// fraction of the time to connect.
+// For example, if a host has 4 IP addresses and the timeout is 1 minute,
+// the connect to each single address will be given 15 seconds to complete
+// before trying the next one.
+//
+// See func Dial for a description of the network and address
+// parameters.
+func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn, error) {
+ if ctx == nil {
+ panic("nil context")
+ }
+ deadline := d.deadline(ctx, time.Now())
+ if !deadline.IsZero() {
+ if d, ok := ctx.Deadline(); !ok || deadline.Before(d) {
+ subCtx, cancel := context.WithDeadline(ctx, deadline)
+ defer cancel()
+ ctx = subCtx
+ }
+ }
+ if oldCancel := d.Cancel; oldCancel != nil {
+ subCtx, cancel := context.WithCancel(ctx)
+ defer cancel()
+ go func() {
+ select {
+ case <-oldCancel:
+ cancel()
+ case <-subCtx.Done():
+ }
+ }()
+ ctx = subCtx
+ }
+
+ // Shadow the nettrace (if any) during resolve so Connect events don't fire for DNS lookups.
+ resolveCtx := ctx
+ if trace, _ := ctx.Value(nettrace.TraceKey{}).(*nettrace.Trace); trace != nil {
+ shadow := *trace
+ shadow.ConnectStart = nil
+ shadow.ConnectDone = nil
+ resolveCtx = context.WithValue(resolveCtx, nettrace.TraceKey{}, &shadow)
+ }
+
+ addrs, err := d.resolver().resolveAddrList(resolveCtx, "dial", network, address, d.LocalAddr)
+ if err != nil {
+ return nil, &OpError{Op: "dial", Net: network, Source: nil, Addr: nil, Err: err}
+ }
+
+ sd := &sysDialer{
+ Dialer: *d,
+ network: network,
+ address: address,
+ }
+
+ var primaries, fallbacks addrList
+ if d.dualStack() && network == "tcp" {
+ primaries, fallbacks = addrs.partition(isIPv4)
+ } else {
+ primaries = addrs
+ }
+
+ return sd.dialParallel(ctx, primaries, fallbacks)
+}
+
+// dialParallel races two copies of dialSerial, giving the first a
+// head start. It returns the first established connection and
+// closes the others. Otherwise it returns an error from the first
+// primary address.
+func (sd *sysDialer) dialParallel(ctx context.Context, primaries, fallbacks addrList) (Conn, error) {
+ if len(fallbacks) == 0 {
+ return sd.dialSerial(ctx, primaries)
+ }
+
+ returned := make(chan struct{})
+ defer close(returned)
+
+ type dialResult struct {
+ Conn
+ error
+ primary bool
+ done bool
+ }
+ results := make(chan dialResult) // unbuffered
+
+ startRacer := func(ctx context.Context, primary bool) {
+ ras := primaries
+ if !primary {
+ ras = fallbacks
+ }
+ c, err := sd.dialSerial(ctx, ras)
+ select {
+ case results <- dialResult{Conn: c, error: err, primary: primary, done: true}:
+ case <-returned:
+ if c != nil {
+ c.Close()
+ }
+ }
+ }
+
+ var primary, fallback dialResult
+
+ // Start the main racer.
+ primaryCtx, primaryCancel := context.WithCancel(ctx)
+ defer primaryCancel()
+ go startRacer(primaryCtx, true)
+
+ // Start the timer for the fallback racer.
+ fallbackTimer := time.NewTimer(sd.fallbackDelay())
+ defer fallbackTimer.Stop()
+
+ for {
+ select {
+ case <-fallbackTimer.C:
+ fallbackCtx, fallbackCancel := context.WithCancel(ctx)
+ defer fallbackCancel()
+ go startRacer(fallbackCtx, false)
+
+ case res := <-results:
+ if res.error == nil {
+ return res.Conn, nil
+ }
+ if res.primary {
+ primary = res
+ } else {
+ fallback = res
+ }
+ if primary.done && fallback.done {
+ return nil, primary.error
+ }
+ if res.primary && fallbackTimer.Stop() {
+ // If we were able to stop the timer, that means it
+ // was running (hadn't yet started the fallback), but
+ // we just got an error on the primary path, so start
+ // the fallback immediately (in 0 nanoseconds).
+ fallbackTimer.Reset(0)
+ }
+ }
+ }
+}
+
+// dialSerial connects to a list of addresses in sequence, returning
+// either the first successful connection, or the first error.
+func (sd *sysDialer) dialSerial(ctx context.Context, ras addrList) (Conn, error) {
+ var firstErr error // The error from the first address is most relevant.
+
+ for i, ra := range ras {
+ select {
+ case <-ctx.Done():
+ return nil, &OpError{Op: "dial", Net: sd.network, Source: sd.LocalAddr, Addr: ra, Err: mapErr(ctx.Err())}
+ default:
+ }
+
+ dialCtx := ctx
+ if deadline, hasDeadline := ctx.Deadline(); hasDeadline {
+ partialDeadline, err := partialDeadline(time.Now(), deadline, len(ras)-i)
+ if err != nil {
+ // Ran out of time.
+ if firstErr == nil {
+ firstErr = &OpError{Op: "dial", Net: sd.network, Source: sd.LocalAddr, Addr: ra, Err: err}
+ }
+ break
+ }
+ if partialDeadline.Before(deadline) {
+ var cancel context.CancelFunc
+ dialCtx, cancel = context.WithDeadline(ctx, partialDeadline)
+ defer cancel()
+ }
+ }
+
+ c, err := sd.dialSingle(dialCtx, ra)
+ if err == nil {
+ return c, nil
+ }
+ if firstErr == nil {
+ firstErr = err
+ }
+ }
+
+ if firstErr == nil {
+ firstErr = &OpError{Op: "dial", Net: sd.network, Source: nil, Addr: nil, Err: errMissingAddress}
+ }
+ return nil, firstErr
+}
+
+// dialSingle attempts to establish and returns a single connection to
+// the destination address.
+func (sd *sysDialer) dialSingle(ctx context.Context, ra Addr) (c Conn, err error) {
+ trace, _ := ctx.Value(nettrace.TraceKey{}).(*nettrace.Trace)
+ if trace != nil {
+ raStr := ra.String()
+ if trace.ConnectStart != nil {
+ trace.ConnectStart(sd.network, raStr)
+ }
+ if trace.ConnectDone != nil {
+ defer func() { trace.ConnectDone(sd.network, raStr, err) }()
+ }
+ }
+ la := sd.LocalAddr
+ switch ra := ra.(type) {
+ case *TCPAddr:
+ la, _ := la.(*TCPAddr)
+ if sd.MultipathTCP() {
+ c, err = sd.dialMPTCP(ctx, la, ra)
+ } else {
+ c, err = sd.dialTCP(ctx, la, ra)
+ }
+ case *UDPAddr:
+ la, _ := la.(*UDPAddr)
+ c, err = sd.dialUDP(ctx, la, ra)
+ case *IPAddr:
+ la, _ := la.(*IPAddr)
+ c, err = sd.dialIP(ctx, la, ra)
+ case *UnixAddr:
+ la, _ := la.(*UnixAddr)
+ c, err = sd.dialUnix(ctx, la, ra)
+ default:
+ return nil, &OpError{Op: "dial", Net: sd.network, Source: la, Addr: ra, Err: &AddrError{Err: "unexpected address type", Addr: sd.address}}
+ }
+ if err != nil {
+ return nil, &OpError{Op: "dial", Net: sd.network, Source: la, Addr: ra, Err: err} // c is non-nil interface containing nil pointer
+ }
+ return c, nil
+}
+
+// ListenConfig contains options for listening to an address.
+type ListenConfig struct {
+ // If Control is not nil, it is called after creating the network
+ // connection but before binding it to the operating system.
+ //
+ // Network and address parameters passed to Control method are not
+ // necessarily the ones passed to Listen. For example, passing "tcp" to
+ // Listen will cause the Control function to be called with "tcp4" or "tcp6".
+ Control func(network, address string, c syscall.RawConn) error
+
+ // KeepAlive specifies the keep-alive period for network
+ // connections accepted by this listener.
+ // If zero, keep-alives are enabled if supported by the protocol
+ // and operating system. Network protocols or operating systems
+ // that do not support keep-alives ignore this field.
+ // If negative, keep-alives are disabled.
+ KeepAlive time.Duration
+
+ // If mptcpStatus is set to a value allowing Multipath TCP (MPTCP) to be
+ // used, any call to Listen with "tcp(4|6)" as network will use MPTCP if
+ // supported by the operating system.
+ mptcpStatus mptcpStatus
+}
+
+// MultipathTCP reports whether MPTCP will be used.
+//
+// This method doesn't check if MPTCP is supported by the operating
+// system or not.
+func (lc *ListenConfig) MultipathTCP() bool {
+ return lc.mptcpStatus.get()
+}
+
+// SetMultipathTCP directs the Listen method to use, or not use, MPTCP,
+// if supported by the operating system. This method overrides the
+// system default and the GODEBUG=multipathtcp=... setting if any.
+//
+// If MPTCP is not available on the host or not supported by the client,
+// the Listen method will fall back to TCP.
+func (lc *ListenConfig) SetMultipathTCP(use bool) {
+ lc.mptcpStatus.set(use)
+}
+
+// Listen announces on the local network address.
+//
+// See func Listen for a description of the network and address
+// parameters.
+func (lc *ListenConfig) Listen(ctx context.Context, network, address string) (Listener, error) {
+ addrs, err := DefaultResolver.resolveAddrList(ctx, "listen", network, address, nil)
+ if err != nil {
+ return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: nil, Err: err}
+ }
+ sl := &sysListener{
+ ListenConfig: *lc,
+ network: network,
+ address: address,
+ }
+ var l Listener
+ la := addrs.first(isIPv4)
+ switch la := la.(type) {
+ case *TCPAddr:
+ if sl.MultipathTCP() {
+ l, err = sl.listenMPTCP(ctx, la)
+ } else {
+ l, err = sl.listenTCP(ctx, la)
+ }
+ case *UnixAddr:
+ l, err = sl.listenUnix(ctx, la)
+ default:
+ return nil, &OpError{Op: "listen", Net: sl.network, Source: nil, Addr: la, Err: &AddrError{Err: "unexpected address type", Addr: address}}
+ }
+ if err != nil {
+ return nil, &OpError{Op: "listen", Net: sl.network, Source: nil, Addr: la, Err: err} // l is non-nil interface containing nil pointer
+ }
+ return l, nil
+}
+
+// ListenPacket announces on the local network address.
+//
+// See func ListenPacket for a description of the network and address
+// parameters.
+func (lc *ListenConfig) ListenPacket(ctx context.Context, network, address string) (PacketConn, error) {
+ addrs, err := DefaultResolver.resolveAddrList(ctx, "listen", network, address, nil)
+ if err != nil {
+ return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: nil, Err: err}
+ }
+ sl := &sysListener{
+ ListenConfig: *lc,
+ network: network,
+ address: address,
+ }
+ var c PacketConn
+ la := addrs.first(isIPv4)
+ switch la := la.(type) {
+ case *UDPAddr:
+ c, err = sl.listenUDP(ctx, la)
+ case *IPAddr:
+ c, err = sl.listenIP(ctx, la)
+ case *UnixAddr:
+ c, err = sl.listenUnixgram(ctx, la)
+ default:
+ return nil, &OpError{Op: "listen", Net: sl.network, Source: nil, Addr: la, Err: &AddrError{Err: "unexpected address type", Addr: address}}
+ }
+ if err != nil {
+ return nil, &OpError{Op: "listen", Net: sl.network, Source: nil, Addr: la, Err: err} // c is non-nil interface containing nil pointer
+ }
+ return c, nil
+}
+
+// sysListener contains a Listen's parameters and configuration.
+type sysListener struct {
+ ListenConfig
+ network, address string
+}
+
+// Listen announces on the local network address.
+//
+// The network must be "tcp", "tcp4", "tcp6", "unix" or "unixpacket".
+//
+// For TCP networks, if the host in the address parameter is empty or
+// a literal unspecified IP address, Listen listens on all available
+// unicast and anycast IP addresses of the local system.
+// To only use IPv4, use network "tcp4".
+// The address can use a host name, but this is not recommended,
+// because it will create a listener for at most one of the host's IP
+// addresses.
+// If the port in the address parameter is empty or "0", as in
+// "127.0.0.1:" or "[::1]:0", a port number is automatically chosen.
+// The Addr method of Listener can be used to discover the chosen
+// port.
+//
+// See func Dial for a description of the network and address
+// parameters.
+//
+// Listen uses context.Background internally; to specify the context, use
+// ListenConfig.Listen.
+func Listen(network, address string) (Listener, error) {
+ var lc ListenConfig
+ return lc.Listen(context.Background(), network, address)
+}
+
+// ListenPacket announces on the local network address.
+//
+// The network must be "udp", "udp4", "udp6", "unixgram", or an IP
+// transport. The IP transports are "ip", "ip4", or "ip6" followed by
+// a colon and a literal protocol number or a protocol name, as in
+// "ip:1" or "ip:icmp".
+//
+// For UDP and IP networks, if the host in the address parameter is
+// empty or a literal unspecified IP address, ListenPacket listens on
+// all available IP addresses of the local system except multicast IP
+// addresses.
+// To only use IPv4, use network "udp4" or "ip4:proto".
+// The address can use a host name, but this is not recommended,
+// because it will create a listener for at most one of the host's IP
+// addresses.
+// If the port in the address parameter is empty or "0", as in
+// "127.0.0.1:" or "[::1]:0", a port number is automatically chosen.
+// The LocalAddr method of PacketConn can be used to discover the
+// chosen port.
+//
+// See func Dial for a description of the network and address
+// parameters.
+//
+// ListenPacket uses context.Background internally; to specify the context, use
+// ListenConfig.ListenPacket.
+func ListenPacket(network, address string) (PacketConn, error) {
+ var lc ListenConfig
+ return lc.ListenPacket(context.Background(), network, address)
+}
diff --git a/src/net/dial_test.go b/src/net/dial_test.go
new file mode 100644
index 0000000..ca9f0da
--- /dev/null
+++ b/src/net/dial_test.go
@@ -0,0 +1,1088 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !js && !wasip1
+
+package net
+
+import (
+ "bufio"
+ "context"
+ "errors"
+ "fmt"
+ "internal/testenv"
+ "io"
+ "os"
+ "runtime"
+ "strings"
+ "sync"
+ "syscall"
+ "testing"
+ "time"
+)
+
+var prohibitionaryDialArgTests = []struct {
+ network string
+ address string
+}{
+ {"tcp6", "127.0.0.1"},
+ {"tcp6", "::ffff:127.0.0.1"},
+}
+
+func TestProhibitionaryDialArg(t *testing.T) {
+ testenv.MustHaveExternalNetwork(t)
+
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("not supported on %s", runtime.GOOS)
+ }
+ if !supportsIPv4map() {
+ t.Skip("mapping ipv4 address inside ipv6 address not supported")
+ }
+
+ ln, err := Listen("tcp", "[::]:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ln.Close()
+
+ _, port, err := SplitHostPort(ln.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ for i, tt := range prohibitionaryDialArgTests {
+ c, err := Dial(tt.network, JoinHostPort(tt.address, port))
+ if err == nil {
+ c.Close()
+ t.Errorf("#%d: %v", i, err)
+ }
+ }
+}
+
+func TestDialLocal(t *testing.T) {
+ ln := newLocalListener(t, "tcp")
+ defer ln.Close()
+ _, port, err := SplitHostPort(ln.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ c, err := Dial("tcp", JoinHostPort("", port))
+ if err != nil {
+ t.Fatal(err)
+ }
+ c.Close()
+}
+
+func TestDialerDualStackFDLeak(t *testing.T) {
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("%s does not have full support of socktest", runtime.GOOS)
+ case "windows":
+ t.Skipf("not implemented a way to cancel dial racers in TCP SYN-SENT state on %s", runtime.GOOS)
+ case "openbsd":
+ testenv.SkipFlaky(t, 15157)
+ }
+ if !supportsIPv4() || !supportsIPv6() {
+ t.Skip("both IPv4 and IPv6 are required")
+ }
+
+ before := sw.Sockets()
+ origTestHookLookupIP := testHookLookupIP
+ defer func() { testHookLookupIP = origTestHookLookupIP }()
+ testHookLookupIP = lookupLocalhost
+ handler := func(dss *dualStackServer, ln Listener) {
+ for {
+ c, err := ln.Accept()
+ if err != nil {
+ return
+ }
+ c.Close()
+ }
+ }
+ dss, err := newDualStackServer()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if err := dss.buildup(handler); err != nil {
+ dss.teardown()
+ t.Fatal(err)
+ }
+
+ const N = 10
+ var wg sync.WaitGroup
+ wg.Add(N)
+ d := &Dialer{DualStack: true, Timeout: 5 * time.Second}
+ for i := 0; i < N; i++ {
+ go func() {
+ defer wg.Done()
+ c, err := d.Dial("tcp", JoinHostPort("localhost", dss.port))
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ c.Close()
+ }()
+ }
+ wg.Wait()
+ dss.teardown()
+ after := sw.Sockets()
+ if len(after) != len(before) {
+ t.Errorf("got %d; want %d", len(after), len(before))
+ }
+}
+
+// Define a pair of blackholed (IPv4, IPv6) addresses, for which dialTCP is
+// expected to hang until the timeout elapses. These addresses are reserved
+// for benchmarking by RFC 6890.
+const (
+ slowDst4 = "198.18.0.254"
+ slowDst6 = "2001:2::254"
+)
+
+// In some environments, the slow IPs may be explicitly unreachable, and fail
+// more quickly than expected. This test hook prevents dialTCP from returning
+// before the deadline.
+func slowDialTCP(ctx context.Context, network string, laddr, raddr *TCPAddr) (*TCPConn, error) {
+ sd := &sysDialer{network: network, address: raddr.String()}
+ c, err := sd.doDialTCP(ctx, laddr, raddr)
+ if ParseIP(slowDst4).Equal(raddr.IP) || ParseIP(slowDst6).Equal(raddr.IP) {
+ // Wait for the deadline, or indefinitely if none exists.
+ <-ctx.Done()
+ }
+ return c, err
+}
+
+func dialClosedPort(t *testing.T) (dialLatency time.Duration) {
+ // On most platforms, dialing a closed port should be nearly instantaneous —
+ // less than a few hundred milliseconds. However, on some platforms it may be
+ // much slower: on Windows and OpenBSD, it has been observed to take up to a
+ // few seconds.
+
+ l, err := Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatalf("dialClosedPort: Listen failed: %v", err)
+ }
+ addr := l.Addr().String()
+ l.Close()
+
+ startTime := time.Now()
+ c, err := Dial("tcp", addr)
+ if err == nil {
+ c.Close()
+ }
+ elapsed := time.Since(startTime)
+ t.Logf("dialClosedPort: measured delay %v", elapsed)
+ return elapsed
+}
+
+func TestDialParallel(t *testing.T) {
+ const instant time.Duration = 0
+ const fallbackDelay = 200 * time.Millisecond
+
+ nCopies := func(s string, n int) []string {
+ out := make([]string, n)
+ for i := 0; i < n; i++ {
+ out[i] = s
+ }
+ return out
+ }
+
+ var testCases = []struct {
+ primaries []string
+ fallbacks []string
+ teardownNetwork string
+ expectOk bool
+ expectElapsed time.Duration
+ }{
+ // These should just work on the first try.
+ {[]string{"127.0.0.1"}, []string{}, "", true, instant},
+ {[]string{"::1"}, []string{}, "", true, instant},
+ {[]string{"127.0.0.1", "::1"}, []string{slowDst6}, "tcp6", true, instant},
+ {[]string{"::1", "127.0.0.1"}, []string{slowDst4}, "tcp4", true, instant},
+ // Primary is slow; fallback should kick in.
+ {[]string{slowDst4}, []string{"::1"}, "", true, fallbackDelay},
+ // Skip a "connection refused" in the primary thread.
+ {[]string{"127.0.0.1", "::1"}, []string{}, "tcp4", true, instant},
+ {[]string{"::1", "127.0.0.1"}, []string{}, "tcp6", true, instant},
+ // Skip a "connection refused" in the fallback thread.
+ {[]string{slowDst4, slowDst6}, []string{"::1", "127.0.0.1"}, "tcp6", true, fallbackDelay},
+ // Primary refused, fallback without delay.
+ {[]string{"127.0.0.1"}, []string{"::1"}, "tcp4", true, instant},
+ {[]string{"::1"}, []string{"127.0.0.1"}, "tcp6", true, instant},
+ // Everything is refused.
+ {[]string{"127.0.0.1"}, []string{}, "tcp4", false, instant},
+ // Nothing to do; fail instantly.
+ {[]string{}, []string{}, "", false, instant},
+ // Connecting to tons of addresses should not trip the deadline.
+ {nCopies("::1", 1000), []string{}, "", true, instant},
+ }
+
+ // Convert a list of IP strings into TCPAddrs.
+ makeAddrs := func(ips []string, port string) addrList {
+ var out addrList
+ for _, ip := range ips {
+ addr, err := ResolveTCPAddr("tcp", JoinHostPort(ip, port))
+ if err != nil {
+ t.Fatal(err)
+ }
+ out = append(out, addr)
+ }
+ return out
+ }
+
+ for i, tt := range testCases {
+ i, tt := i, tt
+ t.Run(fmt.Sprint(i), func(t *testing.T) {
+ dialTCP := func(ctx context.Context, network string, laddr, raddr *TCPAddr) (*TCPConn, error) {
+ n := "tcp6"
+ if raddr.IP.To4() != nil {
+ n = "tcp4"
+ }
+ if n == tt.teardownNetwork {
+ return nil, errors.New("unreachable")
+ }
+ if r := raddr.IP.String(); r == slowDst4 || r == slowDst6 {
+ <-ctx.Done()
+ return nil, ctx.Err()
+ }
+ return &TCPConn{}, nil
+ }
+
+ primaries := makeAddrs(tt.primaries, "80")
+ fallbacks := makeAddrs(tt.fallbacks, "80")
+ d := Dialer{
+ FallbackDelay: fallbackDelay,
+ }
+ const forever = 60 * time.Minute
+ if tt.expectElapsed == instant {
+ d.FallbackDelay = forever
+ }
+ startTime := time.Now()
+ sd := &sysDialer{
+ Dialer: d,
+ network: "tcp",
+ address: "?",
+ testHookDialTCP: dialTCP,
+ }
+ c, err := sd.dialParallel(context.Background(), primaries, fallbacks)
+ elapsed := time.Since(startTime)
+
+ if c != nil {
+ c.Close()
+ }
+
+ if tt.expectOk && err != nil {
+ t.Errorf("#%d: got %v; want nil", i, err)
+ } else if !tt.expectOk && err == nil {
+ t.Errorf("#%d: got nil; want non-nil", i)
+ }
+
+ if elapsed < tt.expectElapsed || elapsed >= forever {
+ t.Errorf("#%d: got %v; want >= %v, < forever", i, elapsed, tt.expectElapsed)
+ }
+
+ // Repeat each case, ensuring that it can be canceled.
+ ctx, cancel := context.WithCancel(context.Background())
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ time.Sleep(5 * time.Millisecond)
+ cancel()
+ wg.Done()
+ }()
+ // Ignore errors, since all we care about is that the
+ // call can be canceled.
+ c, _ = sd.dialParallel(ctx, primaries, fallbacks)
+ if c != nil {
+ c.Close()
+ }
+ wg.Wait()
+ })
+ }
+}
+
+func lookupSlowFast(ctx context.Context, fn func(context.Context, string, string) ([]IPAddr, error), network, host string) ([]IPAddr, error) {
+ switch host {
+ case "slow6loopback4":
+ // Returns a slow IPv6 address, and a local IPv4 address.
+ return []IPAddr{
+ {IP: ParseIP(slowDst6)},
+ {IP: ParseIP("127.0.0.1")},
+ }, nil
+ default:
+ return fn(ctx, network, host)
+ }
+}
+
+func TestDialerFallbackDelay(t *testing.T) {
+ testenv.MustHaveExternalNetwork(t)
+
+ if !supportsIPv4() || !supportsIPv6() {
+ t.Skip("both IPv4 and IPv6 are required")
+ }
+
+ origTestHookLookupIP := testHookLookupIP
+ defer func() { testHookLookupIP = origTestHookLookupIP }()
+ testHookLookupIP = lookupSlowFast
+
+ origTestHookDialTCP := testHookDialTCP
+ defer func() { testHookDialTCP = origTestHookDialTCP }()
+ testHookDialTCP = slowDialTCP
+
+ var testCases = []struct {
+ dualstack bool
+ delay time.Duration
+ expectElapsed time.Duration
+ }{
+ // Use a very brief delay, which should fallback immediately.
+ {true, 1 * time.Nanosecond, 0},
+ // Use a 200ms explicit timeout.
+ {true, 200 * time.Millisecond, 200 * time.Millisecond},
+ // The default is 300ms.
+ {true, 0, 300 * time.Millisecond},
+ }
+
+ handler := func(dss *dualStackServer, ln Listener) {
+ for {
+ c, err := ln.Accept()
+ if err != nil {
+ return
+ }
+ c.Close()
+ }
+ }
+ dss, err := newDualStackServer()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer dss.teardown()
+ if err := dss.buildup(handler); err != nil {
+ t.Fatal(err)
+ }
+
+ for i, tt := range testCases {
+ d := &Dialer{DualStack: tt.dualstack, FallbackDelay: tt.delay}
+
+ startTime := time.Now()
+ c, err := d.Dial("tcp", JoinHostPort("slow6loopback4", dss.port))
+ elapsed := time.Since(startTime)
+ if err == nil {
+ c.Close()
+ } else if tt.dualstack {
+ t.Error(err)
+ }
+ expectMin := tt.expectElapsed - 1*time.Millisecond
+ expectMax := tt.expectElapsed + 95*time.Millisecond
+ if elapsed < expectMin {
+ t.Errorf("#%d: got %v; want >= %v", i, elapsed, expectMin)
+ }
+ if elapsed > expectMax {
+ t.Errorf("#%d: got %v; want <= %v", i, elapsed, expectMax)
+ }
+ }
+}
+
+func TestDialParallelSpuriousConnection(t *testing.T) {
+ if !supportsIPv4() || !supportsIPv6() {
+ t.Skip("both IPv4 and IPv6 are required")
+ }
+
+ var readDeadline time.Time
+ if td, ok := t.Deadline(); ok {
+ const arbitraryCleanupMargin = 1 * time.Second
+ readDeadline = td.Add(-arbitraryCleanupMargin)
+ } else {
+ readDeadline = time.Now().Add(5 * time.Second)
+ }
+
+ var closed sync.WaitGroup
+ closed.Add(2)
+ handler := func(dss *dualStackServer, ln Listener) {
+ // Accept one connection per address.
+ c, err := ln.Accept()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Workaround for https://go.dev/issue/37795.
+ // On arm64 macOS (current as of macOS 12.4),
+ // reading from a socket at the same time as the client
+ // is closing it occasionally hangs for 60 seconds before
+ // returning ECONNRESET. Sleep for a bit to give the
+ // socket time to close before trying to read from it.
+ if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" {
+ time.Sleep(10 * time.Millisecond)
+ }
+
+ // The client should close itself, without sending data.
+ c.SetReadDeadline(readDeadline)
+ var b [1]byte
+ if _, err := c.Read(b[:]); err != io.EOF {
+ t.Errorf("got %v; want %v", err, io.EOF)
+ }
+ c.Close()
+ closed.Done()
+ }
+ dss, err := newDualStackServer()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer dss.teardown()
+ if err := dss.buildup(handler); err != nil {
+ t.Fatal(err)
+ }
+
+ const fallbackDelay = 100 * time.Millisecond
+
+ var dialing sync.WaitGroup
+ dialing.Add(2)
+ origTestHookDialTCP := testHookDialTCP
+ defer func() { testHookDialTCP = origTestHookDialTCP }()
+ testHookDialTCP = func(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error) {
+ // Wait until Happy Eyeballs kicks in and both connections are dialing,
+ // and inhibit cancellation.
+ // This forces dialParallel to juggle two successful connections.
+ dialing.Done()
+ dialing.Wait()
+
+ // Now ignore the provided context (which will be canceled) and use a
+ // different one to make sure this completes with a valid connection,
+ // which we hope to be closed below:
+ sd := &sysDialer{network: net, address: raddr.String()}
+ return sd.doDialTCP(context.Background(), laddr, raddr)
+ }
+
+ d := Dialer{
+ FallbackDelay: fallbackDelay,
+ }
+ sd := &sysDialer{
+ Dialer: d,
+ network: "tcp",
+ address: "?",
+ }
+
+ makeAddr := func(ip string) addrList {
+ addr, err := ResolveTCPAddr("tcp", JoinHostPort(ip, dss.port))
+ if err != nil {
+ t.Fatal(err)
+ }
+ return addrList{addr}
+ }
+
+ // dialParallel returns one connection (and closes the other.)
+ c, err := sd.dialParallel(context.Background(), makeAddr("127.0.0.1"), makeAddr("::1"))
+ if err != nil {
+ t.Fatal(err)
+ }
+ c.Close()
+
+ // The server should've seen both connections.
+ closed.Wait()
+}
+
+func TestDialerPartialDeadline(t *testing.T) {
+ now := time.Date(2000, time.January, 1, 0, 0, 0, 0, time.UTC)
+ var testCases = []struct {
+ now time.Time
+ deadline time.Time
+ addrs int
+ expectDeadline time.Time
+ expectErr error
+ }{
+ // Regular division.
+ {now, now.Add(12 * time.Second), 1, now.Add(12 * time.Second), nil},
+ {now, now.Add(12 * time.Second), 2, now.Add(6 * time.Second), nil},
+ {now, now.Add(12 * time.Second), 3, now.Add(4 * time.Second), nil},
+ // Bump against the 2-second sane minimum.
+ {now, now.Add(12 * time.Second), 999, now.Add(2 * time.Second), nil},
+ // Total available is now below the sane minimum.
+ {now, now.Add(1900 * time.Millisecond), 999, now.Add(1900 * time.Millisecond), nil},
+ // Null deadline.
+ {now, noDeadline, 1, noDeadline, nil},
+ // Step the clock forward and cross the deadline.
+ {now.Add(-1 * time.Millisecond), now, 1, now, nil},
+ {now.Add(0 * time.Millisecond), now, 1, noDeadline, errTimeout},
+ {now.Add(1 * time.Millisecond), now, 1, noDeadline, errTimeout},
+ }
+ for i, tt := range testCases {
+ deadline, err := partialDeadline(tt.now, tt.deadline, tt.addrs)
+ if err != tt.expectErr {
+ t.Errorf("#%d: got %v; want %v", i, err, tt.expectErr)
+ }
+ if !deadline.Equal(tt.expectDeadline) {
+ t.Errorf("#%d: got %v; want %v", i, deadline, tt.expectDeadline)
+ }
+ }
+}
+
+// isEADDRINUSE reports whether err is syscall.EADDRINUSE.
+var isEADDRINUSE = func(err error) bool { return false }
+
+func TestDialerLocalAddr(t *testing.T) {
+ if !supportsIPv4() || !supportsIPv6() {
+ t.Skip("both IPv4 and IPv6 are required")
+ }
+
+ type test struct {
+ network, raddr string
+ laddr Addr
+ error
+ }
+ var tests = []test{
+ {"tcp4", "127.0.0.1", nil, nil},
+ {"tcp4", "127.0.0.1", &TCPAddr{}, nil},
+ {"tcp4", "127.0.0.1", &TCPAddr{IP: ParseIP("0.0.0.0")}, nil},
+ {"tcp4", "127.0.0.1", &TCPAddr{IP: ParseIP("0.0.0.0").To4()}, nil},
+ {"tcp4", "127.0.0.1", &TCPAddr{IP: ParseIP("::")}, &AddrError{Err: "some error"}},
+ {"tcp4", "127.0.0.1", &TCPAddr{IP: ParseIP("127.0.0.1").To4()}, nil},
+ {"tcp4", "127.0.0.1", &TCPAddr{IP: ParseIP("127.0.0.1").To16()}, nil},
+ {"tcp4", "127.0.0.1", &TCPAddr{IP: IPv6loopback}, errNoSuitableAddress},
+ {"tcp4", "127.0.0.1", &UDPAddr{}, &AddrError{Err: "some error"}},
+ {"tcp4", "127.0.0.1", &UnixAddr{}, &AddrError{Err: "some error"}},
+
+ {"tcp6", "::1", nil, nil},
+ {"tcp6", "::1", &TCPAddr{}, nil},
+ {"tcp6", "::1", &TCPAddr{IP: ParseIP("0.0.0.0")}, nil},
+ {"tcp6", "::1", &TCPAddr{IP: ParseIP("0.0.0.0").To4()}, nil},
+ {"tcp6", "::1", &TCPAddr{IP: ParseIP("::")}, nil},
+ {"tcp6", "::1", &TCPAddr{IP: ParseIP("127.0.0.1").To4()}, errNoSuitableAddress},
+ {"tcp6", "::1", &TCPAddr{IP: ParseIP("127.0.0.1").To16()}, errNoSuitableAddress},
+ {"tcp6", "::1", &TCPAddr{IP: IPv6loopback}, nil},
+ {"tcp6", "::1", &UDPAddr{}, &AddrError{Err: "some error"}},
+ {"tcp6", "::1", &UnixAddr{}, &AddrError{Err: "some error"}},
+
+ {"tcp", "127.0.0.1", nil, nil},
+ {"tcp", "127.0.0.1", &TCPAddr{}, nil},
+ {"tcp", "127.0.0.1", &TCPAddr{IP: ParseIP("0.0.0.0")}, nil},
+ {"tcp", "127.0.0.1", &TCPAddr{IP: ParseIP("0.0.0.0").To4()}, nil},
+ {"tcp", "127.0.0.1", &TCPAddr{IP: ParseIP("127.0.0.1").To4()}, nil},
+ {"tcp", "127.0.0.1", &TCPAddr{IP: ParseIP("127.0.0.1").To16()}, nil},
+ {"tcp", "127.0.0.1", &TCPAddr{IP: IPv6loopback}, errNoSuitableAddress},
+ {"tcp", "127.0.0.1", &UDPAddr{}, &AddrError{Err: "some error"}},
+ {"tcp", "127.0.0.1", &UnixAddr{}, &AddrError{Err: "some error"}},
+
+ {"tcp", "::1", nil, nil},
+ {"tcp", "::1", &TCPAddr{}, nil},
+ {"tcp", "::1", &TCPAddr{IP: ParseIP("0.0.0.0")}, nil},
+ {"tcp", "::1", &TCPAddr{IP: ParseIP("0.0.0.0").To4()}, nil},
+ {"tcp", "::1", &TCPAddr{IP: ParseIP("::")}, nil},
+ {"tcp", "::1", &TCPAddr{IP: ParseIP("127.0.0.1").To4()}, errNoSuitableAddress},
+ {"tcp", "::1", &TCPAddr{IP: ParseIP("127.0.0.1").To16()}, errNoSuitableAddress},
+ {"tcp", "::1", &TCPAddr{IP: IPv6loopback}, nil},
+ {"tcp", "::1", &UDPAddr{}, &AddrError{Err: "some error"}},
+ {"tcp", "::1", &UnixAddr{}, &AddrError{Err: "some error"}},
+ }
+
+ issue34264Index := -1
+ if supportsIPv4map() {
+ issue34264Index = len(tests)
+ tests = append(tests, test{
+ "tcp", "127.0.0.1", &TCPAddr{IP: ParseIP("::")}, nil,
+ })
+ } else {
+ tests = append(tests, test{
+ "tcp", "127.0.0.1", &TCPAddr{IP: ParseIP("::")}, &AddrError{Err: "some error"},
+ })
+ }
+
+ origTestHookLookupIP := testHookLookupIP
+ defer func() { testHookLookupIP = origTestHookLookupIP }()
+ testHookLookupIP = lookupLocalhost
+ handler := func(ls *localServer, ln Listener) {
+ for {
+ c, err := ln.Accept()
+ if err != nil {
+ return
+ }
+ c.Close()
+ }
+ }
+ var lss [2]*localServer
+ for i, network := range []string{"tcp4", "tcp6"} {
+ lss[i] = newLocalServer(t, network)
+ defer lss[i].teardown()
+ if err := lss[i].buildup(handler); err != nil {
+ t.Fatal(err)
+ }
+ }
+
+ for i, tt := range tests {
+ d := &Dialer{LocalAddr: tt.laddr}
+ var addr string
+ ip := ParseIP(tt.raddr)
+ if ip.To4() != nil {
+ addr = lss[0].Listener.Addr().String()
+ }
+ if ip.To16() != nil && ip.To4() == nil {
+ addr = lss[1].Listener.Addr().String()
+ }
+ c, err := d.Dial(tt.network, addr)
+ if err == nil && tt.error != nil || err != nil && tt.error == nil {
+ if i == issue34264Index && runtime.GOOS == "freebsd" && isEADDRINUSE(err) {
+ // https://golang.org/issue/34264: FreeBSD through at least version 12.2
+ // has been observed to fail with EADDRINUSE when dialing from an IPv6
+ // local address to an IPv4 remote address.
+ t.Logf("%s %v->%s: got %v; want %v", tt.network, tt.laddr, tt.raddr, err, tt.error)
+ t.Logf("(spurious EADDRINUSE ignored on freebsd: see https://golang.org/issue/34264)")
+ } else {
+ t.Errorf("%s %v->%s: got %v; want %v", tt.network, tt.laddr, tt.raddr, err, tt.error)
+ }
+ }
+ if err != nil {
+ if perr := parseDialError(err); perr != nil {
+ t.Error(perr)
+ }
+ continue
+ }
+ c.Close()
+ }
+}
+
+func TestDialerDualStack(t *testing.T) {
+ testenv.SkipFlaky(t, 13324)
+
+ if !supportsIPv4() || !supportsIPv6() {
+ t.Skip("both IPv4 and IPv6 are required")
+ }
+
+ closedPortDelay := dialClosedPort(t)
+
+ origTestHookLookupIP := testHookLookupIP
+ defer func() { testHookLookupIP = origTestHookLookupIP }()
+ testHookLookupIP = lookupLocalhost
+ handler := func(dss *dualStackServer, ln Listener) {
+ for {
+ c, err := ln.Accept()
+ if err != nil {
+ return
+ }
+ c.Close()
+ }
+ }
+
+ var timeout = 150*time.Millisecond + closedPortDelay
+ for _, dualstack := range []bool{false, true} {
+ dss, err := newDualStackServer()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer dss.teardown()
+ if err := dss.buildup(handler); err != nil {
+ t.Fatal(err)
+ }
+
+ d := &Dialer{DualStack: dualstack, Timeout: timeout}
+ for range dss.lns {
+ c, err := d.Dial("tcp", JoinHostPort("localhost", dss.port))
+ if err != nil {
+ t.Error(err)
+ continue
+ }
+ switch addr := c.LocalAddr().(*TCPAddr); {
+ case addr.IP.To4() != nil:
+ dss.teardownNetwork("tcp4")
+ case addr.IP.To16() != nil && addr.IP.To4() == nil:
+ dss.teardownNetwork("tcp6")
+ }
+ c.Close()
+ }
+ }
+}
+
+func TestDialerKeepAlive(t *testing.T) {
+ handler := func(ls *localServer, ln Listener) {
+ for {
+ c, err := ln.Accept()
+ if err != nil {
+ return
+ }
+ c.Close()
+ }
+ }
+ ls := newLocalServer(t, "tcp")
+ defer ls.teardown()
+ if err := ls.buildup(handler); err != nil {
+ t.Fatal(err)
+ }
+ defer func() { testHookSetKeepAlive = func(time.Duration) {} }()
+
+ tests := []struct {
+ ka time.Duration
+ expected time.Duration
+ }{
+ {-1, -1},
+ {0, 15 * time.Second},
+ {5 * time.Second, 5 * time.Second},
+ {30 * time.Second, 30 * time.Second},
+ }
+
+ for _, test := range tests {
+ var got time.Duration = -1
+ testHookSetKeepAlive = func(d time.Duration) { got = d }
+ d := Dialer{KeepAlive: test.ka}
+ c, err := d.Dial("tcp", ls.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ c.Close()
+ if got != test.expected {
+ t.Errorf("Dialer.KeepAlive = %v: SetKeepAlive set to %v, want %v", d.KeepAlive, got, test.expected)
+ }
+ }
+}
+
+func TestDialCancel(t *testing.T) {
+ mustHaveExternalNetwork(t)
+
+ blackholeIPPort := JoinHostPort(slowDst4, "1234")
+ if !supportsIPv4() {
+ blackholeIPPort = JoinHostPort(slowDst6, "1234")
+ }
+
+ ticker := time.NewTicker(10 * time.Millisecond)
+ defer ticker.Stop()
+
+ const cancelTick = 5 // the timer tick we cancel the dial at
+ const timeoutTick = 100
+
+ var d Dialer
+ cancel := make(chan struct{})
+ d.Cancel = cancel
+ errc := make(chan error, 1)
+ connc := make(chan Conn, 1)
+ go func() {
+ if c, err := d.Dial("tcp", blackholeIPPort); err != nil {
+ errc <- err
+ } else {
+ connc <- c
+ }
+ }()
+ ticks := 0
+ for {
+ select {
+ case <-ticker.C:
+ ticks++
+ if ticks == cancelTick {
+ close(cancel)
+ }
+ if ticks == timeoutTick {
+ t.Fatal("timeout waiting for dial to fail")
+ }
+ case c := <-connc:
+ c.Close()
+ t.Fatal("unexpected successful connection")
+ case err := <-errc:
+ if perr := parseDialError(err); perr != nil {
+ t.Error(perr)
+ }
+ if ticks < cancelTick {
+ // Using strings.Contains is ugly but
+ // may work on plan9 and windows.
+ ignorable := []string{
+ "connection refused",
+ "unreachable",
+ "no route to host",
+ }
+ e := err.Error()
+ for _, ignore := range ignorable {
+ if strings.Contains(e, ignore) {
+ t.Skipf("connection to %v failed fast with %v", blackholeIPPort, err)
+ }
+ }
+
+ t.Fatalf("dial error after %d ticks (%d before cancel sent): %v",
+ ticks, cancelTick-ticks, err)
+ }
+ if oe, ok := err.(*OpError); !ok || oe.Err != errCanceled {
+ t.Fatalf("dial error = %v (%T); want OpError with Err == errCanceled", err, err)
+ }
+ return // success.
+ }
+ }
+}
+
+func TestCancelAfterDial(t *testing.T) {
+ if testing.Short() {
+ t.Skip("avoiding time.Sleep")
+ }
+
+ ln := newLocalListener(t, "tcp")
+
+ var wg sync.WaitGroup
+ wg.Add(1)
+ defer func() {
+ ln.Close()
+ wg.Wait()
+ }()
+
+ // Echo back the first line of each incoming connection.
+ go func() {
+ for {
+ c, err := ln.Accept()
+ if err != nil {
+ break
+ }
+ rb := bufio.NewReader(c)
+ line, err := rb.ReadString('\n')
+ if err != nil {
+ t.Error(err)
+ c.Close()
+ continue
+ }
+ if _, err := c.Write([]byte(line)); err != nil {
+ t.Error(err)
+ }
+ c.Close()
+ }
+ wg.Done()
+ }()
+
+ try := func() {
+ cancel := make(chan struct{})
+ d := &Dialer{Cancel: cancel}
+ c, err := d.Dial("tcp", ln.Addr().String())
+
+ // Immediately after dialing, request cancellation and sleep.
+ // Before Issue 15078 was fixed, this would cause subsequent operations
+ // to fail with an i/o timeout roughly 50% of the time.
+ close(cancel)
+ time.Sleep(10 * time.Millisecond)
+
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+
+ // Send some data to confirm that the connection is still alive.
+ const message = "echo!\n"
+ if _, err := c.Write([]byte(message)); err != nil {
+ t.Fatal(err)
+ }
+
+ // The server should echo the line, and close the connection.
+ rb := bufio.NewReader(c)
+ line, err := rb.ReadString('\n')
+ if err != nil {
+ t.Fatal(err)
+ }
+ if line != message {
+ t.Errorf("got %q; want %q", line, message)
+ }
+ if _, err := rb.ReadByte(); err != io.EOF {
+ t.Errorf("got %v; want %v", err, io.EOF)
+ }
+ }
+
+ // This bug manifested about 50% of the time, so try it a few times.
+ for i := 0; i < 10; i++ {
+ try()
+ }
+}
+
+func TestDialClosedPortFailFast(t *testing.T) {
+ if runtime.GOOS != "windows" {
+ // Reported by go.dev/issues/23366.
+ t.Skip("skipping windows only test")
+ }
+ for _, network := range []string{"tcp", "tcp4", "tcp6"} {
+ t.Run(network, func(t *testing.T) {
+ if !testableNetwork(network) {
+ t.Skipf("skipping: can't listen on %s", network)
+ }
+ // Reserve a local port till the end of the
+ // test by opening a listener and connecting to
+ // it using Dial.
+ ln := newLocalListener(t, network)
+ addr := ln.Addr().String()
+ conn1, err := Dial(network, addr)
+ if err != nil {
+ ln.Close()
+ t.Fatal(err)
+ }
+ defer conn1.Close()
+ // Now close the listener so the next Dial fails
+ // keeping conn1 alive so the port is not made
+ // available.
+ ln.Close()
+
+ maxElapsed := time.Second
+ // The host can be heavy-loaded and take
+ // longer than configured. Retry until
+ // Dial takes less than maxElapsed or
+ // the test times out.
+ for {
+ startTime := time.Now()
+ conn2, err := Dial(network, addr)
+ if err == nil {
+ conn2.Close()
+ t.Fatal("error expected")
+ }
+ elapsed := time.Since(startTime)
+ if elapsed < maxElapsed {
+ break
+ }
+ t.Logf("got %v; want < %v", elapsed, maxElapsed)
+ }
+ })
+ }
+}
+
+// Issue 18806: it should always be possible to net.Dial a
+// net.Listener().Addr().String when the listen address was ":n", even
+// if the machine has halfway configured IPv6 such that it can bind on
+// "::" not connect back to that same address.
+func TestDialListenerAddr(t *testing.T) {
+ if !testableNetwork("tcp4") {
+ t.Skipf("skipping: can't listen on tcp4")
+ }
+
+ // The original issue report was for listening on just ":0" on a system that
+ // supports both tcp4 and tcp6 for external traffic but only tcp4 for loopback
+ // traffic. However, the port opened by ":0" is externally-accessible, and may
+ // trigger firewall alerts or otherwise be mistaken for malicious activity
+ // (see https://go.dev/issue/59497). Moreover, it often does not reproduce
+ // the scenario in the issue, in which the port *cannot* be dialed as tcp6.
+ //
+ // To address both of those problems, we open a tcp4-only localhost port, but
+ // then dial the address string that the listener would have reported for a
+ // dual-stack port.
+ ln, err := Listen("tcp4", "localhost:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ln.Close()
+
+ t.Logf("listening on %q", ln.Addr())
+ _, port, err := SplitHostPort(ln.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // If we had opened a dual-stack port without an explicit "localhost" address,
+ // the Listener would arbitrarily report an empty tcp6 address in its Addr
+ // string.
+ //
+ // The documentation for Dial says ‘if the host is empty or a literal
+ // unspecified IP address, as in ":80", "0.0.0.0:80" or "[::]:80" for TCP and
+ // UDP, "", "0.0.0.0" or "::" for IP, the local system is assumed.’
+ // In #18806, it was decided that that should include the local tcp4 host
+ // even if the string is in the tcp6 format.
+ dialAddr := "[::]:" + port
+ c, err := Dial("tcp4", dialAddr)
+ if err != nil {
+ t.Fatalf(`Dial("tcp4", %q): %v`, dialAddr, err)
+ }
+ c.Close()
+ t.Logf(`Dial("tcp4", %q) succeeded`, dialAddr)
+}
+
+func TestDialerControl(t *testing.T) {
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("not supported on %s", runtime.GOOS)
+ }
+
+ t.Run("StreamDial", func(t *testing.T) {
+ for _, network := range []string{"tcp", "tcp4", "tcp6", "unix", "unixpacket"} {
+ if !testableNetwork(network) {
+ continue
+ }
+ ln := newLocalListener(t, network)
+ defer ln.Close()
+ d := Dialer{Control: controlOnConnSetup}
+ c, err := d.Dial(network, ln.Addr().String())
+ if err != nil {
+ t.Error(err)
+ continue
+ }
+ c.Close()
+ }
+ })
+ t.Run("PacketDial", func(t *testing.T) {
+ for _, network := range []string{"udp", "udp4", "udp6", "unixgram"} {
+ if !testableNetwork(network) {
+ continue
+ }
+ c1 := newLocalPacketListener(t, network)
+ if network == "unixgram" {
+ defer os.Remove(c1.LocalAddr().String())
+ }
+ defer c1.Close()
+ d := Dialer{Control: controlOnConnSetup}
+ c2, err := d.Dial(network, c1.LocalAddr().String())
+ if err != nil {
+ t.Error(err)
+ continue
+ }
+ c2.Close()
+ }
+ })
+}
+
+func TestDialerControlContext(t *testing.T) {
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("%s does not have full support of socktest", runtime.GOOS)
+ }
+ t.Run("StreamDial", func(t *testing.T) {
+ for i, network := range []string{"tcp", "tcp4", "tcp6", "unix", "unixpacket"} {
+ if !testableNetwork(network) {
+ continue
+ }
+ ln := newLocalListener(t, network)
+ defer ln.Close()
+ var id int
+ d := Dialer{ControlContext: func(ctx context.Context, network string, address string, c syscall.RawConn) error {
+ id = ctx.Value("id").(int)
+ return controlOnConnSetup(network, address, c)
+ }}
+ c, err := d.DialContext(context.WithValue(context.Background(), "id", i+1), network, ln.Addr().String())
+ if err != nil {
+ t.Error(err)
+ continue
+ }
+ if id != i+1 {
+ t.Errorf("got id %d, want %d", id, i+1)
+ }
+ c.Close()
+ }
+ })
+}
+
+// mustHaveExternalNetwork is like testenv.MustHaveExternalNetwork
+// except that it won't skip testing on non-mobile builders.
+func mustHaveExternalNetwork(t *testing.T) {
+ t.Helper()
+ mobile := runtime.GOOS == "android" || runtime.GOOS == "ios"
+ if testenv.Builder() == "" || mobile {
+ testenv.MustHaveExternalNetwork(t)
+ }
+}
+
+type contextWithNonZeroDeadline struct {
+ context.Context
+}
+
+func (contextWithNonZeroDeadline) Deadline() (time.Time, bool) {
+ // Return non-zero time.Time value with false indicating that no deadline is set.
+ return time.Unix(0, 0), false
+}
+
+func TestDialWithNonZeroDeadline(t *testing.T) {
+ ln := newLocalListener(t, "tcp")
+ defer ln.Close()
+ _, port, err := SplitHostPort(ln.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ ctx := contextWithNonZeroDeadline{Context: context.Background()}
+ var dialer Dialer
+ c, err := dialer.DialContext(ctx, "tcp", JoinHostPort("", port))
+ if err != nil {
+ t.Fatal(err)
+ }
+ c.Close()
+}
diff --git a/src/net/dial_unix_test.go b/src/net/dial_unix_test.go
new file mode 100644
index 0000000..d0df0b7
--- /dev/null
+++ b/src/net/dial_unix_test.go
@@ -0,0 +1,113 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build unix
+
+package net
+
+import (
+ "context"
+ "errors"
+ "syscall"
+ "testing"
+ "time"
+)
+
+func init() {
+ isEADDRINUSE = func(err error) bool {
+ return errors.Is(err, syscall.EADDRINUSE)
+ }
+}
+
+// Issue 16523
+func TestDialContextCancelRace(t *testing.T) {
+ oldConnectFunc := connectFunc
+ oldGetsockoptIntFunc := getsockoptIntFunc
+ oldTestHookCanceledDial := testHookCanceledDial
+ defer func() {
+ connectFunc = oldConnectFunc
+ getsockoptIntFunc = oldGetsockoptIntFunc
+ testHookCanceledDial = oldTestHookCanceledDial
+ }()
+
+ ln := newLocalListener(t, "tcp")
+ listenerDone := make(chan struct{})
+ go func() {
+ defer close(listenerDone)
+ c, err := ln.Accept()
+ if err == nil {
+ c.Close()
+ }
+ }()
+ defer func() { <-listenerDone }()
+ defer ln.Close()
+
+ sawCancel := make(chan bool, 1)
+ testHookCanceledDial = func() {
+ sawCancel <- true
+ }
+
+ ctx, cancelCtx := context.WithCancel(context.Background())
+
+ connectFunc = func(fd int, addr syscall.Sockaddr) error {
+ err := oldConnectFunc(fd, addr)
+ t.Logf("connect(%d, addr) = %v", fd, err)
+ if err == nil {
+ // On some operating systems, localhost
+ // connects _sometimes_ succeed immediately.
+ // Prevent that, so we exercise the code path
+ // we're interested in testing. This seems
+ // harmless. It makes FreeBSD 10.10 work when
+ // run with many iterations. It failed about
+ // half the time previously.
+ return syscall.EINPROGRESS
+ }
+ return err
+ }
+
+ getsockoptIntFunc = func(fd, level, opt int) (val int, err error) {
+ val, err = oldGetsockoptIntFunc(fd, level, opt)
+ t.Logf("getsockoptIntFunc(%d, %d, %d) = (%v, %v)", fd, level, opt, val, err)
+ if level == syscall.SOL_SOCKET && opt == syscall.SO_ERROR && err == nil && val == 0 {
+ t.Logf("canceling context")
+
+ // Cancel the context at just the moment which
+ // caused the race in issue 16523.
+ cancelCtx()
+
+ // And wait for the "interrupter" goroutine to
+ // cancel the dial by messing with its write
+ // timeout before returning.
+ select {
+ case <-sawCancel:
+ t.Logf("saw cancel")
+ case <-time.After(5 * time.Second):
+ t.Errorf("didn't see cancel after 5 seconds")
+ }
+ }
+ return
+ }
+
+ var d Dialer
+ c, err := d.DialContext(ctx, "tcp", ln.Addr().String())
+ if err == nil {
+ c.Close()
+ t.Fatal("unexpected successful dial; want context canceled error")
+ }
+
+ select {
+ case <-ctx.Done():
+ case <-time.After(5 * time.Second):
+ t.Fatal("expected context to be canceled")
+ }
+
+ oe, ok := err.(*OpError)
+ if !ok || oe.Op != "dial" {
+ t.Fatalf("Dial error = %#v; want dial *OpError", err)
+ }
+
+ if oe.Err != errCanceled {
+ t.Errorf("DialContext = (%v, %v); want OpError with error %v", c, err, errCanceled)
+ }
+}
diff --git a/src/net/dnsclient.go b/src/net/dnsclient.go
new file mode 100644
index 0000000..b609dbd
--- /dev/null
+++ b/src/net/dnsclient.go
@@ -0,0 +1,228 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "internal/bytealg"
+ "internal/itoa"
+ "sort"
+
+ "golang.org/x/net/dns/dnsmessage"
+)
+
+// provided by runtime
+func fastrandu() uint
+
+func randInt() int {
+ return int(fastrandu() >> 1) // clear sign bit
+}
+
+func randIntn(n int) int {
+ return randInt() % n
+}
+
+// reverseaddr returns the in-addr.arpa. or ip6.arpa. hostname of the IP
+// address addr suitable for rDNS (PTR) record lookup or an error if it fails
+// to parse the IP address.
+func reverseaddr(addr string) (arpa string, err error) {
+ ip := ParseIP(addr)
+ if ip == nil {
+ return "", &DNSError{Err: "unrecognized address", Name: addr}
+ }
+ if ip.To4() != nil {
+ return itoa.Uitoa(uint(ip[15])) + "." + itoa.Uitoa(uint(ip[14])) + "." + itoa.Uitoa(uint(ip[13])) + "." + itoa.Uitoa(uint(ip[12])) + ".in-addr.arpa.", nil
+ }
+ // Must be IPv6
+ buf := make([]byte, 0, len(ip)*4+len("ip6.arpa."))
+ // Add it, in reverse, to the buffer
+ for i := len(ip) - 1; i >= 0; i-- {
+ v := ip[i]
+ buf = append(buf, hexDigit[v&0xF],
+ '.',
+ hexDigit[v>>4],
+ '.')
+ }
+ // Append "ip6.arpa." and return (buf already has the final .)
+ buf = append(buf, "ip6.arpa."...)
+ return string(buf), nil
+}
+
+func equalASCIIName(x, y dnsmessage.Name) bool {
+ if x.Length != y.Length {
+ return false
+ }
+ for i := 0; i < int(x.Length); i++ {
+ a := x.Data[i]
+ b := y.Data[i]
+ if 'A' <= a && a <= 'Z' {
+ a += 0x20
+ }
+ if 'A' <= b && b <= 'Z' {
+ b += 0x20
+ }
+ if a != b {
+ return false
+ }
+ }
+ return true
+}
+
+// isDomainName checks if a string is a presentation-format domain name
+// (currently restricted to hostname-compatible "preferred name" LDH labels and
+// SRV-like "underscore labels"; see golang.org/issue/12421).
+func isDomainName(s string) bool {
+ // The root domain name is valid. See golang.org/issue/45715.
+ if s == "." {
+ return true
+ }
+
+ // See RFC 1035, RFC 3696.
+ // Presentation format has dots before every label except the first, and the
+ // terminal empty label is optional here because we assume fully-qualified
+ // (absolute) input. We must therefore reserve space for the first and last
+ // labels' length octets in wire format, where they are necessary and the
+ // maximum total length is 255.
+ // So our _effective_ maximum is 253, but 254 is not rejected if the last
+ // character is a dot.
+ l := len(s)
+ if l == 0 || l > 254 || l == 254 && s[l-1] != '.' {
+ return false
+ }
+
+ last := byte('.')
+ nonNumeric := false // true once we've seen a letter or hyphen
+ partlen := 0
+ for i := 0; i < len(s); i++ {
+ c := s[i]
+ switch {
+ default:
+ return false
+ case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_':
+ nonNumeric = true
+ partlen++
+ case '0' <= c && c <= '9':
+ // fine
+ partlen++
+ case c == '-':
+ // Byte before dash cannot be dot.
+ if last == '.' {
+ return false
+ }
+ partlen++
+ nonNumeric = true
+ case c == '.':
+ // Byte before dot cannot be dot, dash.
+ if last == '.' || last == '-' {
+ return false
+ }
+ if partlen > 63 || partlen == 0 {
+ return false
+ }
+ partlen = 0
+ }
+ last = c
+ }
+ if last == '-' || partlen > 63 {
+ return false
+ }
+
+ return nonNumeric
+}
+
+// absDomainName returns an absolute domain name which ends with a
+// trailing dot to match pure Go reverse resolver and all other lookup
+// routines.
+// See golang.org/issue/12189.
+// But we don't want to add dots for local names from /etc/hosts.
+// It's hard to tell so we settle on the heuristic that names without dots
+// (like "localhost" or "myhost") do not get trailing dots, but any other
+// names do.
+func absDomainName(s string) string {
+ if bytealg.IndexByteString(s, '.') != -1 && s[len(s)-1] != '.' {
+ s += "."
+ }
+ return s
+}
+
+// An SRV represents a single DNS SRV record.
+type SRV struct {
+ Target string
+ Port uint16
+ Priority uint16
+ Weight uint16
+}
+
+// byPriorityWeight sorts SRV records by ascending priority and weight.
+type byPriorityWeight []*SRV
+
+func (s byPriorityWeight) Len() int { return len(s) }
+func (s byPriorityWeight) Less(i, j int) bool {
+ return s[i].Priority < s[j].Priority || (s[i].Priority == s[j].Priority && s[i].Weight < s[j].Weight)
+}
+func (s byPriorityWeight) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
+
+// shuffleByWeight shuffles SRV records by weight using the algorithm
+// described in RFC 2782.
+func (addrs byPriorityWeight) shuffleByWeight() {
+ sum := 0
+ for _, addr := range addrs {
+ sum += int(addr.Weight)
+ }
+ for sum > 0 && len(addrs) > 1 {
+ s := 0
+ n := randIntn(sum)
+ for i := range addrs {
+ s += int(addrs[i].Weight)
+ if s > n {
+ if i > 0 {
+ addrs[0], addrs[i] = addrs[i], addrs[0]
+ }
+ break
+ }
+ }
+ sum -= int(addrs[0].Weight)
+ addrs = addrs[1:]
+ }
+}
+
+// sort reorders SRV records as specified in RFC 2782.
+func (addrs byPriorityWeight) sort() {
+ sort.Sort(addrs)
+ i := 0
+ for j := 1; j < len(addrs); j++ {
+ if addrs[i].Priority != addrs[j].Priority {
+ addrs[i:j].shuffleByWeight()
+ i = j
+ }
+ }
+ addrs[i:].shuffleByWeight()
+}
+
+// An MX represents a single DNS MX record.
+type MX struct {
+ Host string
+ Pref uint16
+}
+
+// byPref implements sort.Interface to sort MX records by preference
+type byPref []*MX
+
+func (s byPref) Len() int { return len(s) }
+func (s byPref) Less(i, j int) bool { return s[i].Pref < s[j].Pref }
+func (s byPref) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
+
+// sort reorders MX records as specified in RFC 5321.
+func (s byPref) sort() {
+ for i := range s {
+ j := randIntn(i + 1)
+ s[i], s[j] = s[j], s[i]
+ }
+ sort.Sort(s)
+}
+
+// An NS represents a single DNS NS record.
+type NS struct {
+ Host string
+}
diff --git a/src/net/dnsclient_test.go b/src/net/dnsclient_test.go
new file mode 100644
index 0000000..24cd69e
--- /dev/null
+++ b/src/net/dnsclient_test.go
@@ -0,0 +1,66 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "testing"
+)
+
+func checkDistribution(t *testing.T, data []*SRV, margin float64) {
+ sum := 0
+ for _, srv := range data {
+ sum += int(srv.Weight)
+ }
+
+ results := make(map[string]int)
+
+ count := 10000
+ for j := 0; j < count; j++ {
+ d := make([]*SRV, len(data))
+ copy(d, data)
+ byPriorityWeight(d).shuffleByWeight()
+ key := d[0].Target
+ results[key] = results[key] + 1
+ }
+
+ actual := results[data[0].Target]
+ expected := float64(count) * float64(data[0].Weight) / float64(sum)
+ diff := float64(actual) - expected
+ t.Logf("actual: %v diff: %v e: %v m: %v", actual, diff, expected, margin)
+ if diff < 0 {
+ diff = -diff
+ }
+ if diff > (expected * margin) {
+ t.Errorf("missed target weight: expected %v, %v", expected, actual)
+ }
+}
+
+func testUniformity(t *testing.T, size int, margin float64) {
+ data := make([]*SRV, size)
+ for i := 0; i < size; i++ {
+ data[i] = &SRV{Target: string('a' + rune(i)), Weight: 1}
+ }
+ checkDistribution(t, data, margin)
+}
+
+func TestDNSSRVUniformity(t *testing.T) {
+ testUniformity(t, 2, 0.05)
+ testUniformity(t, 3, 0.10)
+ testUniformity(t, 10, 0.20)
+ testWeighting(t, 0.05)
+}
+
+func testWeighting(t *testing.T, margin float64) {
+ data := []*SRV{
+ {Target: "a", Weight: 60},
+ {Target: "b", Weight: 30},
+ {Target: "c", Weight: 10},
+ }
+ checkDistribution(t, data, margin)
+}
+
+func TestWeighting(t *testing.T) {
+ testWeighting(t, 0.05)
+}
diff --git a/src/net/dnsclient_unix.go b/src/net/dnsclient_unix.go
new file mode 100644
index 0000000..dab5144
--- /dev/null
+++ b/src/net/dnsclient_unix.go
@@ -0,0 +1,879 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !js
+
+// DNS client: see RFC 1035.
+// Has to be linked into package net for Dial.
+
+// TODO(rsc):
+// Could potentially handle many outstanding lookups faster.
+// Random UDP source port (net.Dial should do that for us).
+// Random request IDs.
+
+package net
+
+import (
+ "context"
+ "errors"
+ "internal/itoa"
+ "io"
+ "os"
+ "runtime"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "golang.org/x/net/dns/dnsmessage"
+)
+
+const (
+ // to be used as a useTCP parameter to exchange
+ useTCPOnly = true
+ useUDPOrTCP = false
+
+ // Maximum DNS packet size.
+ // Value taken from https://dnsflagday.net/2020/.
+ maxDNSPacketSize = 1232
+)
+
+var (
+ errLameReferral = errors.New("lame referral")
+ errCannotUnmarshalDNSMessage = errors.New("cannot unmarshal DNS message")
+ errCannotMarshalDNSMessage = errors.New("cannot marshal DNS message")
+ errServerMisbehaving = errors.New("server misbehaving")
+ errInvalidDNSResponse = errors.New("invalid DNS response")
+ errNoAnswerFromDNSServer = errors.New("no answer from DNS server")
+
+ // errServerTemporarilyMisbehaving is like errServerMisbehaving, except
+ // that when it gets translated to a DNSError, the IsTemporary field
+ // gets set to true.
+ errServerTemporarilyMisbehaving = errors.New("server misbehaving")
+)
+
+func newRequest(q dnsmessage.Question, ad bool) (id uint16, udpReq, tcpReq []byte, err error) {
+ id = uint16(randInt())
+ b := dnsmessage.NewBuilder(make([]byte, 2, 514), dnsmessage.Header{ID: id, RecursionDesired: true, AuthenticData: ad})
+ if err := b.StartQuestions(); err != nil {
+ return 0, nil, nil, err
+ }
+ if err := b.Question(q); err != nil {
+ return 0, nil, nil, err
+ }
+
+ // Accept packets up to maxDNSPacketSize. RFC 6891.
+ if err := b.StartAdditionals(); err != nil {
+ return 0, nil, nil, err
+ }
+ var rh dnsmessage.ResourceHeader
+ if err := rh.SetEDNS0(maxDNSPacketSize, dnsmessage.RCodeSuccess, false); err != nil {
+ return 0, nil, nil, err
+ }
+ if err := b.OPTResource(rh, dnsmessage.OPTResource{}); err != nil {
+ return 0, nil, nil, err
+ }
+
+ tcpReq, err = b.Finish()
+ if err != nil {
+ return 0, nil, nil, err
+ }
+ udpReq = tcpReq[2:]
+ l := len(tcpReq) - 2
+ tcpReq[0] = byte(l >> 8)
+ tcpReq[1] = byte(l)
+ return id, udpReq, tcpReq, nil
+}
+
+func checkResponse(reqID uint16, reqQues dnsmessage.Question, respHdr dnsmessage.Header, respQues dnsmessage.Question) bool {
+ if !respHdr.Response {
+ return false
+ }
+ if reqID != respHdr.ID {
+ return false
+ }
+ if reqQues.Type != respQues.Type || reqQues.Class != respQues.Class || !equalASCIIName(reqQues.Name, respQues.Name) {
+ return false
+ }
+ return true
+}
+
+func dnsPacketRoundTrip(c Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) {
+ if _, err := c.Write(b); err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, err
+ }
+
+ b = make([]byte, maxDNSPacketSize)
+ for {
+ n, err := c.Read(b)
+ if err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, err
+ }
+ var p dnsmessage.Parser
+ // Ignore invalid responses as they may be malicious
+ // forgery attempts. Instead continue waiting until
+ // timeout. See golang.org/issue/13281.
+ h, err := p.Start(b[:n])
+ if err != nil {
+ continue
+ }
+ q, err := p.Question()
+ if err != nil || !checkResponse(id, query, h, q) {
+ continue
+ }
+ return p, h, nil
+ }
+}
+
+func dnsStreamRoundTrip(c Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) {
+ if _, err := c.Write(b); err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, err
+ }
+
+ b = make([]byte, 1280) // 1280 is a reasonable initial size for IP over Ethernet, see RFC 4035
+ if _, err := io.ReadFull(c, b[:2]); err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, err
+ }
+ l := int(b[0])<<8 | int(b[1])
+ if l > len(b) {
+ b = make([]byte, l)
+ }
+ n, err := io.ReadFull(c, b[:l])
+ if err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, err
+ }
+ var p dnsmessage.Parser
+ h, err := p.Start(b[:n])
+ if err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage
+ }
+ q, err := p.Question()
+ if err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage
+ }
+ if !checkResponse(id, query, h, q) {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse
+ }
+ return p, h, nil
+}
+
+// exchange sends a query on the connection and hopes for a response.
+func (r *Resolver) exchange(ctx context.Context, server string, q dnsmessage.Question, timeout time.Duration, useTCP, ad bool) (dnsmessage.Parser, dnsmessage.Header, error) {
+ q.Class = dnsmessage.ClassINET
+ id, udpReq, tcpReq, err := newRequest(q, ad)
+ if err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotMarshalDNSMessage
+ }
+ var networks []string
+ if useTCP {
+ networks = []string{"tcp"}
+ } else {
+ networks = []string{"udp", "tcp"}
+ }
+ for _, network := range networks {
+ ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout))
+ defer cancel()
+
+ c, err := r.dial(ctx, network, server)
+ if err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, err
+ }
+ if d, ok := ctx.Deadline(); ok && !d.IsZero() {
+ c.SetDeadline(d)
+ }
+ var p dnsmessage.Parser
+ var h dnsmessage.Header
+ if _, ok := c.(PacketConn); ok {
+ p, h, err = dnsPacketRoundTrip(c, id, q, udpReq)
+ } else {
+ p, h, err = dnsStreamRoundTrip(c, id, q, tcpReq)
+ }
+ c.Close()
+ if err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, mapErr(err)
+ }
+ if err := p.SkipQuestion(); err != dnsmessage.ErrSectionDone {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse
+ }
+ if h.Truncated { // see RFC 5966
+ continue
+ }
+ return p, h, nil
+ }
+ return dnsmessage.Parser{}, dnsmessage.Header{}, errNoAnswerFromDNSServer
+}
+
+// checkHeader performs basic sanity checks on the header.
+func checkHeader(p *dnsmessage.Parser, h dnsmessage.Header) error {
+ if h.RCode == dnsmessage.RCodeNameError {
+ return errNoSuchHost
+ }
+
+ _, err := p.AnswerHeader()
+ if err != nil && err != dnsmessage.ErrSectionDone {
+ return errCannotUnmarshalDNSMessage
+ }
+
+ // libresolv continues to the next server when it receives
+ // an invalid referral response. See golang.org/issue/15434.
+ if h.RCode == dnsmessage.RCodeSuccess && !h.Authoritative && !h.RecursionAvailable && err == dnsmessage.ErrSectionDone {
+ return errLameReferral
+ }
+
+ if h.RCode != dnsmessage.RCodeSuccess && h.RCode != dnsmessage.RCodeNameError {
+ // None of the error codes make sense
+ // for the query we sent. If we didn't get
+ // a name error and we didn't get success,
+ // the server is behaving incorrectly or
+ // having temporary trouble.
+ if h.RCode == dnsmessage.RCodeServerFailure {
+ return errServerTemporarilyMisbehaving
+ }
+ return errServerMisbehaving
+ }
+
+ return nil
+}
+
+func skipToAnswer(p *dnsmessage.Parser, qtype dnsmessage.Type) error {
+ for {
+ h, err := p.AnswerHeader()
+ if err == dnsmessage.ErrSectionDone {
+ return errNoSuchHost
+ }
+ if err != nil {
+ return errCannotUnmarshalDNSMessage
+ }
+ if h.Type == qtype {
+ return nil
+ }
+ if err := p.SkipAnswer(); err != nil {
+ return errCannotUnmarshalDNSMessage
+ }
+ }
+}
+
+// Do a lookup for a single name, which must be rooted
+// (otherwise answer will not find the answers).
+func (r *Resolver) tryOneName(ctx context.Context, cfg *dnsConfig, name string, qtype dnsmessage.Type) (dnsmessage.Parser, string, error) {
+ var lastErr error
+ serverOffset := cfg.serverOffset()
+ sLen := uint32(len(cfg.servers))
+
+ n, err := dnsmessage.NewName(name)
+ if err != nil {
+ return dnsmessage.Parser{}, "", errCannotMarshalDNSMessage
+ }
+ q := dnsmessage.Question{
+ Name: n,
+ Type: qtype,
+ Class: dnsmessage.ClassINET,
+ }
+
+ for i := 0; i < cfg.attempts; i++ {
+ for j := uint32(0); j < sLen; j++ {
+ server := cfg.servers[(serverOffset+j)%sLen]
+
+ p, h, err := r.exchange(ctx, server, q, cfg.timeout, cfg.useTCP, cfg.trustAD)
+ if err != nil {
+ dnsErr := &DNSError{
+ Err: err.Error(),
+ Name: name,
+ Server: server,
+ }
+ if nerr, ok := err.(Error); ok && nerr.Timeout() {
+ dnsErr.IsTimeout = true
+ }
+ // Set IsTemporary for socket-level errors. Note that this flag
+ // may also be used to indicate a SERVFAIL response.
+ if _, ok := err.(*OpError); ok {
+ dnsErr.IsTemporary = true
+ }
+ lastErr = dnsErr
+ continue
+ }
+
+ if err := checkHeader(&p, h); err != nil {
+ dnsErr := &DNSError{
+ Err: err.Error(),
+ Name: name,
+ Server: server,
+ }
+ if err == errServerTemporarilyMisbehaving {
+ dnsErr.IsTemporary = true
+ }
+ if err == errNoSuchHost {
+ // The name does not exist, so trying
+ // another server won't help.
+
+ dnsErr.IsNotFound = true
+ return p, server, dnsErr
+ }
+ lastErr = dnsErr
+ continue
+ }
+
+ err = skipToAnswer(&p, qtype)
+ if err == nil {
+ return p, server, nil
+ }
+ lastErr = &DNSError{
+ Err: err.Error(),
+ Name: name,
+ Server: server,
+ }
+ if err == errNoSuchHost {
+ // The name does not exist, so trying another
+ // server won't help.
+
+ lastErr.(*DNSError).IsNotFound = true
+ return p, server, lastErr
+ }
+ }
+ }
+ return dnsmessage.Parser{}, "", lastErr
+}
+
+// A resolverConfig represents a DNS stub resolver configuration.
+type resolverConfig struct {
+ initOnce sync.Once // guards init of resolverConfig
+
+ // ch is used as a semaphore that only allows one lookup at a
+ // time to recheck resolv.conf.
+ ch chan struct{} // guards lastChecked and modTime
+ lastChecked time.Time // last time resolv.conf was checked
+
+ dnsConfig atomic.Pointer[dnsConfig] // parsed resolv.conf structure used in lookups
+}
+
+var resolvConf resolverConfig
+
+func getSystemDNSConfig() *dnsConfig {
+ resolvConf.tryUpdate("/etc/resolv.conf")
+ return resolvConf.dnsConfig.Load()
+}
+
+// init initializes conf and is only called via conf.initOnce.
+func (conf *resolverConfig) init() {
+ // Set dnsConfig and lastChecked so we don't parse
+ // resolv.conf twice the first time.
+ conf.dnsConfig.Store(dnsReadConfig("/etc/resolv.conf"))
+ conf.lastChecked = time.Now()
+
+ // Prepare ch so that only one update of resolverConfig may
+ // run at once.
+ conf.ch = make(chan struct{}, 1)
+}
+
+// tryUpdate tries to update conf with the named resolv.conf file.
+// The name variable only exists for testing. It is otherwise always
+// "/etc/resolv.conf".
+func (conf *resolverConfig) tryUpdate(name string) {
+ conf.initOnce.Do(conf.init)
+
+ if conf.dnsConfig.Load().noReload {
+ return
+ }
+
+ // Ensure only one update at a time checks resolv.conf.
+ if !conf.tryAcquireSema() {
+ return
+ }
+ defer conf.releaseSema()
+
+ now := time.Now()
+ if conf.lastChecked.After(now.Add(-5 * time.Second)) {
+ return
+ }
+ conf.lastChecked = now
+
+ switch runtime.GOOS {
+ case "windows":
+ // There's no file on disk, so don't bother checking
+ // and failing.
+ //
+ // The Windows implementation of dnsReadConfig (called
+ // below) ignores the name.
+ default:
+ var mtime time.Time
+ if fi, err := os.Stat(name); err == nil {
+ mtime = fi.ModTime()
+ }
+ if mtime.Equal(conf.dnsConfig.Load().mtime) {
+ return
+ }
+ }
+
+ dnsConf := dnsReadConfig(name)
+ conf.dnsConfig.Store(dnsConf)
+}
+
+func (conf *resolverConfig) tryAcquireSema() bool {
+ select {
+ case conf.ch <- struct{}{}:
+ return true
+ default:
+ return false
+ }
+}
+
+func (conf *resolverConfig) releaseSema() {
+ <-conf.ch
+}
+
+func (r *Resolver) lookup(ctx context.Context, name string, qtype dnsmessage.Type, conf *dnsConfig) (dnsmessage.Parser, string, error) {
+ if !isDomainName(name) {
+ // We used to use "invalid domain name" as the error,
+ // but that is a detail of the specific lookup mechanism.
+ // Other lookups might allow broader name syntax
+ // (for example Multicast DNS allows UTF-8; see RFC 6762).
+ // For consistency with libc resolvers, report no such host.
+ return dnsmessage.Parser{}, "", &DNSError{Err: errNoSuchHost.Error(), Name: name, IsNotFound: true}
+ }
+
+ if conf == nil {
+ conf = getSystemDNSConfig()
+ }
+
+ var (
+ p dnsmessage.Parser
+ server string
+ err error
+ )
+ for _, fqdn := range conf.nameList(name) {
+ p, server, err = r.tryOneName(ctx, conf, fqdn, qtype)
+ if err == nil {
+ break
+ }
+ if nerr, ok := err.(Error); ok && nerr.Temporary() && r.strictErrors() {
+ // If we hit a temporary error with StrictErrors enabled,
+ // stop immediately instead of trying more names.
+ break
+ }
+ }
+ if err == nil {
+ return p, server, nil
+ }
+ if err, ok := err.(*DNSError); ok {
+ // Show original name passed to lookup, not suffixed one.
+ // In general we might have tried many suffixes; showing
+ // just one is misleading. See also golang.org/issue/6324.
+ err.Name = name
+ }
+ return dnsmessage.Parser{}, "", err
+}
+
+// avoidDNS reports whether this is a hostname for which we should not
+// use DNS. Currently this includes only .onion, per RFC 7686. See
+// golang.org/issue/13705. Does not cover .local names (RFC 6762),
+// see golang.org/issue/16739.
+func avoidDNS(name string) bool {
+ if name == "" {
+ return true
+ }
+ if name[len(name)-1] == '.' {
+ name = name[:len(name)-1]
+ }
+ return stringsHasSuffixFold(name, ".onion")
+}
+
+// nameList returns a list of names for sequential DNS queries.
+func (conf *dnsConfig) nameList(name string) []string {
+ if avoidDNS(name) {
+ return nil
+ }
+
+ // Check name length (see isDomainName).
+ l := len(name)
+ rooted := l > 0 && name[l-1] == '.'
+ if l > 254 || l == 254 && !rooted {
+ return nil
+ }
+
+ // If name is rooted (trailing dot), try only that name.
+ if rooted {
+ return []string{name}
+ }
+
+ hasNdots := count(name, '.') >= conf.ndots
+ name += "."
+ l++
+
+ // Build list of search choices.
+ names := make([]string, 0, 1+len(conf.search))
+ // If name has enough dots, try unsuffixed first.
+ if hasNdots {
+ names = append(names, name)
+ }
+ // Try suffixes that are not too long (see isDomainName).
+ for _, suffix := range conf.search {
+ if l+len(suffix) <= 254 {
+ names = append(names, name+suffix)
+ }
+ }
+ // Try unsuffixed, if not tried first above.
+ if !hasNdots {
+ names = append(names, name)
+ }
+ return names
+}
+
+// hostLookupOrder specifies the order of LookupHost lookup strategies.
+// It is basically a simplified representation of nsswitch.conf.
+// "files" means /etc/hosts.
+type hostLookupOrder int
+
+const (
+ // hostLookupCgo means defer to cgo.
+ hostLookupCgo hostLookupOrder = iota
+ hostLookupFilesDNS // files first
+ hostLookupDNSFiles // dns first
+ hostLookupFiles // only files
+ hostLookupDNS // only DNS
+)
+
+var lookupOrderName = map[hostLookupOrder]string{
+ hostLookupCgo: "cgo",
+ hostLookupFilesDNS: "files,dns",
+ hostLookupDNSFiles: "dns,files",
+ hostLookupFiles: "files",
+ hostLookupDNS: "dns",
+}
+
+func (o hostLookupOrder) String() string {
+ if s, ok := lookupOrderName[o]; ok {
+ return s
+ }
+ return "hostLookupOrder=" + itoa.Itoa(int(o)) + "??"
+}
+
+func (r *Resolver) goLookupHostOrder(ctx context.Context, name string, order hostLookupOrder, conf *dnsConfig) (addrs []string, err error) {
+ if order == hostLookupFilesDNS || order == hostLookupFiles {
+ // Use entries from /etc/hosts if they match.
+ addrs, _ = lookupStaticHost(name)
+ if len(addrs) > 0 {
+ return
+ }
+
+ if order == hostLookupFiles {
+ return nil, &DNSError{Err: errNoSuchHost.Error(), Name: name, IsNotFound: true}
+ }
+ }
+ ips, _, err := r.goLookupIPCNAMEOrder(ctx, "ip", name, order, conf)
+ if err != nil {
+ return
+ }
+ addrs = make([]string, 0, len(ips))
+ for _, ip := range ips {
+ addrs = append(addrs, ip.String())
+ }
+ return
+}
+
+// lookup entries from /etc/hosts
+func goLookupIPFiles(name string) (addrs []IPAddr, canonical string) {
+ addr, canonical := lookupStaticHost(name)
+ for _, haddr := range addr {
+ haddr, zone := splitHostZone(haddr)
+ if ip := ParseIP(haddr); ip != nil {
+ addr := IPAddr{IP: ip, Zone: zone}
+ addrs = append(addrs, addr)
+ }
+ }
+ sortByRFC6724(addrs)
+ return addrs, canonical
+}
+
+// goLookupIP is the native Go implementation of LookupIP.
+// The libc versions are in cgo_*.go.
+func (r *Resolver) goLookupIP(ctx context.Context, network, host string) (addrs []IPAddr, err error) {
+ order, conf := systemConf().hostLookupOrder(r, host)
+ addrs, _, err = r.goLookupIPCNAMEOrder(ctx, network, host, order, conf)
+ return
+}
+
+func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, network, name string, order hostLookupOrder, conf *dnsConfig) (addrs []IPAddr, cname dnsmessage.Name, err error) {
+ if order == hostLookupFilesDNS || order == hostLookupFiles {
+ var canonical string
+ addrs, canonical = goLookupIPFiles(name)
+
+ if len(addrs) > 0 {
+ var err error
+ cname, err = dnsmessage.NewName(canonical)
+ if err != nil {
+ return nil, dnsmessage.Name{}, err
+ }
+ return addrs, cname, nil
+ }
+
+ if order == hostLookupFiles {
+ return nil, dnsmessage.Name{}, &DNSError{Err: errNoSuchHost.Error(), Name: name, IsNotFound: true}
+ }
+ }
+
+ if !isDomainName(name) {
+ // See comment in func lookup above about use of errNoSuchHost.
+ return nil, dnsmessage.Name{}, &DNSError{Err: errNoSuchHost.Error(), Name: name, IsNotFound: true}
+ }
+ type result struct {
+ p dnsmessage.Parser
+ server string
+ error
+ }
+
+ if conf == nil {
+ conf = getSystemDNSConfig()
+ }
+
+ lane := make(chan result, 1)
+ qtypes := []dnsmessage.Type{dnsmessage.TypeA, dnsmessage.TypeAAAA}
+ if network == "CNAME" {
+ qtypes = append(qtypes, dnsmessage.TypeCNAME)
+ }
+ switch ipVersion(network) {
+ case '4':
+ qtypes = []dnsmessage.Type{dnsmessage.TypeA}
+ case '6':
+ qtypes = []dnsmessage.Type{dnsmessage.TypeAAAA}
+ }
+ var queryFn func(fqdn string, qtype dnsmessage.Type)
+ var responseFn func(fqdn string, qtype dnsmessage.Type) result
+ if conf.singleRequest {
+ queryFn = func(fqdn string, qtype dnsmessage.Type) {}
+ responseFn = func(fqdn string, qtype dnsmessage.Type) result {
+ dnsWaitGroup.Add(1)
+ defer dnsWaitGroup.Done()
+ p, server, err := r.tryOneName(ctx, conf, fqdn, qtype)
+ return result{p, server, err}
+ }
+ } else {
+ queryFn = func(fqdn string, qtype dnsmessage.Type) {
+ dnsWaitGroup.Add(1)
+ go func(qtype dnsmessage.Type) {
+ p, server, err := r.tryOneName(ctx, conf, fqdn, qtype)
+ lane <- result{p, server, err}
+ dnsWaitGroup.Done()
+ }(qtype)
+ }
+ responseFn = func(fqdn string, qtype dnsmessage.Type) result {
+ return <-lane
+ }
+ }
+ var lastErr error
+ for _, fqdn := range conf.nameList(name) {
+ for _, qtype := range qtypes {
+ queryFn(fqdn, qtype)
+ }
+ hitStrictError := false
+ for _, qtype := range qtypes {
+ result := responseFn(fqdn, qtype)
+ if result.error != nil {
+ if nerr, ok := result.error.(Error); ok && nerr.Temporary() && r.strictErrors() {
+ // This error will abort the nameList loop.
+ hitStrictError = true
+ lastErr = result.error
+ } else if lastErr == nil || fqdn == name+"." {
+ // Prefer error for original name.
+ lastErr = result.error
+ }
+ continue
+ }
+
+ // Presotto says it's okay to assume that servers listed in
+ // /etc/resolv.conf are recursive resolvers.
+ //
+ // We asked for recursion, so it should have included all the
+ // answers we need in this one packet.
+ //
+ // Further, RFC 1034 section 4.3.1 says that "the recursive
+ // response to a query will be... The answer to the query,
+ // possibly preface by one or more CNAME RRs that specify
+ // aliases encountered on the way to an answer."
+ //
+ // Therefore, we should be able to assume that we can ignore
+ // CNAMEs and that the A and AAAA records we requested are
+ // for the canonical name.
+
+ loop:
+ for {
+ h, err := result.p.AnswerHeader()
+ if err != nil && err != dnsmessage.ErrSectionDone {
+ lastErr = &DNSError{
+ Err: "cannot marshal DNS message",
+ Name: name,
+ Server: result.server,
+ }
+ }
+ if err != nil {
+ break
+ }
+ switch h.Type {
+ case dnsmessage.TypeA:
+ a, err := result.p.AResource()
+ if err != nil {
+ lastErr = &DNSError{
+ Err: "cannot marshal DNS message",
+ Name: name,
+ Server: result.server,
+ }
+ break loop
+ }
+ addrs = append(addrs, IPAddr{IP: IP(a.A[:])})
+ if cname.Length == 0 && h.Name.Length != 0 {
+ cname = h.Name
+ }
+
+ case dnsmessage.TypeAAAA:
+ aaaa, err := result.p.AAAAResource()
+ if err != nil {
+ lastErr = &DNSError{
+ Err: "cannot marshal DNS message",
+ Name: name,
+ Server: result.server,
+ }
+ break loop
+ }
+ addrs = append(addrs, IPAddr{IP: IP(aaaa.AAAA[:])})
+ if cname.Length == 0 && h.Name.Length != 0 {
+ cname = h.Name
+ }
+
+ case dnsmessage.TypeCNAME:
+ c, err := result.p.CNAMEResource()
+ if err != nil {
+ lastErr = &DNSError{
+ Err: "cannot marshal DNS message",
+ Name: name,
+ Server: result.server,
+ }
+ break loop
+ }
+ if cname.Length == 0 && c.CNAME.Length > 0 {
+ cname = c.CNAME
+ }
+
+ default:
+ if err := result.p.SkipAnswer(); err != nil {
+ lastErr = &DNSError{
+ Err: "cannot marshal DNS message",
+ Name: name,
+ Server: result.server,
+ }
+ break loop
+ }
+ continue
+ }
+ }
+ }
+ if hitStrictError {
+ // If either family hit an error with StrictErrors enabled,
+ // discard all addresses. This ensures that network flakiness
+ // cannot turn a dualstack hostname IPv4/IPv6-only.
+ addrs = nil
+ break
+ }
+ if len(addrs) > 0 || network == "CNAME" && cname.Length > 0 {
+ break
+ }
+ }
+ if lastErr, ok := lastErr.(*DNSError); ok {
+ // Show original name passed to lookup, not suffixed one.
+ // In general we might have tried many suffixes; showing
+ // just one is misleading. See also golang.org/issue/6324.
+ lastErr.Name = name
+ }
+ sortByRFC6724(addrs)
+ if len(addrs) == 0 && !(network == "CNAME" && cname.Length > 0) {
+ if order == hostLookupDNSFiles {
+ var canonical string
+ addrs, canonical = goLookupIPFiles(name)
+ if len(addrs) > 0 {
+ var err error
+ cname, err = dnsmessage.NewName(canonical)
+ if err != nil {
+ return nil, dnsmessage.Name{}, err
+ }
+ return addrs, cname, nil
+ }
+ }
+ if lastErr != nil {
+ return nil, dnsmessage.Name{}, lastErr
+ }
+ }
+ return addrs, cname, nil
+}
+
+// goLookupCNAME is the native Go (non-cgo) implementation of LookupCNAME.
+func (r *Resolver) goLookupCNAME(ctx context.Context, host string, order hostLookupOrder, conf *dnsConfig) (string, error) {
+ _, cname, err := r.goLookupIPCNAMEOrder(ctx, "CNAME", host, order, conf)
+ return cname.String(), err
+}
+
+// goLookupPTR is the native Go implementation of LookupAddr.
+func (r *Resolver) goLookupPTR(ctx context.Context, addr string, order hostLookupOrder, conf *dnsConfig) ([]string, error) {
+ if order == hostLookupFiles || order == hostLookupFilesDNS {
+ names := lookupStaticAddr(addr)
+ if len(names) > 0 {
+ return names, nil
+ }
+
+ if order == hostLookupFiles {
+ return nil, &DNSError{Err: errNoSuchHost.Error(), Name: addr, IsNotFound: true}
+ }
+ }
+
+ arpa, err := reverseaddr(addr)
+ if err != nil {
+ return nil, err
+ }
+ p, server, err := r.lookup(ctx, arpa, dnsmessage.TypePTR, conf)
+ if err != nil {
+ var dnsErr *DNSError
+ if errors.As(err, &dnsErr) && dnsErr.IsNotFound {
+ if order == hostLookupDNSFiles {
+ names := lookupStaticAddr(addr)
+ if len(names) > 0 {
+ return names, nil
+ }
+ }
+ }
+ return nil, err
+ }
+ var ptrs []string
+ for {
+ h, err := p.AnswerHeader()
+ if err == dnsmessage.ErrSectionDone {
+ break
+ }
+ if err != nil {
+ return nil, &DNSError{
+ Err: "cannot marshal DNS message",
+ Name: addr,
+ Server: server,
+ }
+ }
+ if h.Type != dnsmessage.TypePTR {
+ err := p.SkipAnswer()
+ if err != nil {
+ return nil, &DNSError{
+ Err: "cannot marshal DNS message",
+ Name: addr,
+ Server: server,
+ }
+ }
+ continue
+ }
+ ptr, err := p.PTRResource()
+ if err != nil {
+ return nil, &DNSError{
+ Err: "cannot marshal DNS message",
+ Name: addr,
+ Server: server,
+ }
+ }
+ ptrs = append(ptrs, ptr.PTR.String())
+
+ }
+
+ return ptrs, nil
+}
diff --git a/src/net/dnsclient_unix_test.go b/src/net/dnsclient_unix_test.go
new file mode 100644
index 0000000..8d435a5
--- /dev/null
+++ b/src/net/dnsclient_unix_test.go
@@ -0,0 +1,2600 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build unix
+
+package net
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "internal/testenv"
+ "os"
+ "path"
+ "path/filepath"
+ "reflect"
+ "runtime"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "golang.org/x/net/dns/dnsmessage"
+)
+
+var goResolver = Resolver{PreferGo: true}
+
+// Test address from 192.0.2.0/24 block, reserved by RFC 5737 for documentation.
+var TestAddr = [4]byte{0xc0, 0x00, 0x02, 0x01}
+
+// Test address from 2001:db8::/32 block, reserved by RFC 3849 for documentation.
+var TestAddr6 = [16]byte{0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
+
+func mustNewName(name string) dnsmessage.Name {
+ nn, err := dnsmessage.NewName(name)
+ if err != nil {
+ panic(fmt.Sprint("creating name: ", err))
+ }
+ return nn
+}
+
+func mustQuestion(name string, qtype dnsmessage.Type, class dnsmessage.Class) dnsmessage.Question {
+ return dnsmessage.Question{
+ Name: mustNewName(name),
+ Type: qtype,
+ Class: class,
+ }
+}
+
+var dnsTransportFallbackTests = []struct {
+ server string
+ question dnsmessage.Question
+ timeout int
+ rcode dnsmessage.RCode
+}{
+ // Querying "com." with qtype=255 usually makes an answer
+ // which requires more than 512 bytes.
+ {"8.8.8.8:53", mustQuestion("com.", dnsmessage.TypeALL, dnsmessage.ClassINET), 2, dnsmessage.RCodeSuccess},
+ {"8.8.4.4:53", mustQuestion("com.", dnsmessage.TypeALL, dnsmessage.ClassINET), 4, dnsmessage.RCodeSuccess},
+}
+
+func TestDNSTransportFallback(t *testing.T) {
+ fake := fakeDNSServer{
+ rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.Header.ID,
+ Response: true,
+ RCode: dnsmessage.RCodeSuccess,
+ },
+ Questions: q.Questions,
+ }
+ if n == "udp" {
+ r.Header.Truncated = true
+ }
+ return r, nil
+ },
+ }
+ r := Resolver{PreferGo: true, Dial: fake.DialContext}
+ for _, tt := range dnsTransportFallbackTests {
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ _, h, err := r.exchange(ctx, tt.server, tt.question, time.Second, useUDPOrTCP, false)
+ if err != nil {
+ t.Error(err)
+ continue
+ }
+ if h.RCode != tt.rcode {
+ t.Errorf("got %v from %v; want %v", h.RCode, tt.server, tt.rcode)
+ continue
+ }
+ }
+}
+
+// See RFC 6761 for further information about the reserved, pseudo
+// domain names.
+var specialDomainNameTests = []struct {
+ question dnsmessage.Question
+ rcode dnsmessage.RCode
+}{
+ // Name resolution APIs and libraries should not recognize the
+ // followings as special.
+ {mustQuestion("1.0.168.192.in-addr.arpa.", dnsmessage.TypePTR, dnsmessage.ClassINET), dnsmessage.RCodeNameError},
+ {mustQuestion("test.", dnsmessage.TypeALL, dnsmessage.ClassINET), dnsmessage.RCodeNameError},
+ {mustQuestion("example.com.", dnsmessage.TypeALL, dnsmessage.ClassINET), dnsmessage.RCodeSuccess},
+
+ // Name resolution APIs and libraries should recognize the
+ // followings as special and should not send any queries.
+ // Though, we test those names here for verifying negative
+ // answers at DNS query-response interaction level.
+ {mustQuestion("localhost.", dnsmessage.TypeALL, dnsmessage.ClassINET), dnsmessage.RCodeNameError},
+ {mustQuestion("invalid.", dnsmessage.TypeALL, dnsmessage.ClassINET), dnsmessage.RCodeNameError},
+}
+
+func TestSpecialDomainName(t *testing.T) {
+ fake := fakeDNSServer{rh: func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.ID,
+ Response: true,
+ },
+ Questions: q.Questions,
+ }
+
+ switch q.Questions[0].Name.String() {
+ case "example.com.":
+ r.Header.RCode = dnsmessage.RCodeSuccess
+ default:
+ r.Header.RCode = dnsmessage.RCodeNameError
+ }
+
+ return r, nil
+ }}
+ r := Resolver{PreferGo: true, Dial: fake.DialContext}
+ server := "8.8.8.8:53"
+ for _, tt := range specialDomainNameTests {
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ _, h, err := r.exchange(ctx, server, tt.question, 3*time.Second, useUDPOrTCP, false)
+ if err != nil {
+ t.Error(err)
+ continue
+ }
+ if h.RCode != tt.rcode {
+ t.Errorf("got %v from %v; want %v", h.RCode, server, tt.rcode)
+ continue
+ }
+ }
+}
+
+// Issue 13705: don't try to resolve onion addresses, etc
+func TestAvoidDNSName(t *testing.T) {
+ tests := []struct {
+ name string
+ avoid bool
+ }{
+ {"foo.com", false},
+ {"foo.com.", false},
+
+ {"foo.onion.", true},
+ {"foo.onion", true},
+ {"foo.ONION", true},
+ {"foo.ONION.", true},
+
+ // But do resolve *.local address; Issue 16739
+ {"foo.local.", false},
+ {"foo.local", false},
+ {"foo.LOCAL", false},
+ {"foo.LOCAL.", false},
+
+ {"", true}, // will be rejected earlier too
+
+ // Without stuff before onion/local, they're fine to
+ // use DNS. With a search path,
+ // "onion.vegetables.com" can use DNS. Without a
+ // search path (or with a trailing dot), the queries
+ // are just kinda useless, but don't reveal anything
+ // private.
+ {"local", false},
+ {"onion", false},
+ {"local.", false},
+ {"onion.", false},
+ }
+ for _, tt := range tests {
+ got := avoidDNS(tt.name)
+ if got != tt.avoid {
+ t.Errorf("avoidDNS(%q) = %v; want %v", tt.name, got, tt.avoid)
+ }
+ }
+}
+
+var fakeDNSServerSuccessful = fakeDNSServer{rh: func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.ID,
+ Response: true,
+ },
+ Questions: q.Questions,
+ }
+ if len(q.Questions) == 1 && q.Questions[0].Type == dnsmessage.TypeA {
+ r.Answers = []dnsmessage.Resource{
+ {
+ Header: dnsmessage.ResourceHeader{
+ Name: q.Questions[0].Name,
+ Type: dnsmessage.TypeA,
+ Class: dnsmessage.ClassINET,
+ Length: 4,
+ },
+ Body: &dnsmessage.AResource{
+ A: TestAddr,
+ },
+ },
+ }
+ }
+ return r, nil
+}}
+
+// Issue 13705: don't try to resolve onion addresses, etc
+func TestLookupTorOnion(t *testing.T) {
+ defer dnsWaitGroup.Wait()
+ r := Resolver{PreferGo: true, Dial: fakeDNSServerSuccessful.DialContext}
+ addrs, err := r.LookupIPAddr(context.Background(), "foo.onion")
+ if err != nil {
+ t.Fatalf("lookup = %v; want nil", err)
+ }
+ if len(addrs) > 0 {
+ t.Errorf("unexpected addresses: %v", addrs)
+ }
+}
+
+type resolvConfTest struct {
+ dir string
+ path string
+ *resolverConfig
+}
+
+func newResolvConfTest() (*resolvConfTest, error) {
+ dir, err := os.MkdirTemp("", "go-resolvconftest")
+ if err != nil {
+ return nil, err
+ }
+ conf := &resolvConfTest{
+ dir: dir,
+ path: path.Join(dir, "resolv.conf"),
+ resolverConfig: &resolvConf,
+ }
+ conf.initOnce.Do(conf.init)
+ return conf, nil
+}
+
+func (conf *resolvConfTest) write(lines []string) error {
+ f, err := os.OpenFile(conf.path, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600)
+ if err != nil {
+ return err
+ }
+ if _, err := f.WriteString(strings.Join(lines, "\n")); err != nil {
+ f.Close()
+ return err
+ }
+ f.Close()
+ return nil
+}
+
+func (conf *resolvConfTest) writeAndUpdate(lines []string) error {
+ return conf.writeAndUpdateWithLastCheckedTime(lines, time.Now().Add(time.Hour))
+}
+
+func (conf *resolvConfTest) writeAndUpdateWithLastCheckedTime(lines []string, lastChecked time.Time) error {
+ if err := conf.write(lines); err != nil {
+ return err
+ }
+ return conf.forceUpdate(conf.path, lastChecked)
+}
+
+func (conf *resolvConfTest) forceUpdate(name string, lastChecked time.Time) error {
+ dnsConf := dnsReadConfig(name)
+ if !conf.forceUpdateConf(dnsConf, lastChecked) {
+ return fmt.Errorf("tryAcquireSema for %s failed", name)
+ }
+ return nil
+}
+
+func (conf *resolvConfTest) forceUpdateConf(c *dnsConfig, lastChecked time.Time) bool {
+ conf.dnsConfig.Store(c)
+ for i := 0; i < 5; i++ {
+ if conf.tryAcquireSema() {
+ conf.lastChecked = lastChecked
+ conf.releaseSema()
+ return true
+ }
+ }
+ return false
+}
+
+func (conf *resolvConfTest) servers() []string {
+ return conf.dnsConfig.Load().servers
+}
+
+func (conf *resolvConfTest) teardown() error {
+ err := conf.forceUpdate("/etc/resolv.conf", time.Time{})
+ os.RemoveAll(conf.dir)
+ return err
+}
+
+var updateResolvConfTests = []struct {
+ name string // query name
+ lines []string // resolver configuration lines
+ servers []string // expected name servers
+}{
+ {
+ name: "golang.org",
+ lines: []string{"nameserver 8.8.8.8"},
+ servers: []string{"8.8.8.8:53"},
+ },
+ {
+ name: "",
+ lines: nil, // an empty resolv.conf should use defaultNS as name servers
+ servers: defaultNS,
+ },
+ {
+ name: "www.example.com",
+ lines: []string{"nameserver 8.8.4.4"},
+ servers: []string{"8.8.4.4:53"},
+ },
+}
+
+func TestUpdateResolvConf(t *testing.T) {
+ defer dnsWaitGroup.Wait()
+
+ r := Resolver{PreferGo: true, Dial: fakeDNSServerSuccessful.DialContext}
+
+ conf, err := newResolvConfTest()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conf.teardown()
+
+ for i, tt := range updateResolvConfTests {
+ if err := conf.writeAndUpdate(tt.lines); err != nil {
+ t.Error(err)
+ continue
+ }
+ if tt.name != "" {
+ var wg sync.WaitGroup
+ const N = 10
+ wg.Add(N)
+ for j := 0; j < N; j++ {
+ go func(name string) {
+ defer wg.Done()
+ ips, err := r.LookupIPAddr(context.Background(), name)
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ if len(ips) == 0 {
+ t.Errorf("no records for %s", name)
+ return
+ }
+ }(tt.name)
+ }
+ wg.Wait()
+ }
+ servers := conf.servers()
+ if !reflect.DeepEqual(servers, tt.servers) {
+ t.Errorf("#%d: got %v; want %v", i, servers, tt.servers)
+ continue
+ }
+ }
+}
+
+var goLookupIPWithResolverConfigTests = []struct {
+ name string
+ lines []string // resolver configuration lines
+ error
+ a, aaaa bool // whether response contains A, AAAA-record
+}{
+ // no records, transport timeout
+ {
+ "jgahvsekduiv9bw4b3qhn4ykdfgj0493iohkrjfhdvhjiu4j",
+ []string{
+ "options timeout:1 attempts:1",
+ "nameserver 255.255.255.255", // please forgive us for abuse of limited broadcast address
+ },
+ &DNSError{Name: "jgahvsekduiv9bw4b3qhn4ykdfgj0493iohkrjfhdvhjiu4j", Server: "255.255.255.255:53", IsTimeout: true},
+ false, false,
+ },
+
+ // no records, non-existent domain
+ {
+ "jgahvsekduiv9bw4b3qhn4ykdfgj0493iohkrjfhdvhjiu4j",
+ []string{
+ "options timeout:3 attempts:1",
+ "nameserver 8.8.8.8",
+ },
+ &DNSError{Name: "jgahvsekduiv9bw4b3qhn4ykdfgj0493iohkrjfhdvhjiu4j", Server: "8.8.8.8:53", IsTimeout: false},
+ false, false,
+ },
+
+ // a few A records, no AAAA records
+ {
+ "ipv4.google.com.",
+ []string{
+ "nameserver 8.8.8.8",
+ "nameserver 2001:4860:4860::8888",
+ },
+ nil,
+ true, false,
+ },
+ {
+ "ipv4.google.com",
+ []string{
+ "domain golang.org",
+ "nameserver 2001:4860:4860::8888",
+ "nameserver 8.8.8.8",
+ },
+ nil,
+ true, false,
+ },
+ {
+ "ipv4.google.com",
+ []string{
+ "search x.golang.org y.golang.org",
+ "nameserver 2001:4860:4860::8888",
+ "nameserver 8.8.8.8",
+ },
+ nil,
+ true, false,
+ },
+
+ // no A records, a few AAAA records
+ {
+ "ipv6.google.com.",
+ []string{
+ "nameserver 2001:4860:4860::8888",
+ "nameserver 8.8.8.8",
+ },
+ nil,
+ false, true,
+ },
+ {
+ "ipv6.google.com",
+ []string{
+ "domain golang.org",
+ "nameserver 8.8.8.8",
+ "nameserver 2001:4860:4860::8888",
+ },
+ nil,
+ false, true,
+ },
+ {
+ "ipv6.google.com",
+ []string{
+ "search x.golang.org y.golang.org",
+ "nameserver 8.8.8.8",
+ "nameserver 2001:4860:4860::8888",
+ },
+ nil,
+ false, true,
+ },
+
+ // both A and AAAA records
+ {
+ "hostname.as112.net", // see RFC 7534
+ []string{
+ "domain golang.org",
+ "nameserver 2001:4860:4860::8888",
+ "nameserver 8.8.8.8",
+ },
+ nil,
+ true, true,
+ },
+ {
+ "hostname.as112.net", // see RFC 7534
+ []string{
+ "search x.golang.org y.golang.org",
+ "nameserver 2001:4860:4860::8888",
+ "nameserver 8.8.8.8",
+ },
+ nil,
+ true, true,
+ },
+}
+
+func TestGoLookupIPWithResolverConfig(t *testing.T) {
+ defer dnsWaitGroup.Wait()
+ fake := fakeDNSServer{rh: func(n, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ switch s {
+ case "[2001:4860:4860::8888]:53", "8.8.8.8:53":
+ break
+ default:
+ time.Sleep(10 * time.Millisecond)
+ return dnsmessage.Message{}, os.ErrDeadlineExceeded
+ }
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.ID,
+ Response: true,
+ },
+ Questions: q.Questions,
+ }
+ for _, question := range q.Questions {
+ switch question.Type {
+ case dnsmessage.TypeA:
+ switch question.Name.String() {
+ case "hostname.as112.net.":
+ break
+ case "ipv4.google.com.":
+ r.Answers = append(r.Answers, dnsmessage.Resource{
+ Header: dnsmessage.ResourceHeader{
+ Name: q.Questions[0].Name,
+ Type: dnsmessage.TypeA,
+ Class: dnsmessage.ClassINET,
+ Length: 4,
+ },
+ Body: &dnsmessage.AResource{
+ A: TestAddr,
+ },
+ })
+ default:
+
+ }
+ case dnsmessage.TypeAAAA:
+ switch question.Name.String() {
+ case "hostname.as112.net.":
+ break
+ case "ipv6.google.com.":
+ r.Answers = append(r.Answers, dnsmessage.Resource{
+ Header: dnsmessage.ResourceHeader{
+ Name: q.Questions[0].Name,
+ Type: dnsmessage.TypeAAAA,
+ Class: dnsmessage.ClassINET,
+ Length: 16,
+ },
+ Body: &dnsmessage.AAAAResource{
+ AAAA: TestAddr6,
+ },
+ })
+ }
+ }
+ }
+ return r, nil
+ }}
+ r := Resolver{PreferGo: true, Dial: fake.DialContext}
+
+ conf, err := newResolvConfTest()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conf.teardown()
+
+ for _, tt := range goLookupIPWithResolverConfigTests {
+ if err := conf.writeAndUpdate(tt.lines); err != nil {
+ t.Error(err)
+ continue
+ }
+ addrs, err := r.LookupIPAddr(context.Background(), tt.name)
+ if err != nil {
+ if err, ok := err.(*DNSError); !ok || tt.error != nil && (err.Name != tt.error.(*DNSError).Name || err.Server != tt.error.(*DNSError).Server || err.IsTimeout != tt.error.(*DNSError).IsTimeout) {
+ t.Errorf("got %v; want %v", err, tt.error)
+ }
+ continue
+ }
+ if len(addrs) == 0 {
+ t.Errorf("no records for %s", tt.name)
+ }
+ if !tt.a && !tt.aaaa && len(addrs) > 0 {
+ t.Errorf("unexpected %v for %s", addrs, tt.name)
+ }
+ for _, addr := range addrs {
+ if !tt.a && addr.IP.To4() != nil {
+ t.Errorf("got %v; must not be IPv4 address", addr)
+ }
+ if !tt.aaaa && addr.IP.To16() != nil && addr.IP.To4() == nil {
+ t.Errorf("got %v; must not be IPv6 address", addr)
+ }
+ }
+ }
+}
+
+// Test that goLookupIPOrder falls back to the host file when no DNS servers are available.
+func TestGoLookupIPOrderFallbackToFile(t *testing.T) {
+ defer dnsWaitGroup.Wait()
+
+ fake := fakeDNSServer{rh: func(n, s string, q dnsmessage.Message, tm time.Time) (dnsmessage.Message, error) {
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.ID,
+ Response: true,
+ },
+ Questions: q.Questions,
+ }
+ return r, nil
+ }}
+ r := Resolver{PreferGo: true, Dial: fake.DialContext}
+
+ // Add a config that simulates no dns servers being available.
+ conf, err := newResolvConfTest()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conf.teardown()
+
+ if err := conf.writeAndUpdate([]string{}); err != nil {
+ t.Fatal(err)
+ }
+ // Redirect host file lookups.
+ defer func(orig string) { testHookHostsPath = orig }(testHookHostsPath)
+ testHookHostsPath = "testdata/hosts"
+
+ for _, order := range []hostLookupOrder{hostLookupFilesDNS, hostLookupDNSFiles} {
+ name := fmt.Sprintf("order %v", order)
+ // First ensure that we get an error when contacting a non-existent host.
+ _, _, err := r.goLookupIPCNAMEOrder(context.Background(), "ip", "notarealhost", order, nil)
+ if err == nil {
+ t.Errorf("%s: expected error while looking up name not in hosts file", name)
+ continue
+ }
+
+ // Now check that we get an address when the name appears in the hosts file.
+ addrs, _, err := r.goLookupIPCNAMEOrder(context.Background(), "ip", "thor", order, nil) // entry is in "testdata/hosts"
+ if err != nil {
+ t.Errorf("%s: expected to successfully lookup host entry", name)
+ continue
+ }
+ if len(addrs) != 1 {
+ t.Errorf("%s: expected exactly one result, but got %v", name, addrs)
+ continue
+ }
+ if got, want := addrs[0].String(), "127.1.1.1"; got != want {
+ t.Errorf("%s: address doesn't match expectation. got %v, want %v", name, got, want)
+ }
+ }
+}
+
+// Issue 12712.
+// When using search domains, return the error encountered
+// querying the original name instead of an error encountered
+// querying a generated name.
+func TestErrorForOriginalNameWhenSearching(t *testing.T) {
+ defer dnsWaitGroup.Wait()
+
+ const fqdn = "doesnotexist.domain"
+
+ conf, err := newResolvConfTest()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conf.teardown()
+
+ if err := conf.writeAndUpdate([]string{"search servfail"}); err != nil {
+ t.Fatal(err)
+ }
+
+ fake := fakeDNSServer{rh: func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.ID,
+ Response: true,
+ },
+ Questions: q.Questions,
+ }
+
+ switch q.Questions[0].Name.String() {
+ case fqdn + ".servfail.":
+ r.Header.RCode = dnsmessage.RCodeServerFailure
+ default:
+ r.Header.RCode = dnsmessage.RCodeNameError
+ }
+
+ return r, nil
+ }}
+
+ cases := []struct {
+ strictErrors bool
+ wantErr *DNSError
+ }{
+ {true, &DNSError{Name: fqdn, Err: "server misbehaving", IsTemporary: true}},
+ {false, &DNSError{Name: fqdn, Err: errNoSuchHost.Error(), IsNotFound: true}},
+ }
+ for _, tt := range cases {
+ r := Resolver{PreferGo: true, StrictErrors: tt.strictErrors, Dial: fake.DialContext}
+ _, err = r.LookupIPAddr(context.Background(), fqdn)
+ if err == nil {
+ t.Fatal("expected an error")
+ }
+
+ want := tt.wantErr
+ if err, ok := err.(*DNSError); !ok || err.Name != want.Name || err.Err != want.Err || err.IsTemporary != want.IsTemporary {
+ t.Errorf("got %v; want %v", err, want)
+ }
+ }
+}
+
+// Issue 15434. If a name server gives a lame referral, continue to the next.
+func TestIgnoreLameReferrals(t *testing.T) {
+ defer dnsWaitGroup.Wait()
+
+ conf, err := newResolvConfTest()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conf.teardown()
+
+ if err := conf.writeAndUpdate([]string{"nameserver 192.0.2.1", // the one that will give a lame referral
+ "nameserver 192.0.2.2"}); err != nil {
+ t.Fatal(err)
+ }
+
+ fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ t.Log(s, q)
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.ID,
+ Response: true,
+ },
+ Questions: q.Questions,
+ }
+
+ if s == "192.0.2.2:53" {
+ r.Header.RecursionAvailable = true
+ if q.Questions[0].Type == dnsmessage.TypeA {
+ r.Answers = []dnsmessage.Resource{
+ {
+ Header: dnsmessage.ResourceHeader{
+ Name: q.Questions[0].Name,
+ Type: dnsmessage.TypeA,
+ Class: dnsmessage.ClassINET,
+ Length: 4,
+ },
+ Body: &dnsmessage.AResource{
+ A: TestAddr,
+ },
+ },
+ }
+ }
+ }
+
+ return r, nil
+ }}
+ r := Resolver{PreferGo: true, Dial: fake.DialContext}
+
+ addrs, err := r.LookupIPAddr(context.Background(), "www.golang.org")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if got := len(addrs); got != 1 {
+ t.Fatalf("got %d addresses, want 1", got)
+ }
+
+ if got, want := addrs[0].String(), "192.0.2.1"; got != want {
+ t.Fatalf("got address %v, want %v", got, want)
+ }
+}
+
+func BenchmarkGoLookupIP(b *testing.B) {
+ testHookUninstaller.Do(uninstallTestHooks)
+ ctx := context.Background()
+ b.ReportAllocs()
+
+ for i := 0; i < b.N; i++ {
+ goResolver.LookupIPAddr(ctx, "www.example.com")
+ }
+}
+
+func BenchmarkGoLookupIPNoSuchHost(b *testing.B) {
+ testHookUninstaller.Do(uninstallTestHooks)
+ ctx := context.Background()
+ b.ReportAllocs()
+
+ for i := 0; i < b.N; i++ {
+ goResolver.LookupIPAddr(ctx, "some.nonexistent")
+ }
+}
+
+func BenchmarkGoLookupIPWithBrokenNameServer(b *testing.B) {
+ testHookUninstaller.Do(uninstallTestHooks)
+
+ conf, err := newResolvConfTest()
+ if err != nil {
+ b.Fatal(err)
+ }
+ defer conf.teardown()
+
+ lines := []string{
+ "nameserver 203.0.113.254", // use TEST-NET-3 block, see RFC 5737
+ "nameserver 8.8.8.8",
+ }
+ if err := conf.writeAndUpdate(lines); err != nil {
+ b.Fatal(err)
+ }
+ ctx := context.Background()
+ b.ReportAllocs()
+
+ for i := 0; i < b.N; i++ {
+ goResolver.LookupIPAddr(ctx, "www.example.com")
+ }
+}
+
+type fakeDNSServer struct {
+ rh func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error)
+ alwaysTCP bool
+}
+
+func (server *fakeDNSServer) DialContext(_ context.Context, n, s string) (Conn, error) {
+ if server.alwaysTCP || n == "tcp" || n == "tcp4" || n == "tcp6" {
+ return &fakeDNSConn{tcp: true, server: server, n: n, s: s}, nil
+ }
+ return &fakeDNSPacketConn{fakeDNSConn: fakeDNSConn{tcp: false, server: server, n: n, s: s}}, nil
+}
+
+type fakeDNSConn struct {
+ Conn
+ tcp bool
+ server *fakeDNSServer
+ n string
+ s string
+ q dnsmessage.Message
+ t time.Time
+ buf []byte
+}
+
+func (f *fakeDNSConn) Close() error {
+ return nil
+}
+
+func (f *fakeDNSConn) Read(b []byte) (int, error) {
+ if len(f.buf) > 0 {
+ n := copy(b, f.buf)
+ f.buf = f.buf[n:]
+ return n, nil
+ }
+
+ resp, err := f.server.rh(f.n, f.s, f.q, f.t)
+ if err != nil {
+ return 0, err
+ }
+
+ bb := make([]byte, 2, 514)
+ bb, err = resp.AppendPack(bb)
+ if err != nil {
+ return 0, fmt.Errorf("cannot marshal DNS message: %v", err)
+ }
+
+ if f.tcp {
+ l := len(bb) - 2
+ bb[0] = byte(l >> 8)
+ bb[1] = byte(l)
+ f.buf = bb
+ return f.Read(b)
+ }
+
+ bb = bb[2:]
+ if len(b) < len(bb) {
+ return 0, errors.New("read would fragment DNS message")
+ }
+
+ copy(b, bb)
+ return len(bb), nil
+}
+
+func (f *fakeDNSConn) Write(b []byte) (int, error) {
+ if f.tcp && len(b) >= 2 {
+ b = b[2:]
+ }
+ if f.q.Unpack(b) != nil {
+ return 0, fmt.Errorf("cannot unmarshal DNS message fake %s (%d)", f.n, len(b))
+ }
+ return len(b), nil
+}
+
+func (f *fakeDNSConn) SetDeadline(t time.Time) error {
+ f.t = t
+ return nil
+}
+
+type fakeDNSPacketConn struct {
+ PacketConn
+ fakeDNSConn
+}
+
+func (f *fakeDNSPacketConn) SetDeadline(t time.Time) error {
+ return f.fakeDNSConn.SetDeadline(t)
+}
+
+func (f *fakeDNSPacketConn) Close() error {
+ return f.fakeDNSConn.Close()
+}
+
+// UDP round-tripper algorithm should ignore invalid DNS responses (issue 13281).
+func TestIgnoreDNSForgeries(t *testing.T) {
+ c, s := Pipe()
+ go func() {
+ b := make([]byte, maxDNSPacketSize)
+ n, err := s.Read(b)
+ if err != nil {
+ t.Error(err)
+ return
+ }
+
+ var msg dnsmessage.Message
+ if msg.Unpack(b[:n]) != nil {
+ t.Error("invalid DNS query:", err)
+ return
+ }
+
+ s.Write([]byte("garbage DNS response packet"))
+
+ msg.Header.Response = true
+ msg.Header.ID++ // make invalid ID
+
+ if b, err = msg.Pack(); err != nil {
+ t.Error("failed to pack DNS response:", err)
+ return
+ }
+ s.Write(b)
+
+ msg.Header.ID-- // restore original ID
+ msg.Answers = []dnsmessage.Resource{
+ {
+ Header: dnsmessage.ResourceHeader{
+ Name: mustNewName("www.example.com."),
+ Type: dnsmessage.TypeA,
+ Class: dnsmessage.ClassINET,
+ Length: 4,
+ },
+ Body: &dnsmessage.AResource{
+ A: TestAddr,
+ },
+ },
+ }
+
+ b, err = msg.Pack()
+ if err != nil {
+ t.Error("failed to pack DNS response:", err)
+ return
+ }
+ s.Write(b)
+ }()
+
+ msg := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: 42,
+ },
+ Questions: []dnsmessage.Question{
+ {
+ Name: mustNewName("www.example.com."),
+ Type: dnsmessage.TypeA,
+ Class: dnsmessage.ClassINET,
+ },
+ },
+ }
+
+ b, err := msg.Pack()
+ if err != nil {
+ t.Fatal("Pack failed:", err)
+ }
+
+ p, _, err := dnsPacketRoundTrip(c, 42, msg.Questions[0], b)
+ if err != nil {
+ t.Fatalf("dnsPacketRoundTrip failed: %v", err)
+ }
+
+ p.SkipAllQuestions()
+ as, err := p.AllAnswers()
+ if err != nil {
+ t.Fatal("AllAnswers failed:", err)
+ }
+ if got := as[0].Body.(*dnsmessage.AResource).A; got != TestAddr {
+ t.Errorf("got address %v, want %v", got, TestAddr)
+ }
+}
+
+// Issue 16865. If a name server times out, continue to the next.
+func TestRetryTimeout(t *testing.T) {
+ defer dnsWaitGroup.Wait()
+
+ conf, err := newResolvConfTest()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conf.teardown()
+
+ testConf := []string{
+ "nameserver 192.0.2.1", // the one that will timeout
+ "nameserver 192.0.2.2",
+ }
+ if err := conf.writeAndUpdate(testConf); err != nil {
+ t.Fatal(err)
+ }
+
+ var deadline0 time.Time
+
+ fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
+ t.Log(s, q, deadline)
+
+ if deadline.IsZero() {
+ t.Error("zero deadline")
+ }
+
+ if s == "192.0.2.1:53" {
+ deadline0 = deadline
+ time.Sleep(10 * time.Millisecond)
+ return dnsmessage.Message{}, os.ErrDeadlineExceeded
+ }
+
+ if deadline.Equal(deadline0) {
+ t.Error("deadline didn't change")
+ }
+
+ return mockTXTResponse(q), nil
+ }}
+ r := &Resolver{PreferGo: true, Dial: fake.DialContext}
+
+ _, err = r.LookupTXT(context.Background(), "www.golang.org")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if deadline0.IsZero() {
+ t.Error("deadline0 still zero", deadline0)
+ }
+}
+
+func TestRotate(t *testing.T) {
+ // without rotation, always uses the first server
+ testRotate(t, false, []string{"192.0.2.1", "192.0.2.2"}, []string{"192.0.2.1:53", "192.0.2.1:53", "192.0.2.1:53"})
+
+ // with rotation, rotates through back to first
+ testRotate(t, true, []string{"192.0.2.1", "192.0.2.2"}, []string{"192.0.2.1:53", "192.0.2.2:53", "192.0.2.1:53"})
+}
+
+func testRotate(t *testing.T, rotate bool, nameservers, wantServers []string) {
+ defer dnsWaitGroup.Wait()
+
+ conf, err := newResolvConfTest()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conf.teardown()
+
+ var confLines []string
+ for _, ns := range nameservers {
+ confLines = append(confLines, "nameserver "+ns)
+ }
+ if rotate {
+ confLines = append(confLines, "options rotate")
+ }
+
+ if err := conf.writeAndUpdate(confLines); err != nil {
+ t.Fatal(err)
+ }
+
+ var usedServers []string
+ fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
+ usedServers = append(usedServers, s)
+ return mockTXTResponse(q), nil
+ }}
+ r := Resolver{PreferGo: true, Dial: fake.DialContext}
+
+ // len(nameservers) + 1 to allow rotation to get back to start
+ for i := 0; i < len(nameservers)+1; i++ {
+ if _, err := r.LookupTXT(context.Background(), "www.golang.org"); err != nil {
+ t.Fatal(err)
+ }
+ }
+
+ if !reflect.DeepEqual(usedServers, wantServers) {
+ t.Errorf("rotate=%t got used servers:\n%v\nwant:\n%v", rotate, usedServers, wantServers)
+ }
+}
+
+func mockTXTResponse(q dnsmessage.Message) dnsmessage.Message {
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.ID,
+ Response: true,
+ RecursionAvailable: true,
+ },
+ Questions: q.Questions,
+ Answers: []dnsmessage.Resource{
+ {
+ Header: dnsmessage.ResourceHeader{
+ Name: q.Questions[0].Name,
+ Type: dnsmessage.TypeTXT,
+ Class: dnsmessage.ClassINET,
+ },
+ Body: &dnsmessage.TXTResource{
+ TXT: []string{"ok"},
+ },
+ },
+ },
+ }
+
+ return r
+}
+
+// Issue 17448. With StrictErrors enabled, temporary errors should make
+// LookupIP fail rather than return a partial result.
+func TestStrictErrorsLookupIP(t *testing.T) {
+ defer dnsWaitGroup.Wait()
+
+ conf, err := newResolvConfTest()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conf.teardown()
+
+ confData := []string{
+ "nameserver 192.0.2.53",
+ "search x.golang.org y.golang.org",
+ }
+ if err := conf.writeAndUpdate(confData); err != nil {
+ t.Fatal(err)
+ }
+
+ const name = "test-issue19592"
+ const server = "192.0.2.53:53"
+ const searchX = "test-issue19592.x.golang.org."
+ const searchY = "test-issue19592.y.golang.org."
+ const ip4 = "192.0.2.1"
+ const ip6 = "2001:db8::1"
+
+ type resolveWhichEnum int
+ const (
+ resolveOK resolveWhichEnum = iota
+ resolveOpError
+ resolveServfail
+ resolveTimeout
+ )
+
+ makeTempError := func(err string) error {
+ return &DNSError{
+ Err: err,
+ Name: name,
+ Server: server,
+ IsTemporary: true,
+ }
+ }
+ makeTimeout := func() error {
+ return &DNSError{
+ Err: os.ErrDeadlineExceeded.Error(),
+ Name: name,
+ Server: server,
+ IsTimeout: true,
+ }
+ }
+ makeNxDomain := func() error {
+ return &DNSError{
+ Err: errNoSuchHost.Error(),
+ Name: name,
+ Server: server,
+ IsNotFound: true,
+ }
+ }
+
+ cases := []struct {
+ desc string
+ resolveWhich func(quest dnsmessage.Question) resolveWhichEnum
+ wantStrictErr error
+ wantLaxErr error
+ wantIPs []string
+ }{
+ {
+ desc: "No errors",
+ resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
+ return resolveOK
+ },
+ wantIPs: []string{ip4, ip6},
+ },
+ {
+ desc: "searchX error fails in strict mode",
+ resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
+ if quest.Name.String() == searchX {
+ return resolveTimeout
+ }
+ return resolveOK
+ },
+ wantStrictErr: makeTimeout(),
+ wantIPs: []string{ip4, ip6},
+ },
+ {
+ desc: "searchX IPv4-only timeout fails in strict mode",
+ resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
+ if quest.Name.String() == searchX && quest.Type == dnsmessage.TypeA {
+ return resolveTimeout
+ }
+ return resolveOK
+ },
+ wantStrictErr: makeTimeout(),
+ wantIPs: []string{ip4, ip6},
+ },
+ {
+ desc: "searchX IPv6-only servfail fails in strict mode",
+ resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
+ if quest.Name.String() == searchX && quest.Type == dnsmessage.TypeAAAA {
+ return resolveServfail
+ }
+ return resolveOK
+ },
+ wantStrictErr: makeTempError("server misbehaving"),
+ wantIPs: []string{ip4, ip6},
+ },
+ {
+ desc: "searchY error always fails",
+ resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
+ if quest.Name.String() == searchY {
+ return resolveTimeout
+ }
+ return resolveOK
+ },
+ wantStrictErr: makeTimeout(),
+ wantLaxErr: makeNxDomain(), // This one reaches the "test." FQDN.
+ },
+ {
+ desc: "searchY IPv4-only socket error fails in strict mode",
+ resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
+ if quest.Name.String() == searchY && quest.Type == dnsmessage.TypeA {
+ return resolveOpError
+ }
+ return resolveOK
+ },
+ wantStrictErr: makeTempError("write: socket on fire"),
+ wantIPs: []string{ip6},
+ },
+ {
+ desc: "searchY IPv6-only timeout fails in strict mode",
+ resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
+ if quest.Name.String() == searchY && quest.Type == dnsmessage.TypeAAAA {
+ return resolveTimeout
+ }
+ return resolveOK
+ },
+ wantStrictErr: makeTimeout(),
+ wantIPs: []string{ip4},
+ },
+ }
+
+ for i, tt := range cases {
+ fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
+ t.Log(s, q)
+
+ switch tt.resolveWhich(q.Questions[0]) {
+ case resolveOK:
+ // Handle below.
+ case resolveOpError:
+ return dnsmessage.Message{}, &OpError{Op: "write", Err: fmt.Errorf("socket on fire")}
+ case resolveServfail:
+ return dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.ID,
+ Response: true,
+ RCode: dnsmessage.RCodeServerFailure,
+ },
+ Questions: q.Questions,
+ }, nil
+ case resolveTimeout:
+ return dnsmessage.Message{}, os.ErrDeadlineExceeded
+ default:
+ t.Fatal("Impossible resolveWhich")
+ }
+
+ switch q.Questions[0].Name.String() {
+ case searchX, name + ".":
+ // Return NXDOMAIN to utilize the search list.
+ return dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.ID,
+ Response: true,
+ RCode: dnsmessage.RCodeNameError,
+ },
+ Questions: q.Questions,
+ }, nil
+ case searchY:
+ // Return records below.
+ default:
+ return dnsmessage.Message{}, fmt.Errorf("Unexpected Name: %v", q.Questions[0].Name)
+ }
+
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.ID,
+ Response: true,
+ },
+ Questions: q.Questions,
+ }
+ switch q.Questions[0].Type {
+ case dnsmessage.TypeA:
+ r.Answers = []dnsmessage.Resource{
+ {
+ Header: dnsmessage.ResourceHeader{
+ Name: q.Questions[0].Name,
+ Type: dnsmessage.TypeA,
+ Class: dnsmessage.ClassINET,
+ Length: 4,
+ },
+ Body: &dnsmessage.AResource{
+ A: TestAddr,
+ },
+ },
+ }
+ case dnsmessage.TypeAAAA:
+ r.Answers = []dnsmessage.Resource{
+ {
+ Header: dnsmessage.ResourceHeader{
+ Name: q.Questions[0].Name,
+ Type: dnsmessage.TypeAAAA,
+ Class: dnsmessage.ClassINET,
+ Length: 16,
+ },
+ Body: &dnsmessage.AAAAResource{
+ AAAA: TestAddr6,
+ },
+ },
+ }
+ default:
+ return dnsmessage.Message{}, fmt.Errorf("Unexpected Type: %v", q.Questions[0].Type)
+ }
+ return r, nil
+ }}
+
+ for _, strict := range []bool{true, false} {
+ r := Resolver{PreferGo: true, StrictErrors: strict, Dial: fake.DialContext}
+ ips, err := r.LookupIPAddr(context.Background(), name)
+
+ var wantErr error
+ if strict {
+ wantErr = tt.wantStrictErr
+ } else {
+ wantErr = tt.wantLaxErr
+ }
+ if !reflect.DeepEqual(err, wantErr) {
+ t.Errorf("#%d (%s) strict=%v: got err %#v; want %#v", i, tt.desc, strict, err, wantErr)
+ }
+
+ gotIPs := map[string]struct{}{}
+ for _, ip := range ips {
+ gotIPs[ip.String()] = struct{}{}
+ }
+ wantIPs := map[string]struct{}{}
+ if wantErr == nil {
+ for _, ip := range tt.wantIPs {
+ wantIPs[ip] = struct{}{}
+ }
+ }
+ if !reflect.DeepEqual(gotIPs, wantIPs) {
+ t.Errorf("#%d (%s) strict=%v: got ips %v; want %v", i, tt.desc, strict, gotIPs, wantIPs)
+ }
+ }
+ }
+}
+
+// Issue 17448. With StrictErrors enabled, temporary errors should make
+// LookupTXT stop walking the search list.
+func TestStrictErrorsLookupTXT(t *testing.T) {
+ defer dnsWaitGroup.Wait()
+
+ conf, err := newResolvConfTest()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conf.teardown()
+
+ confData := []string{
+ "nameserver 192.0.2.53",
+ "search x.golang.org y.golang.org",
+ }
+ if err := conf.writeAndUpdate(confData); err != nil {
+ t.Fatal(err)
+ }
+
+ const name = "test"
+ const server = "192.0.2.53:53"
+ const searchX = "test.x.golang.org."
+ const searchY = "test.y.golang.org."
+ const txt = "Hello World"
+
+ fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
+ t.Log(s, q)
+
+ switch q.Questions[0].Name.String() {
+ case searchX:
+ return dnsmessage.Message{}, os.ErrDeadlineExceeded
+ case searchY:
+ return mockTXTResponse(q), nil
+ default:
+ return dnsmessage.Message{}, fmt.Errorf("Unexpected Name: %v", q.Questions[0].Name)
+ }
+ }}
+
+ for _, strict := range []bool{true, false} {
+ r := Resolver{StrictErrors: strict, Dial: fake.DialContext}
+ p, _, err := r.lookup(context.Background(), name, dnsmessage.TypeTXT, nil)
+ var wantErr error
+ var wantRRs int
+ if strict {
+ wantErr = &DNSError{
+ Err: os.ErrDeadlineExceeded.Error(),
+ Name: name,
+ Server: server,
+ IsTimeout: true,
+ }
+ } else {
+ wantRRs = 1
+ }
+ if !reflect.DeepEqual(err, wantErr) {
+ t.Errorf("strict=%v: got err %#v; want %#v", strict, err, wantErr)
+ }
+ a, err := p.AllAnswers()
+ if err != nil {
+ a = nil
+ }
+ if len(a) != wantRRs {
+ t.Errorf("strict=%v: got %v; want %v", strict, len(a), wantRRs)
+ }
+ }
+}
+
+// Test for a race between uninstalling the test hooks and closing a
+// socket connection. This used to fail when testing with -race.
+func TestDNSGoroutineRace(t *testing.T) {
+ defer dnsWaitGroup.Wait()
+
+ fake := fakeDNSServer{rh: func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error) {
+ time.Sleep(10 * time.Microsecond)
+ return dnsmessage.Message{}, os.ErrDeadlineExceeded
+ }}
+ r := Resolver{PreferGo: true, Dial: fake.DialContext}
+
+ // The timeout here is less than the timeout used by the server,
+ // so the goroutine started to query the (fake) server will hang
+ // around after this test is done if we don't call dnsWaitGroup.Wait.
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Microsecond)
+ defer cancel()
+ _, err := r.LookupIPAddr(ctx, "where.are.they.now")
+ if err == nil {
+ t.Fatal("fake DNS lookup unexpectedly succeeded")
+ }
+}
+
+func lookupWithFake(fake fakeDNSServer, name string, typ dnsmessage.Type) error {
+ r := Resolver{PreferGo: true, Dial: fake.DialContext}
+
+ conf := getSystemDNSConfig()
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ _, _, err := r.tryOneName(ctx, conf, name, typ)
+ return err
+}
+
+// Issue 8434: verify that Temporary returns true on an error when rcode
+// is SERVFAIL
+func TestIssue8434(t *testing.T) {
+ err := lookupWithFake(fakeDNSServer{
+ rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ return dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.ID,
+ Response: true,
+ RCode: dnsmessage.RCodeServerFailure,
+ },
+ Questions: q.Questions,
+ }, nil
+ },
+ }, "golang.org.", dnsmessage.TypeALL)
+ if err == nil {
+ t.Fatal("expected an error")
+ }
+ if ne, ok := err.(Error); !ok {
+ t.Fatalf("err = %#v; wanted something supporting net.Error", err)
+ } else if !ne.Temporary() {
+ t.Fatalf("Temporary = false for err = %#v; want Temporary == true", err)
+ }
+ if de, ok := err.(*DNSError); !ok {
+ t.Fatalf("err = %#v; wanted a *net.DNSError", err)
+ } else if !de.IsTemporary {
+ t.Fatalf("IsTemporary = false for err = %#v; want IsTemporary == true", err)
+ }
+}
+
+func TestIssueNoSuchHostExists(t *testing.T) {
+ err := lookupWithFake(fakeDNSServer{
+ rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ return dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.ID,
+ Response: true,
+ RCode: dnsmessage.RCodeNameError,
+ },
+ Questions: q.Questions,
+ }, nil
+ },
+ }, "golang.org.", dnsmessage.TypeALL)
+ if err == nil {
+ t.Fatal("expected an error")
+ }
+ if _, ok := err.(Error); !ok {
+ t.Fatalf("err = %#v; wanted something supporting net.Error", err)
+ }
+ if de, ok := err.(*DNSError); !ok {
+ t.Fatalf("err = %#v; wanted a *net.DNSError", err)
+ } else if !de.IsNotFound {
+ t.Fatalf("IsNotFound = false for err = %#v; want IsNotFound == true", err)
+ }
+}
+
+// TestNoSuchHost verifies that tryOneName works correctly when the domain does
+// not exist.
+//
+// Issue 12778: verify that NXDOMAIN without RA bit errors as "no such host"
+// and not "server misbehaving"
+//
+// Issue 25336: verify that NXDOMAIN errors fail fast.
+//
+// Issue 27525: verify that empty answers fail fast.
+func TestNoSuchHost(t *testing.T) {
+ tests := []struct {
+ name string
+ f func(string, string, dnsmessage.Message, time.Time) (dnsmessage.Message, error)
+ }{
+ {
+ "NXDOMAIN",
+ func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ return dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.ID,
+ Response: true,
+ RCode: dnsmessage.RCodeNameError,
+ RecursionAvailable: false,
+ },
+ Questions: q.Questions,
+ }, nil
+ },
+ },
+ {
+ "no answers",
+ func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ return dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.ID,
+ Response: true,
+ RCode: dnsmessage.RCodeSuccess,
+ RecursionAvailable: false,
+ Authoritative: true,
+ },
+ Questions: q.Questions,
+ }, nil
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ lookups := 0
+ err := lookupWithFake(fakeDNSServer{
+ rh: func(n, s string, q dnsmessage.Message, d time.Time) (dnsmessage.Message, error) {
+ lookups++
+ return test.f(n, s, q, d)
+ },
+ }, ".", dnsmessage.TypeALL)
+
+ if lookups != 1 {
+ t.Errorf("got %d lookups, wanted 1", lookups)
+ }
+
+ if err == nil {
+ t.Fatal("expected an error")
+ }
+ de, ok := err.(*DNSError)
+ if !ok {
+ t.Fatalf("err = %#v; wanted a *net.DNSError", err)
+ }
+ if de.Err != errNoSuchHost.Error() {
+ t.Fatalf("Err = %#v; wanted %q", de.Err, errNoSuchHost.Error())
+ }
+ if !de.IsNotFound {
+ t.Fatalf("IsNotFound = %v wanted true", de.IsNotFound)
+ }
+ })
+ }
+}
+
+// Issue 26573: verify that Conns that don't implement PacketConn are treated
+// as streams even when udp was requested.
+func TestDNSDialTCP(t *testing.T) {
+ fake := fakeDNSServer{
+ rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.Header.ID,
+ Response: true,
+ RCode: dnsmessage.RCodeSuccess,
+ },
+ Questions: q.Questions,
+ }
+ return r, nil
+ },
+ alwaysTCP: true,
+ }
+ r := Resolver{PreferGo: true, Dial: fake.DialContext}
+ ctx := context.Background()
+ _, _, err := r.exchange(ctx, "0.0.0.0", mustQuestion("com.", dnsmessage.TypeALL, dnsmessage.ClassINET), time.Second, useUDPOrTCP, false)
+ if err != nil {
+ t.Fatal("exchange failed:", err)
+ }
+}
+
+// Issue 27763: verify that two strings in one TXT record are concatenated.
+func TestTXTRecordTwoStrings(t *testing.T) {
+ fake := fakeDNSServer{
+ rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.Header.ID,
+ Response: true,
+ RCode: dnsmessage.RCodeSuccess,
+ },
+ Questions: q.Questions,
+ Answers: []dnsmessage.Resource{
+ {
+ Header: dnsmessage.ResourceHeader{
+ Name: q.Questions[0].Name,
+ Type: dnsmessage.TypeA,
+ Class: dnsmessage.ClassINET,
+ },
+ Body: &dnsmessage.TXTResource{
+ TXT: []string{"string1 ", "string2"},
+ },
+ },
+ {
+ Header: dnsmessage.ResourceHeader{
+ Name: q.Questions[0].Name,
+ Type: dnsmessage.TypeA,
+ Class: dnsmessage.ClassINET,
+ },
+ Body: &dnsmessage.TXTResource{
+ TXT: []string{"onestring"},
+ },
+ },
+ },
+ }
+ return r, nil
+ },
+ }
+ r := Resolver{PreferGo: true, Dial: fake.DialContext}
+ txt, err := r.lookupTXT(context.Background(), "golang.org")
+ if err != nil {
+ t.Fatal("LookupTXT failed:", err)
+ }
+ if want := 2; len(txt) != want {
+ t.Fatalf("len(txt), got %d, want %d", len(txt), want)
+ }
+ if want := "string1 string2"; txt[0] != want {
+ t.Errorf("txt[0], got %q, want %q", txt[0], want)
+ }
+ if want := "onestring"; txt[1] != want {
+ t.Errorf("txt[1], got %q, want %q", txt[1], want)
+ }
+}
+
+// Issue 29644: support single-request resolv.conf option in pure Go resolver.
+// The A and AAAA queries will be sent sequentially, not in parallel.
+func TestSingleRequestLookup(t *testing.T) {
+ defer dnsWaitGroup.Wait()
+ var (
+ firstcalled int32
+ ipv4 int32 = 1
+ ipv6 int32 = 2
+ )
+ fake := fakeDNSServer{rh: func(n, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.ID,
+ Response: true,
+ },
+ Questions: q.Questions,
+ }
+ for _, question := range q.Questions {
+ switch question.Type {
+ case dnsmessage.TypeA:
+ if question.Name.String() == "slowipv4.example.net." {
+ time.Sleep(10 * time.Millisecond)
+ }
+ if !atomic.CompareAndSwapInt32(&firstcalled, 0, ipv4) {
+ t.Errorf("the A query was received after the AAAA query !")
+ }
+ r.Answers = append(r.Answers, dnsmessage.Resource{
+ Header: dnsmessage.ResourceHeader{
+ Name: q.Questions[0].Name,
+ Type: dnsmessage.TypeA,
+ Class: dnsmessage.ClassINET,
+ Length: 4,
+ },
+ Body: &dnsmessage.AResource{
+ A: TestAddr,
+ },
+ })
+ case dnsmessage.TypeAAAA:
+ atomic.CompareAndSwapInt32(&firstcalled, 0, ipv6)
+ r.Answers = append(r.Answers, dnsmessage.Resource{
+ Header: dnsmessage.ResourceHeader{
+ Name: q.Questions[0].Name,
+ Type: dnsmessage.TypeAAAA,
+ Class: dnsmessage.ClassINET,
+ Length: 16,
+ },
+ Body: &dnsmessage.AAAAResource{
+ AAAA: TestAddr6,
+ },
+ })
+ }
+ }
+ return r, nil
+ }}
+ r := Resolver{PreferGo: true, Dial: fake.DialContext}
+
+ conf, err := newResolvConfTest()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conf.teardown()
+ if err := conf.writeAndUpdate([]string{"options single-request"}); err != nil {
+ t.Fatal(err)
+ }
+ for _, name := range []string{"hostname.example.net", "slowipv4.example.net"} {
+ firstcalled = 0
+ _, err := r.LookupIPAddr(context.Background(), name)
+ if err != nil {
+ t.Error(err)
+ }
+ }
+}
+
+// Issue 29358. Add configuration knob to force TCP-only DNS requests in the pure Go resolver.
+func TestDNSUseTCP(t *testing.T) {
+ fake := fakeDNSServer{
+ rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.Header.ID,
+ Response: true,
+ RCode: dnsmessage.RCodeSuccess,
+ },
+ Questions: q.Questions,
+ }
+ if n == "udp" {
+ t.Fatal("udp protocol was used instead of tcp")
+ }
+ return r, nil
+ },
+ }
+ r := Resolver{PreferGo: true, Dial: fake.DialContext}
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ _, _, err := r.exchange(ctx, "0.0.0.0", mustQuestion("com.", dnsmessage.TypeALL, dnsmessage.ClassINET), time.Second, useTCPOnly, false)
+ if err != nil {
+ t.Fatal("exchange failed:", err)
+ }
+}
+
+// Issue 34660: PTR response with non-PTR answers should ignore non-PTR
+func TestPTRandNonPTR(t *testing.T) {
+ fake := fakeDNSServer{
+ rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.Header.ID,
+ Response: true,
+ RCode: dnsmessage.RCodeSuccess,
+ },
+ Questions: q.Questions,
+ Answers: []dnsmessage.Resource{
+ {
+ Header: dnsmessage.ResourceHeader{
+ Name: q.Questions[0].Name,
+ Type: dnsmessage.TypePTR,
+ Class: dnsmessage.ClassINET,
+ },
+ Body: &dnsmessage.PTRResource{
+ PTR: dnsmessage.MustNewName("golang.org."),
+ },
+ },
+ {
+ Header: dnsmessage.ResourceHeader{
+ Name: q.Questions[0].Name,
+ Type: dnsmessage.TypeTXT,
+ Class: dnsmessage.ClassINET,
+ },
+ Body: &dnsmessage.TXTResource{
+ TXT: []string{"PTR 8 6 60 ..."}, // fake RRSIG
+ },
+ },
+ },
+ }
+ return r, nil
+ },
+ }
+ r := Resolver{PreferGo: true, Dial: fake.DialContext}
+ names, err := r.lookupAddr(context.Background(), "192.0.2.123")
+ if err != nil {
+ t.Fatalf("LookupAddr: %v", err)
+ }
+ if want := []string{"golang.org."}; !reflect.DeepEqual(names, want) {
+ t.Errorf("names = %q; want %q", names, want)
+ }
+}
+
+func TestCVE202133195(t *testing.T) {
+ fake := fakeDNSServer{
+ rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.Header.ID,
+ Response: true,
+ RCode: dnsmessage.RCodeSuccess,
+ RecursionAvailable: true,
+ },
+ Questions: q.Questions,
+ }
+ switch q.Questions[0].Type {
+ case dnsmessage.TypeCNAME:
+ r.Answers = []dnsmessage.Resource{}
+ case dnsmessage.TypeA: // CNAME lookup uses a A/AAAA as a proxy
+ r.Answers = append(r.Answers,
+ dnsmessage.Resource{
+ Header: dnsmessage.ResourceHeader{
+ Name: dnsmessage.MustNewName("<html>.golang.org."),
+ Type: dnsmessage.TypeA,
+ Class: dnsmessage.ClassINET,
+ Length: 4,
+ },
+ Body: &dnsmessage.AResource{
+ A: TestAddr,
+ },
+ },
+ )
+ case dnsmessage.TypeSRV:
+ n := q.Questions[0].Name
+ if n.String() == "_hdr._tcp.golang.org." {
+ n = dnsmessage.MustNewName("<html>.golang.org.")
+ }
+ r.Answers = append(r.Answers,
+ dnsmessage.Resource{
+ Header: dnsmessage.ResourceHeader{
+ Name: n,
+ Type: dnsmessage.TypeSRV,
+ Class: dnsmessage.ClassINET,
+ Length: 4,
+ },
+ Body: &dnsmessage.SRVResource{
+ Target: dnsmessage.MustNewName("<html>.golang.org."),
+ },
+ },
+ dnsmessage.Resource{
+ Header: dnsmessage.ResourceHeader{
+ Name: n,
+ Type: dnsmessage.TypeSRV,
+ Class: dnsmessage.ClassINET,
+ Length: 4,
+ },
+ Body: &dnsmessage.SRVResource{
+ Target: dnsmessage.MustNewName("good.golang.org."),
+ },
+ },
+ )
+ case dnsmessage.TypeMX:
+ r.Answers = append(r.Answers,
+ dnsmessage.Resource{
+ Header: dnsmessage.ResourceHeader{
+ Name: dnsmessage.MustNewName("<html>.golang.org."),
+ Type: dnsmessage.TypeMX,
+ Class: dnsmessage.ClassINET,
+ Length: 4,
+ },
+ Body: &dnsmessage.MXResource{
+ MX: dnsmessage.MustNewName("<html>.golang.org."),
+ },
+ },
+ dnsmessage.Resource{
+ Header: dnsmessage.ResourceHeader{
+ Name: dnsmessage.MustNewName("good.golang.org."),
+ Type: dnsmessage.TypeMX,
+ Class: dnsmessage.ClassINET,
+ Length: 4,
+ },
+ Body: &dnsmessage.MXResource{
+ MX: dnsmessage.MustNewName("good.golang.org."),
+ },
+ },
+ )
+ case dnsmessage.TypeNS:
+ r.Answers = append(r.Answers,
+ dnsmessage.Resource{
+ Header: dnsmessage.ResourceHeader{
+ Name: dnsmessage.MustNewName("<html>.golang.org."),
+ Type: dnsmessage.TypeNS,
+ Class: dnsmessage.ClassINET,
+ Length: 4,
+ },
+ Body: &dnsmessage.NSResource{
+ NS: dnsmessage.MustNewName("<html>.golang.org."),
+ },
+ },
+ dnsmessage.Resource{
+ Header: dnsmessage.ResourceHeader{
+ Name: dnsmessage.MustNewName("good.golang.org."),
+ Type: dnsmessage.TypeNS,
+ Class: dnsmessage.ClassINET,
+ Length: 4,
+ },
+ Body: &dnsmessage.NSResource{
+ NS: dnsmessage.MustNewName("good.golang.org."),
+ },
+ },
+ )
+ case dnsmessage.TypePTR:
+ r.Answers = append(r.Answers,
+ dnsmessage.Resource{
+ Header: dnsmessage.ResourceHeader{
+ Name: dnsmessage.MustNewName("<html>.golang.org."),
+ Type: dnsmessage.TypePTR,
+ Class: dnsmessage.ClassINET,
+ Length: 4,
+ },
+ Body: &dnsmessage.PTRResource{
+ PTR: dnsmessage.MustNewName("<html>.golang.org."),
+ },
+ },
+ dnsmessage.Resource{
+ Header: dnsmessage.ResourceHeader{
+ Name: dnsmessage.MustNewName("good.golang.org."),
+ Type: dnsmessage.TypePTR,
+ Class: dnsmessage.ClassINET,
+ Length: 4,
+ },
+ Body: &dnsmessage.PTRResource{
+ PTR: dnsmessage.MustNewName("good.golang.org."),
+ },
+ },
+ )
+ }
+ return r, nil
+ },
+ }
+
+ r := Resolver{PreferGo: true, Dial: fake.DialContext}
+ // Change the default resolver to match our manipulated resolver
+ originalDefault := DefaultResolver
+ DefaultResolver = &r
+ defer func() { DefaultResolver = originalDefault }()
+ // Redirect host file lookups.
+ defer func(orig string) { testHookHostsPath = orig }(testHookHostsPath)
+ testHookHostsPath = "testdata/hosts"
+
+ tests := []struct {
+ name string
+ f func(*testing.T)
+ }{
+ {
+ name: "CNAME",
+ f: func(t *testing.T) {
+ expectedErr := &DNSError{Err: errMalformedDNSRecordsDetail, Name: "golang.org"}
+ _, err := r.LookupCNAME(context.Background(), "golang.org")
+ if err.Error() != expectedErr.Error() {
+ t.Fatalf("unexpected error: %s", err)
+ }
+ _, err = LookupCNAME("golang.org")
+ if err.Error() != expectedErr.Error() {
+ t.Fatalf("unexpected error: %s", err)
+ }
+ },
+ },
+ {
+ name: "SRV (bad record)",
+ f: func(t *testing.T) {
+ expected := []*SRV{
+ {
+ Target: "good.golang.org.",
+ },
+ }
+ expectedErr := &DNSError{Err: errMalformedDNSRecordsDetail, Name: "golang.org"}
+ _, records, err := r.LookupSRV(context.Background(), "target", "tcp", "golang.org")
+ if err.Error() != expectedErr.Error() {
+ t.Fatalf("unexpected error: %s", err)
+ }
+ if !reflect.DeepEqual(records, expected) {
+ t.Error("Unexpected record set")
+ }
+ _, records, err = LookupSRV("target", "tcp", "golang.org")
+ if err.Error() != expectedErr.Error() {
+ t.Errorf("unexpected error: %s", err)
+ }
+ if !reflect.DeepEqual(records, expected) {
+ t.Error("Unexpected record set")
+ }
+ },
+ },
+ {
+ name: "SRV (bad header)",
+ f: func(t *testing.T) {
+ _, _, err := r.LookupSRV(context.Background(), "hdr", "tcp", "golang.org.")
+ if expected := "lookup golang.org.: SRV header name is invalid"; err == nil || err.Error() != expected {
+ t.Errorf("Resolver.LookupSRV returned unexpected error, got %q, want %q", err, expected)
+ }
+ _, _, err = LookupSRV("hdr", "tcp", "golang.org.")
+ if expected := "lookup golang.org.: SRV header name is invalid"; err == nil || err.Error() != expected {
+ t.Errorf("LookupSRV returned unexpected error, got %q, want %q", err, expected)
+ }
+ },
+ },
+ {
+ name: "MX",
+ f: func(t *testing.T) {
+ expected := []*MX{
+ {
+ Host: "good.golang.org.",
+ },
+ }
+ expectedErr := &DNSError{Err: errMalformedDNSRecordsDetail, Name: "golang.org"}
+ records, err := r.LookupMX(context.Background(), "golang.org")
+ if err.Error() != expectedErr.Error() {
+ t.Fatalf("unexpected error: %s", err)
+ }
+ if !reflect.DeepEqual(records, expected) {
+ t.Error("Unexpected record set")
+ }
+ records, err = LookupMX("golang.org")
+ if err.Error() != expectedErr.Error() {
+ t.Fatalf("unexpected error: %s", err)
+ }
+ if !reflect.DeepEqual(records, expected) {
+ t.Error("Unexpected record set")
+ }
+ },
+ },
+ {
+ name: "NS",
+ f: func(t *testing.T) {
+ expected := []*NS{
+ {
+ Host: "good.golang.org.",
+ },
+ }
+ expectedErr := &DNSError{Err: errMalformedDNSRecordsDetail, Name: "golang.org"}
+ records, err := r.LookupNS(context.Background(), "golang.org")
+ if err.Error() != expectedErr.Error() {
+ t.Fatalf("unexpected error: %s", err)
+ }
+ if !reflect.DeepEqual(records, expected) {
+ t.Error("Unexpected record set")
+ }
+ records, err = LookupNS("golang.org")
+ if err.Error() != expectedErr.Error() {
+ t.Fatalf("unexpected error: %s", err)
+ }
+ if !reflect.DeepEqual(records, expected) {
+ t.Error("Unexpected record set")
+ }
+ },
+ },
+ {
+ name: "Addr",
+ f: func(t *testing.T) {
+ expected := []string{"good.golang.org."}
+ expectedErr := &DNSError{Err: errMalformedDNSRecordsDetail, Name: "192.0.2.42"}
+ records, err := r.LookupAddr(context.Background(), "192.0.2.42")
+ if err.Error() != expectedErr.Error() {
+ t.Fatalf("unexpected error: %s", err)
+ }
+ if !reflect.DeepEqual(records, expected) {
+ t.Error("Unexpected record set")
+ }
+ records, err = LookupAddr("192.0.2.42")
+ if err.Error() != expectedErr.Error() {
+ t.Fatalf("unexpected error: %s", err)
+ }
+ if !reflect.DeepEqual(records, expected) {
+ t.Error("Unexpected record set")
+ }
+ },
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, tc.f)
+ }
+
+}
+
+func TestNullMX(t *testing.T) {
+ fake := fakeDNSServer{
+ rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.Header.ID,
+ Response: true,
+ RCode: dnsmessage.RCodeSuccess,
+ },
+ Questions: q.Questions,
+ Answers: []dnsmessage.Resource{
+ {
+ Header: dnsmessage.ResourceHeader{
+ Name: q.Questions[0].Name,
+ Type: dnsmessage.TypeMX,
+ Class: dnsmessage.ClassINET,
+ },
+ Body: &dnsmessage.MXResource{
+ MX: dnsmessage.MustNewName("."),
+ },
+ },
+ },
+ }
+ return r, nil
+ },
+ }
+ r := Resolver{PreferGo: true, Dial: fake.DialContext}
+ rrset, err := r.LookupMX(context.Background(), "golang.org")
+ if err != nil {
+ t.Fatalf("LookupMX: %v", err)
+ }
+ if want := []*MX{&MX{Host: "."}}; !reflect.DeepEqual(rrset, want) {
+ records := []string{}
+ for _, rr := range rrset {
+ records = append(records, fmt.Sprintf("%v", rr))
+ }
+ t.Errorf("records = [%v]; want [%v]", strings.Join(records, " "), want[0])
+ }
+}
+
+func TestRootNS(t *testing.T) {
+ // See https://golang.org/issue/45715.
+ fake := fakeDNSServer{
+ rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.Header.ID,
+ Response: true,
+ RCode: dnsmessage.RCodeSuccess,
+ },
+ Questions: q.Questions,
+ Answers: []dnsmessage.Resource{
+ {
+ Header: dnsmessage.ResourceHeader{
+ Name: q.Questions[0].Name,
+ Type: dnsmessage.TypeNS,
+ Class: dnsmessage.ClassINET,
+ },
+ Body: &dnsmessage.NSResource{
+ NS: dnsmessage.MustNewName("i.root-servers.net."),
+ },
+ },
+ },
+ }
+ return r, nil
+ },
+ }
+ r := Resolver{PreferGo: true, Dial: fake.DialContext}
+ rrset, err := r.LookupNS(context.Background(), ".")
+ if err != nil {
+ t.Fatalf("LookupNS: %v", err)
+ }
+ if want := []*NS{&NS{Host: "i.root-servers.net."}}; !reflect.DeepEqual(rrset, want) {
+ records := []string{}
+ for _, rr := range rrset {
+ records = append(records, fmt.Sprintf("%v", rr))
+ }
+ t.Errorf("records = [%v]; want [%v]", strings.Join(records, " "), want[0])
+ }
+}
+
+func TestGoLookupIPCNAMEOrderHostsAliasesFilesOnlyMode(t *testing.T) {
+ defer func(orig string) { testHookHostsPath = orig }(testHookHostsPath)
+ testHookHostsPath = "testdata/aliases"
+ mode := hostLookupFiles
+
+ for _, v := range lookupStaticHostAliasesTest {
+ testGoLookupIPCNAMEOrderHostsAliases(t, mode, v.lookup, absDomainName(v.res))
+ }
+}
+
+func TestGoLookupIPCNAMEOrderHostsAliasesFilesDNSMode(t *testing.T) {
+ defer func(orig string) { testHookHostsPath = orig }(testHookHostsPath)
+ testHookHostsPath = "testdata/aliases"
+ mode := hostLookupFilesDNS
+
+ for _, v := range lookupStaticHostAliasesTest {
+ testGoLookupIPCNAMEOrderHostsAliases(t, mode, v.lookup, absDomainName(v.res))
+ }
+}
+
+var goLookupIPCNAMEOrderDNSFilesModeTests = []struct {
+ lookup, res string
+}{
+ // 127.0.1.1
+ {"invalid.invalid", "invalid.test"},
+}
+
+func TestGoLookupIPCNAMEOrderHostsAliasesDNSFilesMode(t *testing.T) {
+ if testenv.Builder() == "" {
+ t.Skip("Makes assumptions about local networks and (re)naming that aren't always true")
+ }
+ defer func(orig string) { testHookHostsPath = orig }(testHookHostsPath)
+ testHookHostsPath = "testdata/aliases"
+ mode := hostLookupDNSFiles
+
+ for _, v := range goLookupIPCNAMEOrderDNSFilesModeTests {
+ testGoLookupIPCNAMEOrderHostsAliases(t, mode, v.lookup, absDomainName(v.res))
+ }
+}
+
+func testGoLookupIPCNAMEOrderHostsAliases(t *testing.T, mode hostLookupOrder, lookup, lookupRes string) {
+ ins := []string{lookup, absDomainName(lookup), strings.ToLower(lookup), strings.ToUpper(lookup)}
+ for _, in := range ins {
+ _, res, err := goResolver.goLookupIPCNAMEOrder(context.Background(), "ip", in, mode, nil)
+ if err != nil {
+ t.Errorf("expected err == nil, but got error: %v", err)
+ }
+ if res.String() != lookupRes {
+ t.Errorf("goLookupIPCNAMEOrder(%v): got %v, want %v", in, res, lookupRes)
+ }
+ }
+}
+
+// Test that we advertise support for a larger DNS packet size.
+// This isn't a great test as it just tests the dnsmessage package
+// against itself.
+func TestDNSPacketSize(t *testing.T) {
+ fake := fakeDNSServer{
+ rh: func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ if len(q.Additionals) == 0 {
+ t.Error("missing EDNS record")
+ } else if opt, ok := q.Additionals[0].Body.(*dnsmessage.OPTResource); !ok {
+ t.Errorf("additional record type %T, expected OPTResource", q.Additionals[0])
+ } else if len(opt.Options) != 0 {
+ t.Errorf("found %d Options, expected none", len(opt.Options))
+ } else {
+ got := int(q.Additionals[0].Header.Class)
+ t.Logf("EDNS packet size == %d", got)
+ if got != maxDNSPacketSize {
+ t.Errorf("EDNS packet size == %d, want %d", got, maxDNSPacketSize)
+ }
+ }
+
+ // Hand back a dummy answer to verify that
+ // LookupIPAddr completes.
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.Header.ID,
+ Response: true,
+ RCode: dnsmessage.RCodeSuccess,
+ },
+ Questions: q.Questions,
+ }
+ if q.Questions[0].Type == dnsmessage.TypeA {
+ r.Answers = []dnsmessage.Resource{
+ {
+ Header: dnsmessage.ResourceHeader{
+ Name: q.Questions[0].Name,
+ Type: dnsmessage.TypeA,
+ Class: dnsmessage.ClassINET,
+ Length: 4,
+ },
+ Body: &dnsmessage.AResource{
+ A: TestAddr,
+ },
+ },
+ }
+ }
+ return r, nil
+ },
+ }
+
+ r := &Resolver{PreferGo: true, Dial: fake.DialContext}
+ if _, err := r.LookupIPAddr(context.Background(), "go.dev"); err != nil {
+ t.Errorf("lookup failed: %v", err)
+ }
+}
+
+func TestLongDNSNames(t *testing.T) {
+ const longDNSsuffix = ".go.dev."
+ const longDNSsuffixNoEndingDot = ".go.dev"
+
+ var longDNSPrefix = strings.Repeat("verylongdomainlabel.", 20)
+
+ var longDNSNamesTests = []struct {
+ req string
+ fail bool
+ }{
+ {req: longDNSPrefix[:255-len(longDNSsuffix)] + longDNSsuffix, fail: true},
+ {req: longDNSPrefix[:254-len(longDNSsuffix)] + longDNSsuffix},
+ {req: longDNSPrefix[:253-len(longDNSsuffix)] + longDNSsuffix},
+
+ {req: longDNSPrefix[:253-len(longDNSsuffixNoEndingDot)] + longDNSsuffixNoEndingDot},
+ {req: longDNSPrefix[:254-len(longDNSsuffixNoEndingDot)] + longDNSsuffixNoEndingDot, fail: true},
+ }
+
+ fake := fakeDNSServer{
+ rh: func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.Header.ID,
+ Response: true,
+ RCode: dnsmessage.RCodeSuccess,
+ },
+ Questions: q.Questions,
+ Answers: []dnsmessage.Resource{
+ {
+ Header: dnsmessage.ResourceHeader{
+ Name: q.Questions[0].Name,
+ Type: q.Questions[0].Type,
+ Class: dnsmessage.ClassINET,
+ },
+ },
+ },
+ }
+
+ switch q.Questions[0].Type {
+ case dnsmessage.TypeA:
+ r.Answers[0].Body = &dnsmessage.AResource{A: TestAddr}
+ case dnsmessage.TypeAAAA:
+ r.Answers[0].Body = &dnsmessage.AAAAResource{AAAA: TestAddr6}
+ case dnsmessage.TypeTXT:
+ r.Answers[0].Body = &dnsmessage.TXTResource{TXT: []string{"."}}
+ case dnsmessage.TypeMX:
+ r.Answers[0].Body = &dnsmessage.MXResource{
+ MX: dnsmessage.MustNewName("go.dev."),
+ }
+ case dnsmessage.TypeNS:
+ r.Answers[0].Body = &dnsmessage.NSResource{
+ NS: dnsmessage.MustNewName("go.dev."),
+ }
+ case dnsmessage.TypeSRV:
+ r.Answers[0].Body = &dnsmessage.SRVResource{
+ Target: dnsmessage.MustNewName("go.dev."),
+ }
+ case dnsmessage.TypeCNAME:
+ r.Answers[0].Body = &dnsmessage.CNAMEResource{
+ CNAME: dnsmessage.MustNewName("fake.cname."),
+ }
+ default:
+ panic("unknown dnsmessage type")
+ }
+
+ return r, nil
+ },
+ }
+
+ r := &Resolver{PreferGo: true, Dial: fake.DialContext}
+
+ methodTests := []string{"CNAME", "Host", "IP", "IPAddr", "MX", "NS", "NetIP", "SRV", "TXT"}
+ query := func(t string, req string) error {
+ switch t {
+ case "CNAME":
+ _, err := r.LookupCNAME(context.Background(), req)
+ return err
+ case "Host":
+ _, err := r.LookupHost(context.Background(), req)
+ return err
+ case "IP":
+ _, err := r.LookupIP(context.Background(), "ip", req)
+ return err
+ case "IPAddr":
+ _, err := r.LookupIPAddr(context.Background(), req)
+ return err
+ case "MX":
+ _, err := r.LookupMX(context.Background(), req)
+ return err
+ case "NS":
+ _, err := r.LookupNS(context.Background(), req)
+ return err
+ case "NetIP":
+ _, err := r.LookupNetIP(context.Background(), "ip", req)
+ return err
+ case "SRV":
+ const service = "service"
+ const proto = "proto"
+ req = req[len(service)+len(proto)+4:]
+ _, _, err := r.LookupSRV(context.Background(), service, proto, req)
+ return err
+ case "TXT":
+ _, err := r.LookupTXT(context.Background(), req)
+ return err
+ }
+ panic("unknown query method")
+ }
+
+ for i, v := range longDNSNamesTests {
+ for _, testName := range methodTests {
+ err := query(testName, v.req)
+ if v.fail {
+ if err == nil {
+ t.Errorf("%v: Lookup%v: unexpected success", i, testName)
+ break
+ }
+
+ expectedErr := DNSError{Err: errNoSuchHost.Error(), Name: v.req, IsNotFound: true}
+ var dnsErr *DNSError
+ errors.As(err, &dnsErr)
+ if dnsErr == nil || *dnsErr != expectedErr {
+ t.Errorf("%v: Lookup%v: unexpected error: %v", i, testName, err)
+ }
+ break
+ }
+ if err != nil {
+ t.Errorf("%v: Lookup%v: unexpected error: %v", i, testName, err)
+ }
+ }
+ }
+}
+
+func TestDNSTrustAD(t *testing.T) {
+ fake := fakeDNSServer{
+ rh: func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ if q.Questions[0].Name.String() == "notrustad.go.dev." && q.Header.AuthenticData {
+ t.Error("unexpected AD bit")
+ }
+
+ if q.Questions[0].Name.String() == "trustad.go.dev." && !q.Header.AuthenticData {
+ t.Error("expected AD bit")
+ }
+
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.Header.ID,
+ Response: true,
+ RCode: dnsmessage.RCodeSuccess,
+ },
+ Questions: q.Questions,
+ }
+ if q.Questions[0].Type == dnsmessage.TypeA {
+ r.Answers = []dnsmessage.Resource{
+ {
+ Header: dnsmessage.ResourceHeader{
+ Name: q.Questions[0].Name,
+ Type: dnsmessage.TypeA,
+ Class: dnsmessage.ClassINET,
+ Length: 4,
+ },
+ Body: &dnsmessage.AResource{
+ A: TestAddr,
+ },
+ },
+ }
+ }
+
+ return r, nil
+ }}
+
+ r := &Resolver{PreferGo: true, Dial: fake.DialContext}
+
+ conf, err := newResolvConfTest()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conf.teardown()
+
+ err = conf.writeAndUpdate([]string{"nameserver 127.0.0.1"})
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if _, err := r.LookupIPAddr(context.Background(), "notrustad.go.dev"); err != nil {
+ t.Errorf("lookup failed: %v", err)
+ }
+
+ err = conf.writeAndUpdate([]string{"nameserver 127.0.0.1", "options trust-ad"})
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if _, err := r.LookupIPAddr(context.Background(), "trustad.go.dev"); err != nil {
+ t.Errorf("lookup failed: %v", err)
+ }
+}
+
+func TestDNSConfigNoReload(t *testing.T) {
+ r := &Resolver{PreferGo: true, Dial: func(ctx context.Context, network, address string) (Conn, error) {
+ if address != "192.0.2.1:53" {
+ return nil, errors.New("configuration unexpectedly changed")
+ }
+ return fakeDNSServerSuccessful.DialContext(ctx, network, address)
+ }}
+
+ conf, err := newResolvConfTest()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conf.teardown()
+
+ err = conf.writeAndUpdateWithLastCheckedTime([]string{"nameserver 192.0.2.1", "options no-reload"}, time.Now().Add(-time.Hour))
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if _, err = r.LookupHost(context.Background(), "go.dev"); err != nil {
+ t.Fatal(err)
+ }
+
+ err = conf.write([]string{"nameserver 192.0.2.200"})
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if _, err = r.LookupHost(context.Background(), "go.dev"); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestLookupOrderFilesNoSuchHost(t *testing.T) {
+ defer func(orig string) { testHookHostsPath = orig }(testHookHostsPath)
+ if runtime.GOOS != "openbsd" {
+ defer setSystemNSS(getSystemNSS(), 0)
+ setSystemNSS(nssStr(t, "hosts: files"), time.Hour)
+ }
+
+ conf, err := newResolvConfTest()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conf.teardown()
+
+ resolvConf := dnsConfig{servers: defaultNS}
+ if runtime.GOOS == "openbsd" {
+ // Set error to ErrNotExist, so that the hostLookupOrder
+ // returns hostLookupFiles for openbsd.
+ resolvConf.err = os.ErrNotExist
+ }
+
+ if !conf.forceUpdateConf(&resolvConf, time.Now().Add(time.Hour)) {
+ t.Fatal("failed to update resolv config")
+ }
+
+ tmpFile := filepath.Join(t.TempDir(), "hosts")
+ if err := os.WriteFile(tmpFile, []byte{}, 0660); err != nil {
+ t.Fatal(err)
+ }
+ testHookHostsPath = tmpFile
+
+ const testName = "test.invalid"
+
+ order, _ := systemConf().hostLookupOrder(DefaultResolver, testName)
+ if order != hostLookupFiles {
+ // skip test for systems which do not return hostLookupFiles
+ t.Skipf("hostLookupOrder did not return hostLookupFiles")
+ }
+
+ var lookupTests = []struct {
+ name string
+ lookup func(name string) error
+ }{
+ {
+ name: "Host",
+ lookup: func(name string) error {
+ _, err = DefaultResolver.LookupHost(context.Background(), name)
+ return err
+ },
+ },
+ {
+ name: "IP",
+ lookup: func(name string) error {
+ _, err = DefaultResolver.LookupIP(context.Background(), "ip", name)
+ return err
+ },
+ },
+ {
+ name: "IPAddr",
+ lookup: func(name string) error {
+ _, err = DefaultResolver.LookupIPAddr(context.Background(), name)
+ return err
+ },
+ },
+ {
+ name: "NetIP",
+ lookup: func(name string) error {
+ _, err = DefaultResolver.LookupNetIP(context.Background(), "ip", name)
+ return err
+ },
+ },
+ }
+
+ for _, v := range lookupTests {
+ err := v.lookup(testName)
+
+ if err == nil {
+ t.Errorf("Lookup%v: unexpected success", v.name)
+ continue
+ }
+
+ expectedErr := DNSError{Err: errNoSuchHost.Error(), Name: testName, IsNotFound: true}
+ var dnsErr *DNSError
+ errors.As(err, &dnsErr)
+ if dnsErr == nil || *dnsErr != expectedErr {
+ t.Errorf("Lookup%v: unexpected error: %v", v.name, err)
+ }
+ }
+}
diff --git a/src/net/dnsconfig.go b/src/net/dnsconfig.go
new file mode 100644
index 0000000..c86a70b
--- /dev/null
+++ b/src/net/dnsconfig.go
@@ -0,0 +1,45 @@
+// Copyright 2022 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "os"
+ "sync/atomic"
+ "time"
+)
+
+var (
+ defaultNS = []string{"127.0.0.1:53", "[::1]:53"}
+ getHostname = os.Hostname // variable for testing
+)
+
+type dnsConfig struct {
+ servers []string // server addresses (in host:port form) to use
+ search []string // rooted suffixes to append to local name
+ ndots int // number of dots in name to trigger absolute lookup
+ timeout time.Duration // wait before giving up on a query, including retries
+ attempts int // lost packets before giving up on server
+ rotate bool // round robin among servers
+ unknownOpt bool // anything unknown was encountered
+ lookup []string // OpenBSD top-level database "lookup" order
+ err error // any error that occurs during open of resolv.conf
+ mtime time.Time // time of resolv.conf modification
+ soffset uint32 // used by serverOffset
+ singleRequest bool // use sequential A and AAAA queries instead of parallel queries
+ useTCP bool // force usage of TCP for DNS resolutions
+ trustAD bool // add AD flag to queries
+ noReload bool // do not check for config file updates
+}
+
+// serverOffset returns an offset that can be used to determine
+// indices of servers in c.servers when making queries.
+// When the rotate option is enabled, this offset increases.
+// Otherwise it is always 0.
+func (c *dnsConfig) serverOffset() uint32 {
+ if c.rotate {
+ return atomic.AddUint32(&c.soffset, 1) - 1 // return 0 to start
+ }
+ return 0
+}
diff --git a/src/net/dnsconfig_unix.go b/src/net/dnsconfig_unix.go
new file mode 100644
index 0000000..69b3004
--- /dev/null
+++ b/src/net/dnsconfig_unix.go
@@ -0,0 +1,167 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !js && !windows
+
+// Read system DNS config from /etc/resolv.conf
+
+package net
+
+import (
+ "internal/bytealg"
+ "net/netip"
+ "time"
+)
+
+// See resolv.conf(5) on a Linux machine.
+func dnsReadConfig(filename string) *dnsConfig {
+ conf := &dnsConfig{
+ ndots: 1,
+ timeout: 5 * time.Second,
+ attempts: 2,
+ }
+ file, err := open(filename)
+ if err != nil {
+ conf.servers = defaultNS
+ conf.search = dnsDefaultSearch()
+ conf.err = err
+ return conf
+ }
+ defer file.close()
+ if fi, err := file.file.Stat(); err == nil {
+ conf.mtime = fi.ModTime()
+ } else {
+ conf.servers = defaultNS
+ conf.search = dnsDefaultSearch()
+ conf.err = err
+ return conf
+ }
+ for line, ok := file.readLine(); ok; line, ok = file.readLine() {
+ if len(line) > 0 && (line[0] == ';' || line[0] == '#') {
+ // comment.
+ continue
+ }
+ f := getFields(line)
+ if len(f) < 1 {
+ continue
+ }
+ switch f[0] {
+ case "nameserver": // add one name server
+ if len(f) > 1 && len(conf.servers) < 3 { // small, but the standard limit
+ // One more check: make sure server name is
+ // just an IP address. Otherwise we need DNS
+ // to look it up.
+ if _, err := netip.ParseAddr(f[1]); err == nil {
+ conf.servers = append(conf.servers, JoinHostPort(f[1], "53"))
+ }
+ }
+
+ case "domain": // set search path to just this domain
+ if len(f) > 1 {
+ conf.search = []string{ensureRooted(f[1])}
+ }
+
+ case "search": // set search path to given servers
+ conf.search = make([]string, 0, len(f)-1)
+ for i := 1; i < len(f); i++ {
+ name := ensureRooted(f[i])
+ if name == "." {
+ continue
+ }
+ conf.search = append(conf.search, name)
+ }
+
+ case "options": // magic options
+ for _, s := range f[1:] {
+ switch {
+ case hasPrefix(s, "ndots:"):
+ n, _, _ := dtoi(s[6:])
+ if n < 0 {
+ n = 0
+ } else if n > 15 {
+ n = 15
+ }
+ conf.ndots = n
+ case hasPrefix(s, "timeout:"):
+ n, _, _ := dtoi(s[8:])
+ if n < 1 {
+ n = 1
+ }
+ conf.timeout = time.Duration(n) * time.Second
+ case hasPrefix(s, "attempts:"):
+ n, _, _ := dtoi(s[9:])
+ if n < 1 {
+ n = 1
+ }
+ conf.attempts = n
+ case s == "rotate":
+ conf.rotate = true
+ case s == "single-request" || s == "single-request-reopen":
+ // Linux option:
+ // http://man7.org/linux/man-pages/man5/resolv.conf.5.html
+ // "By default, glibc performs IPv4 and IPv6 lookups in parallel [...]
+ // This option disables the behavior and makes glibc
+ // perform the IPv6 and IPv4 requests sequentially."
+ conf.singleRequest = true
+ case s == "use-vc" || s == "usevc" || s == "tcp":
+ // Linux (use-vc), FreeBSD (usevc) and OpenBSD (tcp) option:
+ // http://man7.org/linux/man-pages/man5/resolv.conf.5.html
+ // "Sets RES_USEVC in _res.options.
+ // This option forces the use of TCP for DNS resolutions."
+ // https://www.freebsd.org/cgi/man.cgi?query=resolv.conf&sektion=5&manpath=freebsd-release-ports
+ // https://man.openbsd.org/resolv.conf.5
+ conf.useTCP = true
+ case s == "trust-ad":
+ conf.trustAD = true
+ case s == "edns0":
+ // We use EDNS by default.
+ // Ignore this option.
+ case s == "no-reload":
+ conf.noReload = true
+ default:
+ conf.unknownOpt = true
+ }
+ }
+
+ case "lookup":
+ // OpenBSD option:
+ // https://www.openbsd.org/cgi-bin/man.cgi/OpenBSD-current/man5/resolv.conf.5
+ // "the legal space-separated values are: bind, file, yp"
+ conf.lookup = f[1:]
+
+ default:
+ conf.unknownOpt = true
+ }
+ }
+ if len(conf.servers) == 0 {
+ conf.servers = defaultNS
+ }
+ if len(conf.search) == 0 {
+ conf.search = dnsDefaultSearch()
+ }
+ return conf
+}
+
+func dnsDefaultSearch() []string {
+ hn, err := getHostname()
+ if err != nil {
+ // best effort
+ return nil
+ }
+ if i := bytealg.IndexByteString(hn, '.'); i >= 0 && i < len(hn)-1 {
+ return []string{ensureRooted(hn[i+1:])}
+ }
+ return nil
+}
+
+func hasPrefix(s, prefix string) bool {
+ return len(s) >= len(prefix) && s[:len(prefix)] == prefix
+}
+
+func ensureRooted(s string) string {
+ if len(s) > 0 && s[len(s)-1] == '.' {
+ return s
+ }
+ return s + "."
+}
diff --git a/src/net/dnsconfig_unix_test.go b/src/net/dnsconfig_unix_test.go
new file mode 100644
index 0000000..0aae2ba
--- /dev/null
+++ b/src/net/dnsconfig_unix_test.go
@@ -0,0 +1,314 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build unix
+
+package net
+
+import (
+ "errors"
+ "io/fs"
+ "os"
+ "reflect"
+ "strings"
+ "testing"
+ "time"
+)
+
+var dnsReadConfigTests = []struct {
+ name string
+ want *dnsConfig
+}{
+ {
+ name: "testdata/resolv.conf",
+ want: &dnsConfig{
+ servers: []string{"8.8.8.8:53", "[2001:4860:4860::8888]:53", "[fe80::1%lo0]:53"},
+ search: []string{"localdomain."},
+ ndots: 5,
+ timeout: 10 * time.Second,
+ attempts: 3,
+ rotate: true,
+ unknownOpt: true, // the "options attempts 3" line
+ },
+ },
+ {
+ name: "testdata/domain-resolv.conf",
+ want: &dnsConfig{
+ servers: []string{"8.8.8.8:53"},
+ search: []string{"localdomain."},
+ ndots: 1,
+ timeout: 5 * time.Second,
+ attempts: 2,
+ },
+ },
+ {
+ name: "testdata/search-resolv.conf",
+ want: &dnsConfig{
+ servers: []string{"8.8.8.8:53"},
+ search: []string{"test.", "invalid."},
+ ndots: 1,
+ timeout: 5 * time.Second,
+ attempts: 2,
+ },
+ },
+ {
+ name: "testdata/search-single-dot-resolv.conf",
+ want: &dnsConfig{
+ servers: []string{"8.8.8.8:53"},
+ search: []string{},
+ ndots: 1,
+ timeout: 5 * time.Second,
+ attempts: 2,
+ },
+ },
+ {
+ name: "testdata/empty-resolv.conf",
+ want: &dnsConfig{
+ servers: defaultNS,
+ ndots: 1,
+ timeout: 5 * time.Second,
+ attempts: 2,
+ search: []string{"domain.local."},
+ },
+ },
+ {
+ name: "testdata/invalid-ndots-resolv.conf",
+ want: &dnsConfig{
+ servers: defaultNS,
+ ndots: 0,
+ timeout: 5 * time.Second,
+ attempts: 2,
+ search: []string{"domain.local."},
+ },
+ },
+ {
+ name: "testdata/large-ndots-resolv.conf",
+ want: &dnsConfig{
+ servers: defaultNS,
+ ndots: 15,
+ timeout: 5 * time.Second,
+ attempts: 2,
+ search: []string{"domain.local."},
+ },
+ },
+ {
+ name: "testdata/negative-ndots-resolv.conf",
+ want: &dnsConfig{
+ servers: defaultNS,
+ ndots: 0,
+ timeout: 5 * time.Second,
+ attempts: 2,
+ search: []string{"domain.local."},
+ },
+ },
+ {
+ name: "testdata/openbsd-resolv.conf",
+ want: &dnsConfig{
+ ndots: 1,
+ timeout: 5 * time.Second,
+ attempts: 2,
+ lookup: []string{"file", "bind"},
+ servers: []string{"169.254.169.254:53", "10.240.0.1:53"},
+ search: []string{"c.symbolic-datum-552.internal."},
+ },
+ },
+ {
+ name: "testdata/single-request-resolv.conf",
+ want: &dnsConfig{
+ servers: defaultNS,
+ ndots: 1,
+ singleRequest: true,
+ timeout: 5 * time.Second,
+ attempts: 2,
+ search: []string{"domain.local."},
+ },
+ },
+ {
+ name: "testdata/single-request-reopen-resolv.conf",
+ want: &dnsConfig{
+ servers: defaultNS,
+ ndots: 1,
+ singleRequest: true,
+ timeout: 5 * time.Second,
+ attempts: 2,
+ search: []string{"domain.local."},
+ },
+ },
+ {
+ name: "testdata/linux-use-vc-resolv.conf",
+ want: &dnsConfig{
+ servers: defaultNS,
+ ndots: 1,
+ useTCP: true,
+ timeout: 5 * time.Second,
+ attempts: 2,
+ search: []string{"domain.local."},
+ },
+ },
+ {
+ name: "testdata/freebsd-usevc-resolv.conf",
+ want: &dnsConfig{
+ servers: defaultNS,
+ ndots: 1,
+ useTCP: true,
+ timeout: 5 * time.Second,
+ attempts: 2,
+ search: []string{"domain.local."},
+ },
+ },
+ {
+ name: "testdata/openbsd-tcp-resolv.conf",
+ want: &dnsConfig{
+ servers: defaultNS,
+ ndots: 1,
+ useTCP: true,
+ timeout: 5 * time.Second,
+ attempts: 2,
+ search: []string{"domain.local."},
+ },
+ },
+}
+
+func TestDNSReadConfig(t *testing.T) {
+ origGetHostname := getHostname
+ defer func() { getHostname = origGetHostname }()
+ getHostname = func() (string, error) { return "host.domain.local", nil }
+
+ for _, tt := range dnsReadConfigTests {
+ want := *tt.want
+ if len(want.search) == 0 {
+ want.search = dnsDefaultSearch()
+ }
+ conf := dnsReadConfig(tt.name)
+ if conf.err != nil {
+ t.Fatal(conf.err)
+ }
+ conf.mtime = time.Time{}
+ if !reflect.DeepEqual(conf, &want) {
+ t.Errorf("%s:\ngot: %+v\nwant: %+v", tt.name, conf, want)
+ }
+ }
+}
+
+func TestDNSReadMissingFile(t *testing.T) {
+ origGetHostname := getHostname
+ defer func() { getHostname = origGetHostname }()
+ getHostname = func() (string, error) { return "host.domain.local", nil }
+
+ conf := dnsReadConfig("a-nonexistent-file")
+ if !os.IsNotExist(conf.err) {
+ t.Errorf("missing resolv.conf:\ngot: %v\nwant: %v", conf.err, fs.ErrNotExist)
+ }
+ conf.err = nil
+ want := &dnsConfig{
+ servers: defaultNS,
+ ndots: 1,
+ timeout: 5 * time.Second,
+ attempts: 2,
+ search: []string{"domain.local."},
+ }
+ if !reflect.DeepEqual(conf, want) {
+ t.Errorf("missing resolv.conf:\ngot: %+v\nwant: %+v", conf, want)
+ }
+}
+
+var dnsDefaultSearchTests = []struct {
+ name string
+ err error
+ want []string
+}{
+ {
+ name: "host.long.domain.local",
+ want: []string{"long.domain.local."},
+ },
+ {
+ name: "host.local",
+ want: []string{"local."},
+ },
+ {
+ name: "host",
+ want: nil,
+ },
+ {
+ name: "host.domain.local",
+ err: errors.New("errored"),
+ want: nil,
+ },
+ {
+ // ensures we don't return []string{""}
+ // which causes duplicate lookups
+ name: "foo.",
+ want: nil,
+ },
+}
+
+func TestDNSDefaultSearch(t *testing.T) {
+ origGetHostname := getHostname
+ defer func() { getHostname = origGetHostname }()
+
+ for _, tt := range dnsDefaultSearchTests {
+ getHostname = func() (string, error) { return tt.name, tt.err }
+ got := dnsDefaultSearch()
+ if !reflect.DeepEqual(got, tt.want) {
+ t.Errorf("dnsDefaultSearch with hostname %q and error %+v = %q, wanted %q", tt.name, tt.err, got, tt.want)
+ }
+ }
+}
+
+func TestDNSNameLength(t *testing.T) {
+ origGetHostname := getHostname
+ defer func() { getHostname = origGetHostname }()
+ getHostname = func() (string, error) { return "host.domain.local", nil }
+
+ var char63 = ""
+ for i := 0; i < 63; i++ {
+ char63 += "a"
+ }
+ longDomain := strings.Repeat(char63+".", 5) + "example"
+
+ for _, tt := range dnsReadConfigTests {
+ conf := dnsReadConfig(tt.name)
+ if conf.err != nil {
+ t.Fatal(conf.err)
+ }
+
+ suffixList := tt.want.search
+ if len(suffixList) == 0 {
+ suffixList = dnsDefaultSearch()
+ }
+
+ var shortestSuffix int
+ for _, suffix := range suffixList {
+ if shortestSuffix == 0 || len(suffix) < shortestSuffix {
+ shortestSuffix = len(suffix)
+ }
+ }
+
+ // Test a name that will be maximally long when prefixing the shortest
+ // suffix (accounting for the intervening dot).
+ longName := longDomain[len(longDomain)-254+1+shortestSuffix:]
+ if longName[0] == '.' || longName[1] == '.' {
+ longName = "aa." + longName[3:]
+ }
+ for _, fqdn := range conf.nameList(longName) {
+ if len(fqdn) > 254 {
+ t.Errorf("got %d; want less than or equal to 254", len(fqdn))
+ }
+ }
+
+ // Now test a name that's too long for suffixing.
+ unsuffixable := "a." + longName[1:]
+ unsuffixableResults := conf.nameList(unsuffixable)
+ if len(unsuffixableResults) != 1 {
+ t.Errorf("suffixed names %v; want []", unsuffixableResults[1:])
+ }
+
+ // Now test a name that's too long for DNS.
+ tooLong := "a." + longDomain
+ tooLongResults := conf.nameList(tooLong)
+ if tooLongResults != nil {
+ t.Errorf("suffixed names %v; want nil", tooLongResults)
+ }
+ }
+}
diff --git a/src/net/dnsconfig_windows.go b/src/net/dnsconfig_windows.go
new file mode 100644
index 0000000..f3d2423
--- /dev/null
+++ b/src/net/dnsconfig_windows.go
@@ -0,0 +1,63 @@
+// Copyright 2022 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "internal/syscall/windows"
+ "syscall"
+ "time"
+)
+
+func dnsReadConfig(ignoredFilename string) (conf *dnsConfig) {
+ conf = &dnsConfig{
+ ndots: 1,
+ timeout: 5 * time.Second,
+ attempts: 2,
+ }
+ defer func() {
+ if len(conf.servers) == 0 {
+ conf.servers = defaultNS
+ }
+ }()
+ aas, err := adapterAddresses()
+ if err != nil {
+ return
+ }
+ // TODO(bradfitz): this just collects all the DNS servers on all
+ // the interfaces in some random order. It should order it by
+ // default route, or only use the default route(s) instead.
+ // In practice, however, it mostly works.
+ for _, aa := range aas {
+ for dns := aa.FirstDnsServerAddress; dns != nil; dns = dns.Next {
+ // Only take interfaces whose OperStatus is IfOperStatusUp(0x01) into DNS configs.
+ if aa.OperStatus != windows.IfOperStatusUp {
+ continue
+ }
+ sa, err := dns.Address.Sockaddr.Sockaddr()
+ if err != nil {
+ continue
+ }
+ var ip IP
+ switch sa := sa.(type) {
+ case *syscall.SockaddrInet4:
+ ip = IPv4(sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3])
+ case *syscall.SockaddrInet6:
+ ip = make(IP, IPv6len)
+ copy(ip, sa.Addr[:])
+ if ip[0] == 0xfe && ip[1] == 0xc0 {
+ // Ignore these fec0/10 ones. Windows seems to
+ // populate them as defaults on its misc rando
+ // interfaces.
+ continue
+ }
+ default:
+ // Unexpected type.
+ continue
+ }
+ conf.servers = append(conf.servers, JoinHostPort(ip.String(), "53"))
+ }
+ }
+ return conf
+}
diff --git a/src/net/dnsname_test.go b/src/net/dnsname_test.go
new file mode 100644
index 0000000..4a5f01a
--- /dev/null
+++ b/src/net/dnsname_test.go
@@ -0,0 +1,86 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !js && !wasip1
+
+package net
+
+import (
+ "strings"
+ "testing"
+)
+
+type dnsNameTest struct {
+ name string
+ result bool
+}
+
+var dnsNameTests = []dnsNameTest{
+ // RFC 2181, section 11.
+ {"_xmpp-server._tcp.google.com", true},
+ {"foo.com", true},
+ {"1foo.com", true},
+ {"26.0.0.73.com", true},
+ {"10-0-0-1", true},
+ {"fo-o.com", true},
+ {"fo1o.com", true},
+ {"foo1.com", true},
+ {"a.b..com", false},
+ {"a.b-.com", false},
+ {"a.b.com-", false},
+ {"a.b..", false},
+ {"b.com.", true},
+}
+
+func emitDNSNameTest(ch chan<- dnsNameTest) {
+ defer close(ch)
+ var char63 = ""
+ for i := 0; i < 63; i++ {
+ char63 += "a"
+ }
+ char64 := char63 + "a"
+ longDomain := strings.Repeat(char63+".", 5) + "example"
+
+ for _, tc := range dnsNameTests {
+ ch <- tc
+ }
+
+ ch <- dnsNameTest{char63 + ".com", true}
+ ch <- dnsNameTest{char64 + ".com", false}
+
+ // Remember: wire format is two octets longer than presentation
+ // (length octets for the first and [root] last labels).
+ // 253 is fine:
+ ch <- dnsNameTest{longDomain[len(longDomain)-253:], true}
+ // A terminal dot doesn't contribute to length:
+ ch <- dnsNameTest{longDomain[len(longDomain)-253:] + ".", true}
+ // 254 is bad:
+ ch <- dnsNameTest{longDomain[len(longDomain)-254:], false}
+}
+
+func TestDNSName(t *testing.T) {
+ ch := make(chan dnsNameTest)
+ go emitDNSNameTest(ch)
+ for tc := range ch {
+ if isDomainName(tc.name) != tc.result {
+ t.Errorf("isDomainName(%q) = %v; want %v", tc.name, !tc.result, tc.result)
+ }
+ }
+}
+
+func BenchmarkDNSName(b *testing.B) {
+ testHookUninstaller.Do(uninstallTestHooks)
+
+ benchmarks := append(dnsNameTests, []dnsNameTest{
+ {strings.Repeat("a", 63), true},
+ {strings.Repeat("a", 64), false},
+ }...)
+ for n := 0; n < b.N; n++ {
+ for _, tc := range benchmarks {
+ if isDomainName(tc.name) != tc.result {
+ b.Errorf("isDomainName(%q) = %v; want %v", tc.name, !tc.result, tc.result)
+ }
+ }
+ }
+}
diff --git a/src/net/error_plan9.go b/src/net/error_plan9.go
new file mode 100644
index 0000000..caad133
--- /dev/null
+++ b/src/net/error_plan9.go
@@ -0,0 +1,9 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+func isConnError(err error) bool {
+ return false
+}
diff --git a/src/net/error_plan9_test.go b/src/net/error_plan9_test.go
new file mode 100644
index 0000000..1270af1
--- /dev/null
+++ b/src/net/error_plan9_test.go
@@ -0,0 +1,23 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import "syscall"
+
+var (
+ errTimedout = syscall.ETIMEDOUT
+ errOpNotSupported = syscall.EPLAN9
+
+ abortedConnRequestErrors []error
+)
+
+func isPlatformError(err error) bool {
+ _, ok := err.(syscall.ErrorString)
+ return ok
+}
+
+func isENOBUFS(err error) bool {
+ return false // ENOBUFS is Unix-specific
+}
diff --git a/src/net/error_posix.go b/src/net/error_posix.go
new file mode 100644
index 0000000..c8dc069
--- /dev/null
+++ b/src/net/error_posix.go
@@ -0,0 +1,21 @@
+// Copyright 2017 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build unix || (js && wasm) || wasip1 || windows
+
+package net
+
+import (
+ "os"
+ "syscall"
+)
+
+// wrapSyscallError takes an error and a syscall name. If the error is
+// a syscall.Errno, it wraps it in an os.SyscallError using the syscall name.
+func wrapSyscallError(name string, err error) error {
+ if _, ok := err.(syscall.Errno); ok {
+ err = os.NewSyscallError(name, err)
+ }
+ return err
+}
diff --git a/src/net/error_posix_test.go b/src/net/error_posix_test.go
new file mode 100644
index 0000000..081176f
--- /dev/null
+++ b/src/net/error_posix_test.go
@@ -0,0 +1,34 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !plan9
+
+package net
+
+import (
+ "os"
+ "syscall"
+ "testing"
+)
+
+func TestSpuriousENOTAVAIL(t *testing.T) {
+ for _, tt := range []struct {
+ error
+ ok bool
+ }{
+ {syscall.EADDRNOTAVAIL, true},
+ {&os.SyscallError{Syscall: "syscall", Err: syscall.EADDRNOTAVAIL}, true},
+ {&OpError{Op: "op", Err: syscall.EADDRNOTAVAIL}, true},
+ {&OpError{Op: "op", Err: &os.SyscallError{Syscall: "syscall", Err: syscall.EADDRNOTAVAIL}}, true},
+
+ {syscall.EINVAL, false},
+ {&os.SyscallError{Syscall: "syscall", Err: syscall.EINVAL}, false},
+ {&OpError{Op: "op", Err: syscall.EINVAL}, false},
+ {&OpError{Op: "op", Err: &os.SyscallError{Syscall: "syscall", Err: syscall.EINVAL}}, false},
+ } {
+ if ok := spuriousENOTAVAIL(tt.error); ok != tt.ok {
+ t.Errorf("spuriousENOTAVAIL(%v) = %v; want %v", tt.error, ok, tt.ok)
+ }
+ }
+}
diff --git a/src/net/error_test.go b/src/net/error_test.go
new file mode 100644
index 0000000..4538765
--- /dev/null
+++ b/src/net/error_test.go
@@ -0,0 +1,810 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !js && !wasip1
+
+package net
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "internal/poll"
+ "io"
+ "io/fs"
+ "net/internal/socktest"
+ "os"
+ "runtime"
+ "strings"
+ "testing"
+ "time"
+)
+
+func (e *OpError) isValid() error {
+ if e.Op == "" {
+ return fmt.Errorf("OpError.Op is empty: %v", e)
+ }
+ if e.Net == "" {
+ return fmt.Errorf("OpError.Net is empty: %v", e)
+ }
+ for _, addr := range []Addr{e.Source, e.Addr} {
+ switch addr := addr.(type) {
+ case nil:
+ case *TCPAddr:
+ if addr == nil {
+ return fmt.Errorf("OpError.Source or Addr is non-nil interface: %#v, %v", addr, e)
+ }
+ case *UDPAddr:
+ if addr == nil {
+ return fmt.Errorf("OpError.Source or Addr is non-nil interface: %#v, %v", addr, e)
+ }
+ case *IPAddr:
+ if addr == nil {
+ return fmt.Errorf("OpError.Source or Addr is non-nil interface: %#v, %v", addr, e)
+ }
+ case *IPNet:
+ if addr == nil {
+ return fmt.Errorf("OpError.Source or Addr is non-nil interface: %#v, %v", addr, e)
+ }
+ case *UnixAddr:
+ if addr == nil {
+ return fmt.Errorf("OpError.Source or Addr is non-nil interface: %#v, %v", addr, e)
+ }
+ case *pipeAddr:
+ if addr == nil {
+ return fmt.Errorf("OpError.Source or Addr is non-nil interface: %#v, %v", addr, e)
+ }
+ case fileAddr:
+ if addr == "" {
+ return fmt.Errorf("OpError.Source or Addr is empty: %#v, %v", addr, e)
+ }
+ default:
+ return fmt.Errorf("OpError.Source or Addr is unknown type: %T, %v", addr, e)
+ }
+ }
+ if e.Err == nil {
+ return fmt.Errorf("OpError.Err is empty: %v", e)
+ }
+ return nil
+}
+
+// parseDialError parses nestedErr and reports whether it is a valid
+// error value from Dial, Listen functions.
+// It returns nil when nestedErr is valid.
+func parseDialError(nestedErr error) error {
+ if nestedErr == nil {
+ return nil
+ }
+
+ switch err := nestedErr.(type) {
+ case *OpError:
+ if err := err.isValid(); err != nil {
+ return err
+ }
+ nestedErr = err.Err
+ goto second
+ }
+ return fmt.Errorf("unexpected type on 1st nested level: %T", nestedErr)
+
+second:
+ if isPlatformError(nestedErr) {
+ return nil
+ }
+ switch err := nestedErr.(type) {
+ case *AddrError, *timeoutError, *DNSError, InvalidAddrError, *ParseError, *poll.DeadlineExceededError, UnknownNetworkError:
+ return nil
+ case interface{ isAddrinfoErrno() }:
+ return nil
+ case *os.SyscallError:
+ nestedErr = err.Err
+ goto third
+ case *fs.PathError: // for Plan 9
+ nestedErr = err.Err
+ goto third
+ }
+ switch nestedErr {
+ case errCanceled, ErrClosed, errMissingAddress, errNoSuitableAddress,
+ context.DeadlineExceeded, context.Canceled:
+ return nil
+ }
+ return fmt.Errorf("unexpected type on 2nd nested level: %T", nestedErr)
+
+third:
+ if isPlatformError(nestedErr) {
+ return nil
+ }
+ return fmt.Errorf("unexpected type on 3rd nested level: %T", nestedErr)
+}
+
+var dialErrorTests = []struct {
+ network, address string
+}{
+ {"foo", ""},
+ {"bar", "baz"},
+ {"datakit", "mh/astro/r70"},
+ {"tcp", ""},
+ {"tcp", "127.0.0.1:☺"},
+ {"tcp", "no-such-name:80"},
+ {"tcp", "mh/astro/r70:http"},
+
+ {"tcp", JoinHostPort("127.0.0.1", "-1")},
+ {"tcp", JoinHostPort("127.0.0.1", "123456789")},
+ {"udp", JoinHostPort("127.0.0.1", "-1")},
+ {"udp", JoinHostPort("127.0.0.1", "123456789")},
+ {"ip:icmp", "127.0.0.1"},
+
+ {"unix", "/path/to/somewhere"},
+ {"unixgram", "/path/to/somewhere"},
+ {"unixpacket", "/path/to/somewhere"},
+}
+
+func TestDialError(t *testing.T) {
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("%s does not have full support of socktest", runtime.GOOS)
+ }
+
+ origTestHookLookupIP := testHookLookupIP
+ defer func() { testHookLookupIP = origTestHookLookupIP }()
+ testHookLookupIP = func(ctx context.Context, fn func(context.Context, string, string) ([]IPAddr, error), network, host string) ([]IPAddr, error) {
+ return nil, &DNSError{Err: "dial error test", Name: "name", Server: "server", IsTimeout: true}
+ }
+ sw.Set(socktest.FilterConnect, func(so *socktest.Status) (socktest.AfterFilter, error) {
+ return nil, errOpNotSupported
+ })
+ defer sw.Set(socktest.FilterConnect, nil)
+
+ d := Dialer{Timeout: someTimeout}
+ for i, tt := range dialErrorTests {
+ c, err := d.Dial(tt.network, tt.address)
+ if err == nil {
+ t.Errorf("#%d: should fail; %s:%s->%s", i, c.LocalAddr().Network(), c.LocalAddr(), c.RemoteAddr())
+ c.Close()
+ continue
+ }
+ if tt.network == "tcp" || tt.network == "udp" {
+ nerr := err
+ if op, ok := nerr.(*OpError); ok {
+ nerr = op.Err
+ }
+ if sys, ok := nerr.(*os.SyscallError); ok {
+ nerr = sys.Err
+ }
+ if nerr == errOpNotSupported {
+ t.Errorf("#%d: should fail without %v; %s:%s->", i, nerr, tt.network, tt.address)
+ continue
+ }
+ }
+ if c != nil {
+ t.Errorf("Dial returned non-nil interface %T(%v) with err != nil", c, c)
+ }
+ if err = parseDialError(err); err != nil {
+ t.Errorf("#%d: %v", i, err)
+ continue
+ }
+ }
+}
+
+func TestProtocolDialError(t *testing.T) {
+ switch runtime.GOOS {
+ case "solaris", "illumos":
+ t.Skipf("not supported on %s", runtime.GOOS)
+ }
+
+ for _, network := range []string{"tcp", "udp", "ip:4294967296", "unix", "unixpacket", "unixgram"} {
+ var err error
+ switch network {
+ case "tcp":
+ _, err = DialTCP(network, nil, &TCPAddr{Port: 1 << 16})
+ case "udp":
+ _, err = DialUDP(network, nil, &UDPAddr{Port: 1 << 16})
+ case "ip:4294967296":
+ _, err = DialIP(network, nil, nil)
+ case "unix", "unixpacket", "unixgram":
+ _, err = DialUnix(network, nil, &UnixAddr{Name: "//"})
+ }
+ if err == nil {
+ t.Errorf("%s: should fail", network)
+ continue
+ }
+ if err = parseDialError(err); err != nil {
+ t.Errorf("%s: %v", network, err)
+ continue
+ }
+ }
+}
+
+func TestDialAddrError(t *testing.T) {
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("not supported on %s", runtime.GOOS)
+ }
+ if !supportsIPv4() || !supportsIPv6() {
+ t.Skip("both IPv4 and IPv6 are required")
+ }
+
+ for _, tt := range []struct {
+ network string
+ lit string
+ addr *TCPAddr
+ }{
+ {"tcp4", "::1", nil},
+ {"tcp4", "", &TCPAddr{IP: IPv6loopback}},
+ // We don't test the {"tcp6", "byte sequence", nil}
+ // case for now because there is no easy way to
+ // control name resolution.
+ {"tcp6", "", &TCPAddr{IP: IP{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef}}},
+ } {
+ var err error
+ var c Conn
+ var op string
+ if tt.lit != "" {
+ c, err = Dial(tt.network, JoinHostPort(tt.lit, "0"))
+ op = fmt.Sprintf("Dial(%q, %q)", tt.network, JoinHostPort(tt.lit, "0"))
+ } else {
+ c, err = DialTCP(tt.network, nil, tt.addr)
+ op = fmt.Sprintf("DialTCP(%q, %q)", tt.network, tt.addr)
+ }
+ if err == nil {
+ c.Close()
+ t.Errorf("%s succeeded, want error", op)
+ continue
+ }
+ if perr := parseDialError(err); perr != nil {
+ t.Errorf("%s: %v", op, perr)
+ continue
+ }
+ operr := err.(*OpError).Err
+ aerr, ok := operr.(*AddrError)
+ if !ok {
+ t.Errorf("%s: %v is %T, want *AddrError", op, err, operr)
+ continue
+ }
+ want := tt.lit
+ if tt.lit == "" {
+ want = tt.addr.IP.String()
+ }
+ if aerr.Addr != want {
+ t.Errorf("%s: %v, error Addr=%q, want %q", op, err, aerr.Addr, want)
+ }
+ }
+}
+
+var listenErrorTests = []struct {
+ network, address string
+}{
+ {"foo", ""},
+ {"bar", "baz"},
+ {"datakit", "mh/astro/r70"},
+ {"tcp", "127.0.0.1:☺"},
+ {"tcp", "no-such-name:80"},
+ {"tcp", "mh/astro/r70:http"},
+
+ {"tcp", JoinHostPort("127.0.0.1", "-1")},
+ {"tcp", JoinHostPort("127.0.0.1", "123456789")},
+
+ {"unix", "/path/to/somewhere"},
+ {"unixpacket", "/path/to/somewhere"},
+}
+
+func TestListenError(t *testing.T) {
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("%s does not have full support of socktest", runtime.GOOS)
+ }
+
+ origTestHookLookupIP := testHookLookupIP
+ defer func() { testHookLookupIP = origTestHookLookupIP }()
+ testHookLookupIP = func(_ context.Context, fn func(context.Context, string, string) ([]IPAddr, error), network, host string) ([]IPAddr, error) {
+ return nil, &DNSError{Err: "listen error test", Name: "name", Server: "server", IsTimeout: true}
+ }
+ sw.Set(socktest.FilterListen, func(so *socktest.Status) (socktest.AfterFilter, error) {
+ return nil, errOpNotSupported
+ })
+ defer sw.Set(socktest.FilterListen, nil)
+
+ for i, tt := range listenErrorTests {
+ ln, err := Listen(tt.network, tt.address)
+ if err == nil {
+ t.Errorf("#%d: should fail; %s:%s->", i, ln.Addr().Network(), ln.Addr())
+ ln.Close()
+ continue
+ }
+ if tt.network == "tcp" {
+ nerr := err
+ if op, ok := nerr.(*OpError); ok {
+ nerr = op.Err
+ }
+ if sys, ok := nerr.(*os.SyscallError); ok {
+ nerr = sys.Err
+ }
+ if nerr == errOpNotSupported {
+ t.Errorf("#%d: should fail without %v; %s:%s->", i, nerr, tt.network, tt.address)
+ continue
+ }
+ }
+ if ln != nil {
+ t.Errorf("Listen returned non-nil interface %T(%v) with err != nil", ln, ln)
+ }
+ if err = parseDialError(err); err != nil {
+ t.Errorf("#%d: %v", i, err)
+ continue
+ }
+ }
+}
+
+var listenPacketErrorTests = []struct {
+ network, address string
+}{
+ {"foo", ""},
+ {"bar", "baz"},
+ {"datakit", "mh/astro/r70"},
+ {"udp", "127.0.0.1:☺"},
+ {"udp", "no-such-name:80"},
+ {"udp", "mh/astro/r70:http"},
+
+ {"udp", JoinHostPort("127.0.0.1", "-1")},
+ {"udp", JoinHostPort("127.0.0.1", "123456789")},
+}
+
+func TestListenPacketError(t *testing.T) {
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("%s does not have full support of socktest", runtime.GOOS)
+ }
+
+ origTestHookLookupIP := testHookLookupIP
+ defer func() { testHookLookupIP = origTestHookLookupIP }()
+ testHookLookupIP = func(_ context.Context, fn func(context.Context, string, string) ([]IPAddr, error), network, host string) ([]IPAddr, error) {
+ return nil, &DNSError{Err: "listen error test", Name: "name", Server: "server", IsTimeout: true}
+ }
+
+ for i, tt := range listenPacketErrorTests {
+ c, err := ListenPacket(tt.network, tt.address)
+ if err == nil {
+ t.Errorf("#%d: should fail; %s:%s->", i, c.LocalAddr().Network(), c.LocalAddr())
+ c.Close()
+ continue
+ }
+ if c != nil {
+ t.Errorf("ListenPacket returned non-nil interface %T(%v) with err != nil", c, c)
+ }
+ if err = parseDialError(err); err != nil {
+ t.Errorf("#%d: %v", i, err)
+ continue
+ }
+ }
+}
+
+func TestProtocolListenError(t *testing.T) {
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("not supported on %s", runtime.GOOS)
+ }
+
+ for _, network := range []string{"tcp", "udp", "ip:4294967296", "unix", "unixpacket", "unixgram"} {
+ var err error
+ switch network {
+ case "tcp":
+ _, err = ListenTCP(network, &TCPAddr{Port: 1 << 16})
+ case "udp":
+ _, err = ListenUDP(network, &UDPAddr{Port: 1 << 16})
+ case "ip:4294967296":
+ _, err = ListenIP(network, nil)
+ case "unix", "unixpacket":
+ _, err = ListenUnix(network, &UnixAddr{Name: "//"})
+ case "unixgram":
+ _, err = ListenUnixgram(network, &UnixAddr{Name: "//"})
+ }
+ if err == nil {
+ t.Errorf("%s: should fail", network)
+ continue
+ }
+ if err = parseDialError(err); err != nil {
+ t.Errorf("%s: %v", network, err)
+ continue
+ }
+ }
+}
+
+// parseReadError parses nestedErr and reports whether it is a valid
+// error value from Read functions.
+// It returns nil when nestedErr is valid.
+func parseReadError(nestedErr error) error {
+ if nestedErr == nil {
+ return nil
+ }
+
+ switch err := nestedErr.(type) {
+ case *OpError:
+ if err := err.isValid(); err != nil {
+ return err
+ }
+ nestedErr = err.Err
+ goto second
+ }
+ if nestedErr == io.EOF {
+ return nil
+ }
+ return fmt.Errorf("unexpected type on 1st nested level: %T", nestedErr)
+
+second:
+ if isPlatformError(nestedErr) {
+ return nil
+ }
+ switch err := nestedErr.(type) {
+ case *os.SyscallError:
+ nestedErr = err.Err
+ goto third
+ }
+ switch nestedErr {
+ case ErrClosed, errTimeout, poll.ErrNotPollable, os.ErrDeadlineExceeded:
+ return nil
+ }
+ return fmt.Errorf("unexpected type on 2nd nested level: %T", nestedErr)
+
+third:
+ if isPlatformError(nestedErr) {
+ return nil
+ }
+ return fmt.Errorf("unexpected type on 3rd nested level: %T", nestedErr)
+}
+
+// parseWriteError parses nestedErr and reports whether it is a valid
+// error value from Write functions.
+// It returns nil when nestedErr is valid.
+func parseWriteError(nestedErr error) error {
+ if nestedErr == nil {
+ return nil
+ }
+
+ switch err := nestedErr.(type) {
+ case *OpError:
+ if err := err.isValid(); err != nil {
+ return err
+ }
+ nestedErr = err.Err
+ goto second
+ }
+ return fmt.Errorf("unexpected type on 1st nested level: %T", nestedErr)
+
+second:
+ if isPlatformError(nestedErr) {
+ return nil
+ }
+ switch err := nestedErr.(type) {
+ case *AddrError, *timeoutError, *DNSError, InvalidAddrError, *ParseError, *poll.DeadlineExceededError, UnknownNetworkError:
+ return nil
+ case interface{ isAddrinfoErrno() }:
+ return nil
+ case *os.SyscallError:
+ nestedErr = err.Err
+ goto third
+ }
+ switch nestedErr {
+ case errCanceled, ErrClosed, errMissingAddress, errTimeout, os.ErrDeadlineExceeded, ErrWriteToConnected, io.ErrUnexpectedEOF:
+ return nil
+ }
+ return fmt.Errorf("unexpected type on 2nd nested level: %T", nestedErr)
+
+third:
+ if isPlatformError(nestedErr) {
+ return nil
+ }
+ return fmt.Errorf("unexpected type on 3rd nested level: %T", nestedErr)
+}
+
+// parseCloseError parses nestedErr and reports whether it is a valid
+// error value from Close functions.
+// It returns nil when nestedErr is valid.
+func parseCloseError(nestedErr error, isShutdown bool) error {
+ if nestedErr == nil {
+ return nil
+ }
+
+ // Because historically we have not exported the error that we
+ // return for an operation on a closed network connection,
+ // there are programs that test for the exact error string.
+ // Verify that string here so that we don't break those
+ // programs unexpectedly. See issues #4373 and #19252.
+ want := "use of closed network connection"
+ if !isShutdown && !strings.Contains(nestedErr.Error(), want) {
+ return fmt.Errorf("error string %q does not contain expected string %q", nestedErr, want)
+ }
+
+ if !isShutdown && !errors.Is(nestedErr, ErrClosed) {
+ return fmt.Errorf("errors.Is(%v, errClosed) returns false, want true", nestedErr)
+ }
+
+ switch err := nestedErr.(type) {
+ case *OpError:
+ if err := err.isValid(); err != nil {
+ return err
+ }
+ nestedErr = err.Err
+ goto second
+ }
+ return fmt.Errorf("unexpected type on 1st nested level: %T", nestedErr)
+
+second:
+ if isPlatformError(nestedErr) {
+ return nil
+ }
+ switch err := nestedErr.(type) {
+ case *os.SyscallError:
+ nestedErr = err.Err
+ goto third
+ case *fs.PathError: // for Plan 9
+ nestedErr = err.Err
+ goto third
+ }
+ switch nestedErr {
+ case ErrClosed:
+ return nil
+ }
+ return fmt.Errorf("unexpected type on 2nd nested level: %T", nestedErr)
+
+third:
+ if isPlatformError(nestedErr) {
+ return nil
+ }
+ switch nestedErr {
+ case fs.ErrClosed: // for Plan 9
+ return nil
+ }
+ return fmt.Errorf("unexpected type on 3rd nested level: %T", nestedErr)
+}
+
+func TestCloseError(t *testing.T) {
+ ln := newLocalListener(t, "tcp")
+ defer ln.Close()
+ c, err := Dial(ln.Addr().Network(), ln.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+
+ for i := 0; i < 3; i++ {
+ err = c.(*TCPConn).CloseRead()
+ if perr := parseCloseError(err, true); perr != nil {
+ t.Errorf("#%d: %v", i, perr)
+ }
+ }
+ for i := 0; i < 3; i++ {
+ err = c.(*TCPConn).CloseWrite()
+ if perr := parseCloseError(err, true); perr != nil {
+ t.Errorf("#%d: %v", i, perr)
+ }
+ }
+ for i := 0; i < 3; i++ {
+ err = c.Close()
+ if perr := parseCloseError(err, false); perr != nil {
+ t.Errorf("#%d: %v", i, perr)
+ }
+ err = ln.Close()
+ if perr := parseCloseError(err, false); perr != nil {
+ t.Errorf("#%d: %v", i, perr)
+ }
+ }
+
+ pc, err := ListenPacket("udp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer pc.Close()
+
+ for i := 0; i < 3; i++ {
+ err = pc.Close()
+ if perr := parseCloseError(err, false); perr != nil {
+ t.Errorf("#%d: %v", i, perr)
+ }
+ }
+}
+
+// parseAcceptError parses nestedErr and reports whether it is a valid
+// error value from Accept functions.
+// It returns nil when nestedErr is valid.
+func parseAcceptError(nestedErr error) error {
+ if nestedErr == nil {
+ return nil
+ }
+
+ switch err := nestedErr.(type) {
+ case *OpError:
+ if err := err.isValid(); err != nil {
+ return err
+ }
+ nestedErr = err.Err
+ goto second
+ }
+ return fmt.Errorf("unexpected type on 1st nested level: %T", nestedErr)
+
+second:
+ if isPlatformError(nestedErr) {
+ return nil
+ }
+ switch err := nestedErr.(type) {
+ case *os.SyscallError:
+ nestedErr = err.Err
+ goto third
+ case *fs.PathError: // for Plan 9
+ nestedErr = err.Err
+ goto third
+ }
+ switch nestedErr {
+ case ErrClosed, errTimeout, poll.ErrNotPollable, os.ErrDeadlineExceeded:
+ return nil
+ }
+ return fmt.Errorf("unexpected type on 2nd nested level: %T", nestedErr)
+
+third:
+ if isPlatformError(nestedErr) {
+ return nil
+ }
+ return fmt.Errorf("unexpected type on 3rd nested level: %T", nestedErr)
+}
+
+func TestAcceptError(t *testing.T) {
+ handler := func(ls *localServer, ln Listener) {
+ for {
+ ln.(*TCPListener).SetDeadline(time.Now().Add(5 * time.Millisecond))
+ c, err := ln.Accept()
+ if perr := parseAcceptError(err); perr != nil {
+ t.Error(perr)
+ }
+ if err != nil {
+ if c != nil {
+ t.Errorf("Accept returned non-nil interface %T(%v) with err != nil", c, c)
+ }
+ if nerr, ok := err.(Error); !ok || (!nerr.Timeout() && !nerr.Temporary()) {
+ return
+ }
+ continue
+ }
+ c.Close()
+ }
+ }
+ ls := newLocalServer(t, "tcp")
+ if err := ls.buildup(handler); err != nil {
+ ls.teardown()
+ t.Fatal(err)
+ }
+
+ time.Sleep(100 * time.Millisecond)
+ ls.teardown()
+}
+
+// parseCommonError parses nestedErr and reports whether it is a valid
+// error value from miscellaneous functions.
+// It returns nil when nestedErr is valid.
+func parseCommonError(nestedErr error) error {
+ if nestedErr == nil {
+ return nil
+ }
+
+ switch err := nestedErr.(type) {
+ case *OpError:
+ if err := err.isValid(); err != nil {
+ return err
+ }
+ nestedErr = err.Err
+ goto second
+ }
+ return fmt.Errorf("unexpected type on 1st nested level: %T", nestedErr)
+
+second:
+ if isPlatformError(nestedErr) {
+ return nil
+ }
+ switch err := nestedErr.(type) {
+ case *os.SyscallError:
+ nestedErr = err.Err
+ goto third
+ case *os.LinkError:
+ nestedErr = err.Err
+ goto third
+ case *fs.PathError:
+ nestedErr = err.Err
+ goto third
+ }
+ switch nestedErr {
+ case ErrClosed:
+ return nil
+ }
+ return fmt.Errorf("unexpected type on 2nd nested level: %T", nestedErr)
+
+third:
+ if isPlatformError(nestedErr) {
+ return nil
+ }
+ return fmt.Errorf("unexpected type on 3rd nested level: %T", nestedErr)
+}
+
+func TestFileError(t *testing.T) {
+ switch runtime.GOOS {
+ case "windows":
+ t.Skipf("not supported on %s", runtime.GOOS)
+ }
+
+ f, err := os.CreateTemp("", "go-nettest")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer os.Remove(f.Name())
+ defer f.Close()
+
+ c, err := FileConn(f)
+ if err != nil {
+ if c != nil {
+ t.Errorf("FileConn returned non-nil interface %T(%v) with err != nil", c, c)
+ }
+ if perr := parseCommonError(err); perr != nil {
+ t.Error(perr)
+ }
+ } else {
+ c.Close()
+ t.Error("should fail")
+ }
+ ln, err := FileListener(f)
+ if err != nil {
+ if ln != nil {
+ t.Errorf("FileListener returned non-nil interface %T(%v) with err != nil", ln, ln)
+ }
+ if perr := parseCommonError(err); perr != nil {
+ t.Error(perr)
+ }
+ } else {
+ ln.Close()
+ t.Error("should fail")
+ }
+ pc, err := FilePacketConn(f)
+ if err != nil {
+ if pc != nil {
+ t.Errorf("FilePacketConn returned non-nil interface %T(%v) with err != nil", pc, pc)
+ }
+ if perr := parseCommonError(err); perr != nil {
+ t.Error(perr)
+ }
+ } else {
+ pc.Close()
+ t.Error("should fail")
+ }
+
+ ln = newLocalListener(t, "tcp")
+
+ for i := 0; i < 3; i++ {
+ f, err := ln.(*TCPListener).File()
+ if err != nil {
+ if perr := parseCommonError(err); perr != nil {
+ t.Error(perr)
+ }
+ } else {
+ f.Close()
+ }
+ ln.Close()
+ }
+}
+
+func parseLookupPortError(nestedErr error) error {
+ if nestedErr == nil {
+ return nil
+ }
+
+ switch nestedErr.(type) {
+ case *AddrError, *DNSError:
+ return nil
+ case *fs.PathError: // for Plan 9
+ return nil
+ }
+ return fmt.Errorf("unexpected type on 1st nested level: %T", nestedErr)
+}
+
+func TestContextError(t *testing.T) {
+ if !errors.Is(errCanceled, context.Canceled) {
+ t.Error("errCanceled is not context.Canceled")
+ }
+ if !errors.Is(errTimeout, context.DeadlineExceeded) {
+ t.Error("errTimeout is not context.DeadlineExceeded")
+ }
+}
diff --git a/src/net/error_unix.go b/src/net/error_unix.go
new file mode 100644
index 0000000..d694867
--- /dev/null
+++ b/src/net/error_unix.go
@@ -0,0 +1,16 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build unix || js || wasip1
+
+package net
+
+import "syscall"
+
+func isConnError(err error) bool {
+ if se, ok := err.(syscall.Errno); ok {
+ return se == syscall.ECONNRESET || se == syscall.ECONNABORTED
+ }
+ return false
+}
diff --git a/src/net/error_unix_test.go b/src/net/error_unix_test.go
new file mode 100644
index 0000000..291a723
--- /dev/null
+++ b/src/net/error_unix_test.go
@@ -0,0 +1,39 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !plan9 && !windows
+
+package net
+
+import (
+ "errors"
+ "os"
+ "syscall"
+)
+
+var (
+ errTimedout = syscall.ETIMEDOUT
+ errOpNotSupported = syscall.EOPNOTSUPP
+
+ abortedConnRequestErrors = []error{syscall.ECONNABORTED} // see accept in fd_unix.go
+)
+
+func isPlatformError(err error) bool {
+ _, ok := err.(syscall.Errno)
+ return ok
+}
+
+func samePlatformError(err, want error) bool {
+ if op, ok := err.(*OpError); ok {
+ err = op.Err
+ }
+ if sys, ok := err.(*os.SyscallError); ok {
+ err = sys.Err
+ }
+ return err == want
+}
+
+func isENOBUFS(err error) bool {
+ return errors.Is(err, syscall.ENOBUFS)
+}
diff --git a/src/net/error_windows.go b/src/net/error_windows.go
new file mode 100644
index 0000000..570b97b
--- /dev/null
+++ b/src/net/error_windows.go
@@ -0,0 +1,14 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import "syscall"
+
+func isConnError(err error) bool {
+ if se, ok := err.(syscall.Errno); ok {
+ return se == syscall.WSAECONNRESET || se == syscall.WSAECONNABORTED
+ }
+ return false
+}
diff --git a/src/net/error_windows_test.go b/src/net/error_windows_test.go
new file mode 100644
index 0000000..25825f9
--- /dev/null
+++ b/src/net/error_windows_test.go
@@ -0,0 +1,29 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "errors"
+ "syscall"
+)
+
+var (
+ errTimedout = syscall.ETIMEDOUT
+ errOpNotSupported = syscall.EOPNOTSUPP
+
+ abortedConnRequestErrors = []error{syscall.ERROR_NETNAME_DELETED, syscall.WSAECONNRESET} // see accept in fd_windows.go
+)
+
+func isPlatformError(err error) bool {
+ _, ok := err.(syscall.Errno)
+ return ok
+}
+
+func isENOBUFS(err error) bool {
+ // syscall.ENOBUFS is a completely made-up value on Windows: we don't expect
+ // a real system call to ever actually return it. However, since it is already
+ // defined in the syscall package we may as well check for it.
+ return errors.Is(err, syscall.ENOBUFS)
+}
diff --git a/src/net/example_test.go b/src/net/example_test.go
new file mode 100644
index 0000000..2c045d7
--- /dev/null
+++ b/src/net/example_test.go
@@ -0,0 +1,387 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net_test
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "log"
+ "net"
+ "time"
+)
+
+func ExampleListener() {
+ // Listen on TCP port 2000 on all available unicast and
+ // anycast IP addresses of the local system.
+ l, err := net.Listen("tcp", ":2000")
+ if err != nil {
+ log.Fatal(err)
+ }
+ defer l.Close()
+ for {
+ // Wait for a connection.
+ conn, err := l.Accept()
+ if err != nil {
+ log.Fatal(err)
+ }
+ // Handle the connection in a new goroutine.
+ // The loop then returns to accepting, so that
+ // multiple connections may be served concurrently.
+ go func(c net.Conn) {
+ // Echo all incoming data.
+ io.Copy(c, c)
+ // Shut down the connection.
+ c.Close()
+ }(conn)
+ }
+}
+
+func ExampleDialer() {
+ var d net.Dialer
+ ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
+ defer cancel()
+
+ conn, err := d.DialContext(ctx, "tcp", "localhost:12345")
+ if err != nil {
+ log.Fatalf("Failed to dial: %v", err)
+ }
+ defer conn.Close()
+
+ if _, err := conn.Write([]byte("Hello, World!")); err != nil {
+ log.Fatal(err)
+ }
+}
+
+func ExampleDialer_unix() {
+ // DialUnix does not take a context.Context parameter. This example shows
+ // how to dial a Unix socket with a Context. Note that the Context only
+ // applies to the dial operation; it does not apply to the connection once
+ // it has been established.
+ var d net.Dialer
+ ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
+ defer cancel()
+
+ d.LocalAddr = nil // if you have a local addr, add it here
+ raddr := net.UnixAddr{Name: "/path/to/unix.sock", Net: "unix"}
+ conn, err := d.DialContext(ctx, "unix", raddr.String())
+ if err != nil {
+ log.Fatalf("Failed to dial: %v", err)
+ }
+ defer conn.Close()
+ if _, err := conn.Write([]byte("Hello, socket!")); err != nil {
+ log.Fatal(err)
+ }
+}
+
+func ExampleIPv4() {
+ fmt.Println(net.IPv4(8, 8, 8, 8))
+
+ // Output:
+ // 8.8.8.8
+}
+
+func ExampleParseCIDR() {
+ ipv4Addr, ipv4Net, err := net.ParseCIDR("192.0.2.1/24")
+ if err != nil {
+ log.Fatal(err)
+ }
+ fmt.Println(ipv4Addr)
+ fmt.Println(ipv4Net)
+
+ ipv6Addr, ipv6Net, err := net.ParseCIDR("2001:db8:a0b:12f0::1/32")
+ if err != nil {
+ log.Fatal(err)
+ }
+ fmt.Println(ipv6Addr)
+ fmt.Println(ipv6Net)
+
+ // Output:
+ // 192.0.2.1
+ // 192.0.2.0/24
+ // 2001:db8:a0b:12f0::1
+ // 2001:db8::/32
+}
+
+func ExampleParseIP() {
+ fmt.Println(net.ParseIP("192.0.2.1"))
+ fmt.Println(net.ParseIP("2001:db8::68"))
+ fmt.Println(net.ParseIP("192.0.2"))
+
+ // Output:
+ // 192.0.2.1
+ // 2001:db8::68
+ // <nil>
+}
+
+func ExampleIP_DefaultMask() {
+ ip := net.ParseIP("192.0.2.1")
+ fmt.Println(ip.DefaultMask())
+
+ // Output:
+ // ffffff00
+}
+
+func ExampleIP_Equal() {
+ ipv4DNS := net.ParseIP("8.8.8.8")
+ ipv4Lo := net.ParseIP("127.0.0.1")
+ ipv6DNS := net.ParseIP("0:0:0:0:0:FFFF:0808:0808")
+
+ fmt.Println(ipv4DNS.Equal(ipv4DNS))
+ fmt.Println(ipv4DNS.Equal(ipv4Lo))
+ fmt.Println(ipv4DNS.Equal(ipv6DNS))
+
+ // Output:
+ // true
+ // false
+ // true
+}
+
+func ExampleIP_IsGlobalUnicast() {
+ ipv6Global := net.ParseIP("2000::")
+ ipv6UniqLocal := net.ParseIP("2000::")
+ ipv6Multi := net.ParseIP("FF00::")
+
+ ipv4Private := net.ParseIP("10.255.0.0")
+ ipv4Public := net.ParseIP("8.8.8.8")
+ ipv4Broadcast := net.ParseIP("255.255.255.255")
+
+ fmt.Println(ipv6Global.IsGlobalUnicast())
+ fmt.Println(ipv6UniqLocal.IsGlobalUnicast())
+ fmt.Println(ipv6Multi.IsGlobalUnicast())
+
+ fmt.Println(ipv4Private.IsGlobalUnicast())
+ fmt.Println(ipv4Public.IsGlobalUnicast())
+ fmt.Println(ipv4Broadcast.IsGlobalUnicast())
+
+ // Output:
+ // true
+ // true
+ // false
+ // true
+ // true
+ // false
+}
+
+func ExampleIP_IsInterfaceLocalMulticast() {
+ ipv6InterfaceLocalMulti := net.ParseIP("ff01::1")
+ ipv6Global := net.ParseIP("2000::")
+ ipv4 := net.ParseIP("255.0.0.0")
+
+ fmt.Println(ipv6InterfaceLocalMulti.IsInterfaceLocalMulticast())
+ fmt.Println(ipv6Global.IsInterfaceLocalMulticast())
+ fmt.Println(ipv4.IsInterfaceLocalMulticast())
+
+ // Output:
+ // true
+ // false
+ // false
+}
+
+func ExampleIP_IsLinkLocalMulticast() {
+ ipv6LinkLocalMulti := net.ParseIP("ff02::2")
+ ipv6LinkLocalUni := net.ParseIP("fe80::")
+ ipv4LinkLocalMulti := net.ParseIP("224.0.0.0")
+ ipv4LinkLocalUni := net.ParseIP("169.254.0.0")
+
+ fmt.Println(ipv6LinkLocalMulti.IsLinkLocalMulticast())
+ fmt.Println(ipv6LinkLocalUni.IsLinkLocalMulticast())
+ fmt.Println(ipv4LinkLocalMulti.IsLinkLocalMulticast())
+ fmt.Println(ipv4LinkLocalUni.IsLinkLocalMulticast())
+
+ // Output:
+ // true
+ // false
+ // true
+ // false
+}
+
+func ExampleIP_IsLinkLocalUnicast() {
+ ipv6LinkLocalUni := net.ParseIP("fe80::")
+ ipv6Global := net.ParseIP("2000::")
+ ipv4LinkLocalUni := net.ParseIP("169.254.0.0")
+ ipv4LinkLocalMulti := net.ParseIP("224.0.0.0")
+
+ fmt.Println(ipv6LinkLocalUni.IsLinkLocalUnicast())
+ fmt.Println(ipv6Global.IsLinkLocalUnicast())
+ fmt.Println(ipv4LinkLocalUni.IsLinkLocalUnicast())
+ fmt.Println(ipv4LinkLocalMulti.IsLinkLocalUnicast())
+
+ // Output:
+ // true
+ // false
+ // true
+ // false
+}
+
+func ExampleIP_IsLoopback() {
+ ipv6Lo := net.ParseIP("::1")
+ ipv6 := net.ParseIP("ff02::1")
+ ipv4Lo := net.ParseIP("127.0.0.0")
+ ipv4 := net.ParseIP("128.0.0.0")
+
+ fmt.Println(ipv6Lo.IsLoopback())
+ fmt.Println(ipv6.IsLoopback())
+ fmt.Println(ipv4Lo.IsLoopback())
+ fmt.Println(ipv4.IsLoopback())
+
+ // Output:
+ // true
+ // false
+ // true
+ // false
+}
+
+func ExampleIP_IsMulticast() {
+ ipv6Multi := net.ParseIP("FF00::")
+ ipv6LinkLocalMulti := net.ParseIP("ff02::1")
+ ipv6Lo := net.ParseIP("::1")
+ ipv4Multi := net.ParseIP("239.0.0.0")
+ ipv4LinkLocalMulti := net.ParseIP("224.0.0.0")
+ ipv4Lo := net.ParseIP("127.0.0.0")
+
+ fmt.Println(ipv6Multi.IsMulticast())
+ fmt.Println(ipv6LinkLocalMulti.IsMulticast())
+ fmt.Println(ipv6Lo.IsMulticast())
+ fmt.Println(ipv4Multi.IsMulticast())
+ fmt.Println(ipv4LinkLocalMulti.IsMulticast())
+ fmt.Println(ipv4Lo.IsMulticast())
+
+ // Output:
+ // true
+ // true
+ // false
+ // true
+ // true
+ // false
+}
+
+func ExampleIP_IsPrivate() {
+ ipv6Private := net.ParseIP("fc00::")
+ ipv6Public := net.ParseIP("fe00::")
+ ipv4Private := net.ParseIP("10.255.0.0")
+ ipv4Public := net.ParseIP("11.0.0.0")
+
+ fmt.Println(ipv6Private.IsPrivate())
+ fmt.Println(ipv6Public.IsPrivate())
+ fmt.Println(ipv4Private.IsPrivate())
+ fmt.Println(ipv4Public.IsPrivate())
+
+ // Output:
+ // true
+ // false
+ // true
+ // false
+}
+
+func ExampleIP_IsUnspecified() {
+ ipv6Unspecified := net.ParseIP("::")
+ ipv6Specified := net.ParseIP("fe00::")
+ ipv4Unspecified := net.ParseIP("0.0.0.0")
+ ipv4Specified := net.ParseIP("8.8.8.8")
+
+ fmt.Println(ipv6Unspecified.IsUnspecified())
+ fmt.Println(ipv6Specified.IsUnspecified())
+ fmt.Println(ipv4Unspecified.IsUnspecified())
+ fmt.Println(ipv4Specified.IsUnspecified())
+
+ // Output:
+ // true
+ // false
+ // true
+ // false
+}
+
+func ExampleIP_Mask() {
+ ipv4Addr := net.ParseIP("192.0.2.1")
+ // This mask corresponds to a /24 subnet for IPv4.
+ ipv4Mask := net.CIDRMask(24, 32)
+ fmt.Println(ipv4Addr.Mask(ipv4Mask))
+
+ ipv6Addr := net.ParseIP("2001:db8:a0b:12f0::1")
+ // This mask corresponds to a /32 subnet for IPv6.
+ ipv6Mask := net.CIDRMask(32, 128)
+ fmt.Println(ipv6Addr.Mask(ipv6Mask))
+
+ // Output:
+ // 192.0.2.0
+ // 2001:db8::
+}
+
+func ExampleIP_String() {
+ ipv6 := net.IP{0xfc, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
+ ipv4 := net.IPv4(10, 255, 0, 0)
+
+ fmt.Println(ipv6.String())
+ fmt.Println(ipv4.String())
+
+ // Output:
+ // fc00::
+ // 10.255.0.0
+}
+
+func ExampleIP_To16() {
+ ipv6 := net.IP{0xfc, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
+ ipv4 := net.IPv4(10, 255, 0, 0)
+
+ fmt.Println(ipv6.To16())
+ fmt.Println(ipv4.To16())
+
+ // Output:
+ // fc00::
+ // 10.255.0.0
+}
+
+func ExampleIP_to4() {
+ ipv6 := net.IP{0xfc, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
+ ipv4 := net.IPv4(10, 255, 0, 0)
+
+ fmt.Println(ipv6.To4())
+ fmt.Println(ipv4.To4())
+
+ // Output:
+ // <nil>
+ // 10.255.0.0
+}
+
+func ExampleCIDRMask() {
+ // This mask corresponds to a /31 subnet for IPv4.
+ fmt.Println(net.CIDRMask(31, 32))
+
+ // This mask corresponds to a /64 subnet for IPv6.
+ fmt.Println(net.CIDRMask(64, 128))
+
+ // Output:
+ // fffffffe
+ // ffffffffffffffff0000000000000000
+}
+
+func ExampleIPv4Mask() {
+ fmt.Println(net.IPv4Mask(255, 255, 255, 0))
+
+ // Output:
+ // ffffff00
+}
+
+func ExampleUDPConn_WriteTo() {
+ // Unlike Dial, ListenPacket creates a connection without any
+ // association with peers.
+ conn, err := net.ListenPacket("udp", ":0")
+ if err != nil {
+ log.Fatal(err)
+ }
+ defer conn.Close()
+
+ dst, err := net.ResolveUDPAddr("udp", "192.0.2.1:2000")
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ // The connection can write data to the desired address.
+ _, err = conn.WriteTo([]byte("data"), dst)
+ if err != nil {
+ log.Fatal(err)
+ }
+}
diff --git a/src/net/external_test.go b/src/net/external_test.go
new file mode 100644
index 0000000..0709b9d
--- /dev/null
+++ b/src/net/external_test.go
@@ -0,0 +1,168 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !js && !wasip1
+
+package net
+
+import (
+ "fmt"
+ "internal/testenv"
+ "io"
+ "strings"
+ "testing"
+)
+
+func TestResolveGoogle(t *testing.T) {
+ testenv.MustHaveExternalNetwork(t)
+
+ if !supportsIPv4() || !supportsIPv6() || !*testIPv4 || !*testIPv6 {
+ t.Skip("both IPv4 and IPv6 are required")
+ }
+
+ for _, network := range []string{"tcp", "tcp4", "tcp6"} {
+ addr, err := ResolveTCPAddr(network, "www.google.com:http")
+ if err != nil {
+ t.Error(err)
+ continue
+ }
+ switch {
+ case network == "tcp" && addr.IP.To4() == nil:
+ fallthrough
+ case network == "tcp4" && addr.IP.To4() == nil:
+ t.Errorf("got %v; want an IPv4 address on %s", addr, network)
+ case network == "tcp6" && (addr.IP.To16() == nil || addr.IP.To4() != nil):
+ t.Errorf("got %v; want an IPv6 address on %s", addr, network)
+ }
+ }
+}
+
+var dialGoogleTests = []struct {
+ dial func(string, string) (Conn, error)
+ unreachableNetwork string
+ networks []string
+ addrs []string
+}{
+ {
+ dial: (&Dialer{DualStack: true}).Dial,
+ networks: []string{"tcp", "tcp4", "tcp6"},
+ addrs: []string{"www.google.com:http"},
+ },
+ {
+ dial: Dial,
+ unreachableNetwork: "tcp6",
+ networks: []string{"tcp", "tcp4"},
+ },
+ {
+ dial: Dial,
+ unreachableNetwork: "tcp4",
+ networks: []string{"tcp", "tcp6"},
+ },
+}
+
+func TestDialGoogle(t *testing.T) {
+ testenv.MustHaveExternalNetwork(t)
+
+ if !supportsIPv4() || !supportsIPv6() || !*testIPv4 || !*testIPv6 {
+ t.Skip("both IPv4 and IPv6 are required")
+ }
+
+ var err error
+ dialGoogleTests[1].addrs, dialGoogleTests[2].addrs, err = googleLiteralAddrs()
+ if err != nil {
+ t.Error(err)
+ }
+ for _, tt := range dialGoogleTests {
+ for _, network := range tt.networks {
+ disableSocketConnect(tt.unreachableNetwork)
+ for _, addr := range tt.addrs {
+ if err := fetchGoogle(tt.dial, network, addr); err != nil {
+ t.Error(err)
+ }
+ }
+ enableSocketConnect()
+ }
+ }
+}
+
+var (
+ literalAddrs4 = [...]string{
+ "%d.%d.%d.%d:80",
+ "www.google.com:80",
+ "%d.%d.%d.%d:http",
+ "www.google.com:http",
+ "%03d.%03d.%03d.%03d:0080",
+ "[::ffff:%d.%d.%d.%d]:80",
+ "[::ffff:%02x%02x:%02x%02x]:80",
+ "[0:0:0:0:0000:ffff:%d.%d.%d.%d]:80",
+ "[0:0:0:0:000000:ffff:%d.%d.%d.%d]:80",
+ "[0:0:0:0::ffff:%d.%d.%d.%d]:80",
+ }
+ literalAddrs6 = [...]string{
+ "[%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x]:80",
+ "ipv6.google.com:80",
+ "[%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x]:http",
+ "ipv6.google.com:http",
+ }
+)
+
+func googleLiteralAddrs() (lits4, lits6 []string, err error) {
+ ips, err := LookupIP("www.google.com")
+ if err != nil {
+ return nil, nil, err
+ }
+ if len(ips) == 0 {
+ return nil, nil, nil
+ }
+ var ip4, ip6 IP
+ for _, ip := range ips {
+ if ip4 == nil && ip.To4() != nil {
+ ip4 = ip.To4()
+ }
+ if ip6 == nil && ip.To16() != nil && ip.To4() == nil {
+ ip6 = ip.To16()
+ }
+ if ip4 != nil && ip6 != nil {
+ break
+ }
+ }
+ if ip4 != nil {
+ for i, lit4 := range literalAddrs4 {
+ if strings.Contains(lit4, "%") {
+ literalAddrs4[i] = fmt.Sprintf(lit4, ip4[0], ip4[1], ip4[2], ip4[3])
+ }
+ }
+ lits4 = literalAddrs4[:]
+ }
+ if ip6 != nil {
+ for i, lit6 := range literalAddrs6 {
+ if strings.Contains(lit6, "%") {
+ literalAddrs6[i] = fmt.Sprintf(lit6, ip6[0], ip6[1], ip6[2], ip6[3], ip6[4], ip6[5], ip6[6], ip6[7], ip6[8], ip6[9], ip6[10], ip6[11], ip6[12], ip6[13], ip6[14], ip6[15])
+ }
+ }
+ lits6 = literalAddrs6[:]
+ }
+ return
+}
+
+func fetchGoogle(dial func(string, string) (Conn, error), network, address string) error {
+ c, err := dial(network, address)
+ if err != nil {
+ return err
+ }
+ defer c.Close()
+ req := []byte("GET /robots.txt HTTP/1.0\r\nHost: www.google.com\r\n\r\n")
+ if _, err := c.Write(req); err != nil {
+ return err
+ }
+ b := make([]byte, 1000)
+ n, err := io.ReadFull(c, b)
+ if err != nil {
+ return err
+ }
+ if n < 1000 {
+ return fmt.Errorf("short read from %s:%s->%s", network, c.RemoteAddr(), c.LocalAddr())
+ }
+ return nil
+}
diff --git a/src/net/fd_plan9.go b/src/net/fd_plan9.go
new file mode 100644
index 0000000..da41bc0
--- /dev/null
+++ b/src/net/fd_plan9.go
@@ -0,0 +1,187 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "internal/poll"
+ "io"
+ "os"
+ "syscall"
+ "time"
+)
+
+// Network file descriptor.
+type netFD struct {
+ pfd poll.FD
+
+ // immutable until Close
+ net string
+ n string
+ dir string
+ listen, ctl, data *os.File
+ laddr, raddr Addr
+ isStream bool
+}
+
+var netdir = "/net" // default network
+
+func newFD(net, name string, listen, ctl, data *os.File, laddr, raddr Addr) (*netFD, error) {
+ ret := &netFD{
+ net: net,
+ n: name,
+ dir: netdir + "/" + net + "/" + name,
+ listen: listen,
+ ctl: ctl, data: data,
+ laddr: laddr,
+ raddr: raddr,
+ }
+ ret.pfd.Destroy = ret.destroy
+ return ret, nil
+}
+
+func (fd *netFD) init() error {
+ // stub for future fd.pd.Init(fd)
+ return nil
+}
+
+func (fd *netFD) name() string {
+ var ls, rs string
+ if fd.laddr != nil {
+ ls = fd.laddr.String()
+ }
+ if fd.raddr != nil {
+ rs = fd.raddr.String()
+ }
+ return fd.net + ":" + ls + "->" + rs
+}
+
+func (fd *netFD) ok() bool { return fd != nil && fd.ctl != nil }
+
+func (fd *netFD) destroy() {
+ if !fd.ok() {
+ return
+ }
+ err := fd.ctl.Close()
+ if fd.data != nil {
+ if err1 := fd.data.Close(); err1 != nil && err == nil {
+ err = err1
+ }
+ }
+ if fd.listen != nil {
+ if err1 := fd.listen.Close(); err1 != nil && err == nil {
+ err = err1
+ }
+ }
+ fd.ctl = nil
+ fd.data = nil
+ fd.listen = nil
+}
+
+func (fd *netFD) Read(b []byte) (n int, err error) {
+ if !fd.ok() || fd.data == nil {
+ return 0, syscall.EINVAL
+ }
+ n, err = fd.pfd.Read(fd.data.Read, b)
+ if fd.net == "udp" && err == io.EOF {
+ n = 0
+ err = nil
+ }
+ return
+}
+
+func (fd *netFD) Write(b []byte) (n int, err error) {
+ if !fd.ok() || fd.data == nil {
+ return 0, syscall.EINVAL
+ }
+ return fd.pfd.Write(fd.data.Write, b)
+}
+
+func (fd *netFD) closeRead() error {
+ if !fd.ok() {
+ return syscall.EINVAL
+ }
+ return syscall.EPLAN9
+}
+
+func (fd *netFD) closeWrite() error {
+ if !fd.ok() {
+ return syscall.EINVAL
+ }
+ return syscall.EPLAN9
+}
+
+func (fd *netFD) Close() error {
+ if err := fd.pfd.Close(); err != nil {
+ return err
+ }
+ if !fd.ok() {
+ return syscall.EINVAL
+ }
+ if fd.net == "tcp" {
+ // The following line is required to unblock Reads.
+ _, err := fd.ctl.WriteString("close")
+ if err != nil {
+ return err
+ }
+ }
+ err := fd.ctl.Close()
+ if fd.data != nil {
+ if err1 := fd.data.Close(); err1 != nil && err == nil {
+ err = err1
+ }
+ }
+ if fd.listen != nil {
+ if err1 := fd.listen.Close(); err1 != nil && err == nil {
+ err = err1
+ }
+ }
+ fd.ctl = nil
+ fd.data = nil
+ fd.listen = nil
+ return err
+}
+
+// This method is only called via Conn.
+func (fd *netFD) dup() (*os.File, error) {
+ if !fd.ok() || fd.data == nil {
+ return nil, syscall.EINVAL
+ }
+ return fd.file(fd.data, fd.dir+"/data")
+}
+
+func (l *TCPListener) dup() (*os.File, error) {
+ if !l.fd.ok() {
+ return nil, syscall.EINVAL
+ }
+ return l.fd.file(l.fd.ctl, l.fd.dir+"/ctl")
+}
+
+func (fd *netFD) file(f *os.File, s string) (*os.File, error) {
+ dfd, err := syscall.Dup(int(f.Fd()), -1)
+ if err != nil {
+ return nil, os.NewSyscallError("dup", err)
+ }
+ return os.NewFile(uintptr(dfd), s), nil
+}
+
+func setReadBuffer(fd *netFD, bytes int) error {
+ return syscall.EPLAN9
+}
+
+func setWriteBuffer(fd *netFD, bytes int) error {
+ return syscall.EPLAN9
+}
+
+func (fd *netFD) SetDeadline(t time.Time) error {
+ return fd.pfd.SetDeadline(t)
+}
+
+func (fd *netFD) SetReadDeadline(t time.Time) error {
+ return fd.pfd.SetReadDeadline(t)
+}
+
+func (fd *netFD) SetWriteDeadline(t time.Time) error {
+ return fd.pfd.SetWriteDeadline(t)
+}
diff --git a/src/net/fd_posix.go b/src/net/fd_posix.go
new file mode 100644
index 0000000..ffb9bcf
--- /dev/null
+++ b/src/net/fd_posix.go
@@ -0,0 +1,147 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build unix || windows
+
+package net
+
+import (
+ "internal/poll"
+ "runtime"
+ "syscall"
+ "time"
+)
+
+// Network file descriptor.
+type netFD struct {
+ pfd poll.FD
+
+ // immutable until Close
+ family int
+ sotype int
+ isConnected bool // handshake completed or use of association with peer
+ net string
+ laddr Addr
+ raddr Addr
+}
+
+func (fd *netFD) setAddr(laddr, raddr Addr) {
+ fd.laddr = laddr
+ fd.raddr = raddr
+ runtime.SetFinalizer(fd, (*netFD).Close)
+}
+
+func (fd *netFD) Close() error {
+ runtime.SetFinalizer(fd, nil)
+ return fd.pfd.Close()
+}
+
+func (fd *netFD) shutdown(how int) error {
+ err := fd.pfd.Shutdown(how)
+ runtime.KeepAlive(fd)
+ return wrapSyscallError("shutdown", err)
+}
+
+func (fd *netFD) closeRead() error {
+ return fd.shutdown(syscall.SHUT_RD)
+}
+
+func (fd *netFD) closeWrite() error {
+ return fd.shutdown(syscall.SHUT_WR)
+}
+
+func (fd *netFD) Read(p []byte) (n int, err error) {
+ n, err = fd.pfd.Read(p)
+ runtime.KeepAlive(fd)
+ return n, wrapSyscallError(readSyscallName, err)
+}
+
+func (fd *netFD) readFrom(p []byte) (n int, sa syscall.Sockaddr, err error) {
+ n, sa, err = fd.pfd.ReadFrom(p)
+ runtime.KeepAlive(fd)
+ return n, sa, wrapSyscallError(readFromSyscallName, err)
+}
+func (fd *netFD) readFromInet4(p []byte, from *syscall.SockaddrInet4) (n int, err error) {
+ n, err = fd.pfd.ReadFromInet4(p, from)
+ runtime.KeepAlive(fd)
+ return n, wrapSyscallError(readFromSyscallName, err)
+}
+
+func (fd *netFD) readFromInet6(p []byte, from *syscall.SockaddrInet6) (n int, err error) {
+ n, err = fd.pfd.ReadFromInet6(p, from)
+ runtime.KeepAlive(fd)
+ return n, wrapSyscallError(readFromSyscallName, err)
+}
+
+func (fd *netFD) readMsg(p []byte, oob []byte, flags int) (n, oobn, retflags int, sa syscall.Sockaddr, err error) {
+ n, oobn, retflags, sa, err = fd.pfd.ReadMsg(p, oob, flags)
+ runtime.KeepAlive(fd)
+ return n, oobn, retflags, sa, wrapSyscallError(readMsgSyscallName, err)
+}
+
+func (fd *netFD) readMsgInet4(p []byte, oob []byte, flags int, sa *syscall.SockaddrInet4) (n, oobn, retflags int, err error) {
+ n, oobn, retflags, err = fd.pfd.ReadMsgInet4(p, oob, flags, sa)
+ runtime.KeepAlive(fd)
+ return n, oobn, retflags, wrapSyscallError(readMsgSyscallName, err)
+}
+
+func (fd *netFD) readMsgInet6(p []byte, oob []byte, flags int, sa *syscall.SockaddrInet6) (n, oobn, retflags int, err error) {
+ n, oobn, retflags, err = fd.pfd.ReadMsgInet6(p, oob, flags, sa)
+ runtime.KeepAlive(fd)
+ return n, oobn, retflags, wrapSyscallError(readMsgSyscallName, err)
+}
+
+func (fd *netFD) Write(p []byte) (nn int, err error) {
+ nn, err = fd.pfd.Write(p)
+ runtime.KeepAlive(fd)
+ return nn, wrapSyscallError(writeSyscallName, err)
+}
+
+func (fd *netFD) writeTo(p []byte, sa syscall.Sockaddr) (n int, err error) {
+ n, err = fd.pfd.WriteTo(p, sa)
+ runtime.KeepAlive(fd)
+ return n, wrapSyscallError(writeToSyscallName, err)
+}
+
+func (fd *netFD) writeToInet4(p []byte, sa *syscall.SockaddrInet4) (n int, err error) {
+ n, err = fd.pfd.WriteToInet4(p, sa)
+ runtime.KeepAlive(fd)
+ return n, wrapSyscallError(writeToSyscallName, err)
+}
+
+func (fd *netFD) writeToInet6(p []byte, sa *syscall.SockaddrInet6) (n int, err error) {
+ n, err = fd.pfd.WriteToInet6(p, sa)
+ runtime.KeepAlive(fd)
+ return n, wrapSyscallError(writeToSyscallName, err)
+}
+
+func (fd *netFD) writeMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err error) {
+ n, oobn, err = fd.pfd.WriteMsg(p, oob, sa)
+ runtime.KeepAlive(fd)
+ return n, oobn, wrapSyscallError(writeMsgSyscallName, err)
+}
+
+func (fd *netFD) writeMsgInet4(p []byte, oob []byte, sa *syscall.SockaddrInet4) (n int, oobn int, err error) {
+ n, oobn, err = fd.pfd.WriteMsgInet4(p, oob, sa)
+ runtime.KeepAlive(fd)
+ return n, oobn, wrapSyscallError(writeMsgSyscallName, err)
+}
+
+func (fd *netFD) writeMsgInet6(p []byte, oob []byte, sa *syscall.SockaddrInet6) (n int, oobn int, err error) {
+ n, oobn, err = fd.pfd.WriteMsgInet6(p, oob, sa)
+ runtime.KeepAlive(fd)
+ return n, oobn, wrapSyscallError(writeMsgSyscallName, err)
+}
+
+func (fd *netFD) SetDeadline(t time.Time) error {
+ return fd.pfd.SetDeadline(t)
+}
+
+func (fd *netFD) SetReadDeadline(t time.Time) error {
+ return fd.pfd.SetReadDeadline(t)
+}
+
+func (fd *netFD) SetWriteDeadline(t time.Time) error {
+ return fd.pfd.SetWriteDeadline(t)
+}
diff --git a/src/net/fd_unix.go b/src/net/fd_unix.go
new file mode 100644
index 0000000..a8d3a25
--- /dev/null
+++ b/src/net/fd_unix.go
@@ -0,0 +1,206 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build unix
+
+package net
+
+import (
+ "context"
+ "internal/poll"
+ "os"
+ "runtime"
+ "syscall"
+)
+
+const (
+ readSyscallName = "read"
+ readFromSyscallName = "recvfrom"
+ readMsgSyscallName = "recvmsg"
+ writeSyscallName = "write"
+ writeToSyscallName = "sendto"
+ writeMsgSyscallName = "sendmsg"
+)
+
+func newFD(sysfd, family, sotype int, net string) (*netFD, error) {
+ ret := &netFD{
+ pfd: poll.FD{
+ Sysfd: sysfd,
+ IsStream: sotype == syscall.SOCK_STREAM,
+ ZeroReadIsEOF: sotype != syscall.SOCK_DGRAM && sotype != syscall.SOCK_RAW,
+ },
+ family: family,
+ sotype: sotype,
+ net: net,
+ }
+ return ret, nil
+}
+
+func (fd *netFD) init() error {
+ return fd.pfd.Init(fd.net, true)
+}
+
+func (fd *netFD) name() string {
+ var ls, rs string
+ if fd.laddr != nil {
+ ls = fd.laddr.String()
+ }
+ if fd.raddr != nil {
+ rs = fd.raddr.String()
+ }
+ return fd.net + ":" + ls + "->" + rs
+}
+
+func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) (rsa syscall.Sockaddr, ret error) {
+ // Do not need to call fd.writeLock here,
+ // because fd is not yet accessible to user,
+ // so no concurrent operations are possible.
+ switch err := connectFunc(fd.pfd.Sysfd, ra); err {
+ case syscall.EINPROGRESS, syscall.EALREADY, syscall.EINTR:
+ case nil, syscall.EISCONN:
+ select {
+ case <-ctx.Done():
+ return nil, mapErr(ctx.Err())
+ default:
+ }
+ if err := fd.pfd.Init(fd.net, true); err != nil {
+ return nil, err
+ }
+ runtime.KeepAlive(fd)
+ return nil, nil
+ case syscall.EINVAL:
+ // On Solaris and illumos we can see EINVAL if the socket has
+ // already been accepted and closed by the server. Treat this
+ // as a successful connection--writes to the socket will see
+ // EOF. For details and a test case in C see
+ // https://golang.org/issue/6828.
+ if runtime.GOOS == "solaris" || runtime.GOOS == "illumos" {
+ return nil, nil
+ }
+ fallthrough
+ default:
+ return nil, os.NewSyscallError("connect", err)
+ }
+ if err := fd.pfd.Init(fd.net, true); err != nil {
+ return nil, err
+ }
+ if deadline, hasDeadline := ctx.Deadline(); hasDeadline {
+ fd.pfd.SetWriteDeadline(deadline)
+ defer fd.pfd.SetWriteDeadline(noDeadline)
+ }
+
+ // Start the "interrupter" goroutine, if this context might be canceled.
+ //
+ // The interrupter goroutine waits for the context to be done and
+ // interrupts the dial (by altering the fd's write deadline, which
+ // wakes up waitWrite).
+ ctxDone := ctx.Done()
+ if ctxDone != nil {
+ // Wait for the interrupter goroutine to exit before returning
+ // from connect.
+ done := make(chan struct{})
+ interruptRes := make(chan error)
+ defer func() {
+ close(done)
+ if ctxErr := <-interruptRes; ctxErr != nil && ret == nil {
+ // The interrupter goroutine called SetWriteDeadline,
+ // but the connect code below had returned from
+ // waitWrite already and did a successful connect (ret
+ // == nil). Because we've now poisoned the connection
+ // by making it unwritable, don't return a successful
+ // dial. This was issue 16523.
+ ret = mapErr(ctxErr)
+ fd.Close() // prevent a leak
+ }
+ }()
+ go func() {
+ select {
+ case <-ctxDone:
+ // Force the runtime's poller to immediately give up
+ // waiting for writability, unblocking waitWrite
+ // below.
+ fd.pfd.SetWriteDeadline(aLongTimeAgo)
+ testHookCanceledDial()
+ interruptRes <- ctx.Err()
+ case <-done:
+ interruptRes <- nil
+ }
+ }()
+ }
+
+ for {
+ // Performing multiple connect system calls on a
+ // non-blocking socket under Unix variants does not
+ // necessarily result in earlier errors being
+ // returned. Instead, once runtime-integrated network
+ // poller tells us that the socket is ready, get the
+ // SO_ERROR socket option to see if the connection
+ // succeeded or failed. See issue 7474 for further
+ // details.
+ if err := fd.pfd.WaitWrite(); err != nil {
+ select {
+ case <-ctxDone:
+ return nil, mapErr(ctx.Err())
+ default:
+ }
+ return nil, err
+ }
+ nerr, err := getsockoptIntFunc(fd.pfd.Sysfd, syscall.SOL_SOCKET, syscall.SO_ERROR)
+ if err != nil {
+ return nil, os.NewSyscallError("getsockopt", err)
+ }
+ switch err := syscall.Errno(nerr); err {
+ case syscall.EINPROGRESS, syscall.EALREADY, syscall.EINTR:
+ case syscall.EISCONN:
+ return nil, nil
+ case syscall.Errno(0):
+ // The runtime poller can wake us up spuriously;
+ // see issues 14548 and 19289. Check that we are
+ // really connected; if not, wait again.
+ if rsa, err := syscall.Getpeername(fd.pfd.Sysfd); err == nil {
+ return rsa, nil
+ }
+ default:
+ return nil, os.NewSyscallError("connect", err)
+ }
+ runtime.KeepAlive(fd)
+ }
+}
+
+func (fd *netFD) accept() (netfd *netFD, err error) {
+ d, rsa, errcall, err := fd.pfd.Accept()
+ if err != nil {
+ if errcall != "" {
+ err = wrapSyscallError(errcall, err)
+ }
+ return nil, err
+ }
+
+ if netfd, err = newFD(d, fd.family, fd.sotype, fd.net); err != nil {
+ poll.CloseFunc(d)
+ return nil, err
+ }
+ if err = netfd.init(); err != nil {
+ netfd.Close()
+ return nil, err
+ }
+ lsa, _ := syscall.Getsockname(netfd.pfd.Sysfd)
+ netfd.setAddr(netfd.addrFunc()(lsa), netfd.addrFunc()(rsa))
+ return netfd, nil
+}
+
+// Defined in os package.
+func newUnixFile(fd int, name string) *os.File
+
+func (fd *netFD) dup() (f *os.File, err error) {
+ ns, call, err := fd.pfd.Dup()
+ if err != nil {
+ if call != "" {
+ err = os.NewSyscallError(call, err)
+ }
+ return nil, err
+ }
+
+ return newUnixFile(ns, fd.name()), nil
+}
diff --git a/src/net/fd_wasip1.go b/src/net/fd_wasip1.go
new file mode 100644
index 0000000..74d0b0b
--- /dev/null
+++ b/src/net/fd_wasip1.go
@@ -0,0 +1,184 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build wasip1
+
+package net
+
+import (
+ "internal/poll"
+ "runtime"
+ "syscall"
+ "time"
+)
+
+const (
+ readSyscallName = "fd_read"
+ writeSyscallName = "fd_write"
+)
+
+// Network file descriptor.
+type netFD struct {
+ pfd poll.FD
+
+ // immutable until Close
+ family int
+ sotype int
+ isConnected bool // handshake completed or use of association with peer
+ net string
+ laddr Addr
+ raddr Addr
+
+ // The only networking available in WASI preview 1 is the ability to
+ // sock_accept on an pre-opened socket, and then fd_read, fd_write,
+ // fd_close, and sock_shutdown on the resulting connection. We
+ // intercept applicable netFD calls on this instance, and then pass
+ // the remainder of the netFD calls to fakeNetFD.
+ *fakeNetFD
+}
+
+func newFD(net string, sysfd int) *netFD {
+ return newPollFD(net, poll.FD{
+ Sysfd: sysfd,
+ IsStream: true,
+ ZeroReadIsEOF: true,
+ })
+}
+
+func newPollFD(net string, pfd poll.FD) *netFD {
+ var laddr Addr
+ var raddr Addr
+ // WASI preview 1 does not have functions like getsockname/getpeername,
+ // so we cannot get access to the underlying IP address used by connections.
+ //
+ // However, listeners created by FileListener are of type *TCPListener,
+ // which can be asserted by a Go program. The (*TCPListener).Addr method
+ // documents that the returned value will be of type *TCPAddr, we satisfy
+ // the documented behavior by creating addresses of the expected type here.
+ switch net {
+ case "tcp":
+ laddr = new(TCPAddr)
+ raddr = new(TCPAddr)
+ case "udp":
+ laddr = new(UDPAddr)
+ raddr = new(UDPAddr)
+ default:
+ laddr = unknownAddr{}
+ raddr = unknownAddr{}
+ }
+ return &netFD{
+ pfd: pfd,
+ net: net,
+ laddr: laddr,
+ raddr: raddr,
+ }
+}
+
+func (fd *netFD) init() error {
+ return fd.pfd.Init(fd.net, true)
+}
+
+func (fd *netFD) name() string {
+ return "unknown"
+}
+
+func (fd *netFD) accept() (netfd *netFD, err error) {
+ if fd.fakeNetFD != nil {
+ return fd.fakeNetFD.accept()
+ }
+ d, _, errcall, err := fd.pfd.Accept()
+ if err != nil {
+ if errcall != "" {
+ err = wrapSyscallError(errcall, err)
+ }
+ return nil, err
+ }
+ netfd = newFD("tcp", d)
+ if err = netfd.init(); err != nil {
+ netfd.Close()
+ return nil, err
+ }
+ return netfd, nil
+}
+
+func (fd *netFD) setAddr(laddr, raddr Addr) {
+ fd.laddr = laddr
+ fd.raddr = raddr
+ runtime.SetFinalizer(fd, (*netFD).Close)
+}
+
+func (fd *netFD) Close() error {
+ if fd.fakeNetFD != nil {
+ return fd.fakeNetFD.Close()
+ }
+ runtime.SetFinalizer(fd, nil)
+ return fd.pfd.Close()
+}
+
+func (fd *netFD) shutdown(how int) error {
+ if fd.fakeNetFD != nil {
+ return nil
+ }
+ err := fd.pfd.Shutdown(how)
+ runtime.KeepAlive(fd)
+ return wrapSyscallError("shutdown", err)
+}
+
+func (fd *netFD) closeRead() error {
+ if fd.fakeNetFD != nil {
+ return fd.fakeNetFD.closeRead()
+ }
+ return fd.shutdown(syscall.SHUT_RD)
+}
+
+func (fd *netFD) closeWrite() error {
+ if fd.fakeNetFD != nil {
+ return fd.fakeNetFD.closeWrite()
+ }
+ return fd.shutdown(syscall.SHUT_WR)
+}
+
+func (fd *netFD) Read(p []byte) (n int, err error) {
+ if fd.fakeNetFD != nil {
+ return fd.fakeNetFD.Read(p)
+ }
+ n, err = fd.pfd.Read(p)
+ runtime.KeepAlive(fd)
+ return n, wrapSyscallError(readSyscallName, err)
+}
+
+func (fd *netFD) Write(p []byte) (nn int, err error) {
+ if fd.fakeNetFD != nil {
+ return fd.fakeNetFD.Write(p)
+ }
+ nn, err = fd.pfd.Write(p)
+ runtime.KeepAlive(fd)
+ return nn, wrapSyscallError(writeSyscallName, err)
+}
+
+func (fd *netFD) SetDeadline(t time.Time) error {
+ if fd.fakeNetFD != nil {
+ return fd.fakeNetFD.SetDeadline(t)
+ }
+ return fd.pfd.SetDeadline(t)
+}
+
+func (fd *netFD) SetReadDeadline(t time.Time) error {
+ if fd.fakeNetFD != nil {
+ return fd.fakeNetFD.SetReadDeadline(t)
+ }
+ return fd.pfd.SetReadDeadline(t)
+}
+
+func (fd *netFD) SetWriteDeadline(t time.Time) error {
+ if fd.fakeNetFD != nil {
+ return fd.fakeNetFD.SetWriteDeadline(t)
+ }
+ return fd.pfd.SetWriteDeadline(t)
+}
+
+type unknownAddr struct{}
+
+func (unknownAddr) Network() string { return "unknown" }
+func (unknownAddr) String() string { return "unknown" }
diff --git a/src/net/fd_windows.go b/src/net/fd_windows.go
new file mode 100644
index 0000000..eeb994d
--- /dev/null
+++ b/src/net/fd_windows.go
@@ -0,0 +1,205 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "context"
+ "internal/poll"
+ "internal/syscall/windows"
+ "os"
+ "runtime"
+ "syscall"
+ "unsafe"
+)
+
+const (
+ readSyscallName = "wsarecv"
+ readFromSyscallName = "wsarecvfrom"
+ readMsgSyscallName = "wsarecvmsg"
+ writeSyscallName = "wsasend"
+ writeToSyscallName = "wsasendto"
+ writeMsgSyscallName = "wsasendmsg"
+)
+
+// canUseConnectEx reports whether we can use the ConnectEx Windows API call
+// for the given network type.
+func canUseConnectEx(net string) bool {
+ switch net {
+ case "tcp", "tcp4", "tcp6":
+ return true
+ }
+ // ConnectEx windows API does not support connectionless sockets.
+ return false
+}
+
+func newFD(sysfd syscall.Handle, family, sotype int, net string) (*netFD, error) {
+ ret := &netFD{
+ pfd: poll.FD{
+ Sysfd: sysfd,
+ IsStream: sotype == syscall.SOCK_STREAM,
+ ZeroReadIsEOF: sotype != syscall.SOCK_DGRAM && sotype != syscall.SOCK_RAW,
+ },
+ family: family,
+ sotype: sotype,
+ net: net,
+ }
+ return ret, nil
+}
+
+func (fd *netFD) init() error {
+ errcall, err := fd.pfd.Init(fd.net, true)
+ if errcall != "" {
+ err = wrapSyscallError(errcall, err)
+ }
+ return err
+}
+
+// Always returns nil for connected peer address result.
+func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) (syscall.Sockaddr, error) {
+ // Do not need to call fd.writeLock here,
+ // because fd is not yet accessible to user,
+ // so no concurrent operations are possible.
+ if err := fd.init(); err != nil {
+ return nil, err
+ }
+ if deadline, ok := ctx.Deadline(); ok && !deadline.IsZero() {
+ fd.pfd.SetWriteDeadline(deadline)
+ defer fd.pfd.SetWriteDeadline(noDeadline)
+ }
+ if !canUseConnectEx(fd.net) {
+ err := connectFunc(fd.pfd.Sysfd, ra)
+ return nil, os.NewSyscallError("connect", err)
+ }
+ // ConnectEx windows API requires an unconnected, previously bound socket.
+ if la == nil {
+ switch ra.(type) {
+ case *syscall.SockaddrInet4:
+ la = &syscall.SockaddrInet4{}
+ case *syscall.SockaddrInet6:
+ la = &syscall.SockaddrInet6{}
+ default:
+ panic("unexpected type in connect")
+ }
+ if err := syscall.Bind(fd.pfd.Sysfd, la); err != nil {
+ return nil, os.NewSyscallError("bind", err)
+ }
+ }
+
+ var isloopback bool
+ switch ra := ra.(type) {
+ case *syscall.SockaddrInet4:
+ isloopback = ra.Addr[0] == 127
+ case *syscall.SockaddrInet6:
+ isloopback = ra.Addr == [16]byte(IPv6loopback)
+ default:
+ panic("unexpected type in connect")
+ }
+ if isloopback {
+ // This makes ConnectEx() fails faster if the target port on the localhost
+ // is not reachable, instead of waiting for 2s.
+ params := windows.TCP_INITIAL_RTO_PARAMETERS{
+ Rtt: windows.TCP_INITIAL_RTO_UNSPECIFIED_RTT, // use the default or overridden by the Administrator
+ MaxSynRetransmissions: 1, // minimum possible value before Windows 10.0.16299
+ }
+ if windows.Support_TCP_INITIAL_RTO_NO_SYN_RETRANSMISSIONS() {
+ // In Windows 10.0.16299 TCP_INITIAL_RTO_NO_SYN_RETRANSMISSIONS makes ConnectEx() fails instantly.
+ params.MaxSynRetransmissions = windows.TCP_INITIAL_RTO_NO_SYN_RETRANSMISSIONS
+ }
+ var out uint32
+ // Don't abort the connection if WSAIoctl fails, as it is only an optimization.
+ // If it fails reliably, we expect TestDialClosedPortFailFast to detect it.
+ _ = fd.pfd.WSAIoctl(windows.SIO_TCP_INITIAL_RTO, (*byte)(unsafe.Pointer(&params)), uint32(unsafe.Sizeof(params)), nil, 0, &out, nil, 0)
+ }
+
+ // Wait for the goroutine converting context.Done into a write timeout
+ // to exist, otherwise our caller might cancel the context and
+ // cause fd.setWriteDeadline(aLongTimeAgo) to cancel a successful dial.
+ done := make(chan bool) // must be unbuffered
+ defer func() { done <- true }()
+ go func() {
+ select {
+ case <-ctx.Done():
+ // Force the runtime's poller to immediately give
+ // up waiting for writability.
+ fd.pfd.SetWriteDeadline(aLongTimeAgo)
+ <-done
+ case <-done:
+ }
+ }()
+
+ // Call ConnectEx API.
+ if err := fd.pfd.ConnectEx(ra); err != nil {
+ select {
+ case <-ctx.Done():
+ return nil, mapErr(ctx.Err())
+ default:
+ if _, ok := err.(syscall.Errno); ok {
+ err = os.NewSyscallError("connectex", err)
+ }
+ return nil, err
+ }
+ }
+ // Refresh socket properties.
+ return nil, os.NewSyscallError("setsockopt", syscall.Setsockopt(fd.pfd.Sysfd, syscall.SOL_SOCKET, syscall.SO_UPDATE_CONNECT_CONTEXT, (*byte)(unsafe.Pointer(&fd.pfd.Sysfd)), int32(unsafe.Sizeof(fd.pfd.Sysfd))))
+}
+
+func (c *conn) writeBuffers(v *Buffers) (int64, error) {
+ if !c.ok() {
+ return 0, syscall.EINVAL
+ }
+ n, err := c.fd.writeBuffers(v)
+ if err != nil {
+ return n, &OpError{Op: "wsasend", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
+ }
+ return n, nil
+}
+
+func (fd *netFD) writeBuffers(buf *Buffers) (int64, error) {
+ n, err := fd.pfd.Writev((*[][]byte)(buf))
+ runtime.KeepAlive(fd)
+ return n, wrapSyscallError("wsasend", err)
+}
+
+func (fd *netFD) accept() (*netFD, error) {
+ s, rawsa, rsan, errcall, err := fd.pfd.Accept(func() (syscall.Handle, error) {
+ return sysSocket(fd.family, fd.sotype, 0)
+ })
+
+ if err != nil {
+ if errcall != "" {
+ err = wrapSyscallError(errcall, err)
+ }
+ return nil, err
+ }
+
+ // Associate our new socket with IOCP.
+ netfd, err := newFD(s, fd.family, fd.sotype, fd.net)
+ if err != nil {
+ poll.CloseFunc(s)
+ return nil, err
+ }
+ if err := netfd.init(); err != nil {
+ fd.Close()
+ return nil, err
+ }
+
+ // Get local and peer addr out of AcceptEx buffer.
+ var lrsa, rrsa *syscall.RawSockaddrAny
+ var llen, rlen int32
+ syscall.GetAcceptExSockaddrs((*byte)(unsafe.Pointer(&rawsa[0])),
+ 0, rsan, rsan, &lrsa, &llen, &rrsa, &rlen)
+ lsa, _ := lrsa.Sockaddr()
+ rsa, _ := rrsa.Sockaddr()
+
+ netfd.setAddr(netfd.addrFunc()(lsa), netfd.addrFunc()(rsa))
+ return netfd, nil
+}
+
+// Unimplemented functions.
+
+func (fd *netFD) dup() (*os.File, error) {
+ // TODO: Implement this
+ return nil, syscall.EWINDOWS
+}
diff --git a/src/net/file.go b/src/net/file.go
new file mode 100644
index 0000000..c13332c
--- /dev/null
+++ b/src/net/file.go
@@ -0,0 +1,51 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import "os"
+
+// BUG(mikio): On JS and Windows, the FileConn, FileListener and
+// FilePacketConn functions are not implemented.
+
+type fileAddr string
+
+func (fileAddr) Network() string { return "file+net" }
+func (f fileAddr) String() string { return string(f) }
+
+// FileConn returns a copy of the network connection corresponding to
+// the open file f.
+// It is the caller's responsibility to close f when finished.
+// Closing c does not affect f, and closing f does not affect c.
+func FileConn(f *os.File) (c Conn, err error) {
+ c, err = fileConn(f)
+ if err != nil {
+ err = &OpError{Op: "file", Net: "file+net", Source: nil, Addr: fileAddr(f.Name()), Err: err}
+ }
+ return
+}
+
+// FileListener returns a copy of the network listener corresponding
+// to the open file f.
+// It is the caller's responsibility to close ln when finished.
+// Closing ln does not affect f, and closing f does not affect ln.
+func FileListener(f *os.File) (ln Listener, err error) {
+ ln, err = fileListener(f)
+ if err != nil {
+ err = &OpError{Op: "file", Net: "file+net", Source: nil, Addr: fileAddr(f.Name()), Err: err}
+ }
+ return
+}
+
+// FilePacketConn returns a copy of the packet network connection
+// corresponding to the open file f.
+// It is the caller's responsibility to close f when finished.
+// Closing c does not affect f, and closing f does not affect c.
+func FilePacketConn(f *os.File) (c PacketConn, err error) {
+ c, err = filePacketConn(f)
+ if err != nil {
+ err = &OpError{Op: "file", Net: "file+net", Source: nil, Addr: fileAddr(f.Name()), Err: err}
+ }
+ return
+}
diff --git a/src/net/file_plan9.go b/src/net/file_plan9.go
new file mode 100644
index 0000000..64aabf9
--- /dev/null
+++ b/src/net/file_plan9.go
@@ -0,0 +1,135 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "errors"
+ "io"
+ "os"
+ "syscall"
+)
+
+func (fd *netFD) status(ln int) (string, error) {
+ if !fd.ok() {
+ return "", syscall.EINVAL
+ }
+
+ status, err := os.Open(fd.dir + "/status")
+ if err != nil {
+ return "", err
+ }
+ defer status.Close()
+ buf := make([]byte, ln)
+ n, err := io.ReadFull(status, buf[:])
+ if err != nil {
+ return "", err
+ }
+ return string(buf[:n]), nil
+}
+
+func newFileFD(f *os.File) (net *netFD, err error) {
+ var ctl *os.File
+ close := func(fd int) {
+ if err != nil {
+ syscall.Close(fd)
+ }
+ }
+
+ path, err := syscall.Fd2path(int(f.Fd()))
+ if err != nil {
+ return nil, os.NewSyscallError("fd2path", err)
+ }
+ comp := splitAtBytes(path, "/")
+ n := len(comp)
+ if n < 3 || comp[0][0:3] != "net" {
+ return nil, syscall.EPLAN9
+ }
+
+ name := comp[2]
+ switch file := comp[n-1]; file {
+ case "ctl", "clone":
+ fd, err := syscall.Dup(int(f.Fd()), -1)
+ if err != nil {
+ return nil, os.NewSyscallError("dup", err)
+ }
+ defer close(fd)
+
+ dir := netdir + "/" + comp[n-2]
+ ctl = os.NewFile(uintptr(fd), dir+"/"+file)
+ ctl.Seek(0, io.SeekStart)
+ var buf [16]byte
+ n, err := ctl.Read(buf[:])
+ if err != nil {
+ return nil, err
+ }
+ name = string(buf[:n])
+ default:
+ if len(comp) < 4 {
+ return nil, errors.New("could not find control file for connection")
+ }
+ dir := netdir + "/" + comp[1] + "/" + name
+ ctl, err = os.OpenFile(dir+"/ctl", os.O_RDWR, 0)
+ if err != nil {
+ return nil, err
+ }
+ defer close(int(ctl.Fd()))
+ }
+ dir := netdir + "/" + comp[1] + "/" + name
+ laddr, err := readPlan9Addr(comp[1], dir+"/local")
+ if err != nil {
+ return nil, err
+ }
+ return newFD(comp[1], name, nil, ctl, nil, laddr, nil)
+}
+
+func fileConn(f *os.File) (Conn, error) {
+ fd, err := newFileFD(f)
+ if err != nil {
+ return nil, err
+ }
+ if !fd.ok() {
+ return nil, syscall.EINVAL
+ }
+
+ fd.data, err = os.OpenFile(fd.dir+"/data", os.O_RDWR, 0)
+ if err != nil {
+ return nil, err
+ }
+
+ switch fd.laddr.(type) {
+ case *TCPAddr:
+ return newTCPConn(fd, defaultTCPKeepAlive, testHookSetKeepAlive), nil
+ case *UDPAddr:
+ return newUDPConn(fd), nil
+ }
+ return nil, syscall.EPLAN9
+}
+
+func fileListener(f *os.File) (Listener, error) {
+ fd, err := newFileFD(f)
+ if err != nil {
+ return nil, err
+ }
+ switch fd.laddr.(type) {
+ case *TCPAddr:
+ default:
+ return nil, syscall.EPLAN9
+ }
+
+ // check that file corresponds to a listener
+ s, err := fd.status(len("Listen"))
+ if err != nil {
+ return nil, err
+ }
+ if s != "Listen" {
+ return nil, errors.New("file does not represent a listener")
+ }
+
+ return &TCPListener{fd: fd}, nil
+}
+
+func filePacketConn(f *os.File) (PacketConn, error) {
+ return nil, syscall.EPLAN9
+}
diff --git a/src/net/file_stub.go b/src/net/file_stub.go
new file mode 100644
index 0000000..91df926
--- /dev/null
+++ b/src/net/file_stub.go
@@ -0,0 +1,16 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build js && wasm
+
+package net
+
+import (
+ "os"
+ "syscall"
+)
+
+func fileConn(f *os.File) (Conn, error) { return nil, syscall.ENOPROTOOPT }
+func fileListener(f *os.File) (Listener, error) { return nil, syscall.ENOPROTOOPT }
+func filePacketConn(f *os.File) (PacketConn, error) { return nil, syscall.ENOPROTOOPT }
diff --git a/src/net/file_test.go b/src/net/file_test.go
new file mode 100644
index 0000000..53cd3c1
--- /dev/null
+++ b/src/net/file_test.go
@@ -0,0 +1,340 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !js && !wasip1
+
+package net
+
+import (
+ "os"
+ "reflect"
+ "runtime"
+ "sync"
+ "testing"
+)
+
+// The full stack test cases for IPConn have been moved to the
+// following:
+// golang.org/x/net/ipv4
+// golang.org/x/net/ipv6
+// golang.org/x/net/icmp
+
+var fileConnTests = []struct {
+ network string
+}{
+ {"tcp"},
+ {"udp"},
+ {"unix"},
+ {"unixpacket"},
+}
+
+func TestFileConn(t *testing.T) {
+ switch runtime.GOOS {
+ case "plan9", "windows":
+ t.Skipf("not supported on %s", runtime.GOOS)
+ }
+
+ for _, tt := range fileConnTests {
+ if !testableNetwork(tt.network) {
+ t.Logf("skipping %s test", tt.network)
+ continue
+ }
+
+ var network, address string
+ switch tt.network {
+ case "udp":
+ c := newLocalPacketListener(t, tt.network)
+ defer c.Close()
+ network = c.LocalAddr().Network()
+ address = c.LocalAddr().String()
+ default:
+ handler := func(ls *localServer, ln Listener) {
+ c, err := ln.Accept()
+ if err != nil {
+ return
+ }
+ defer c.Close()
+ var b [1]byte
+ c.Read(b[:])
+ }
+ ls := newLocalServer(t, tt.network)
+ defer ls.teardown()
+ if err := ls.buildup(handler); err != nil {
+ t.Fatal(err)
+ }
+ network = ls.Listener.Addr().Network()
+ address = ls.Listener.Addr().String()
+ }
+
+ c1, err := Dial(network, address)
+ if err != nil {
+ if perr := parseDialError(err); perr != nil {
+ t.Error(perr)
+ }
+ t.Fatal(err)
+ }
+ addr := c1.LocalAddr()
+
+ var f *os.File
+ switch c1 := c1.(type) {
+ case *TCPConn:
+ f, err = c1.File()
+ case *UDPConn:
+ f, err = c1.File()
+ case *UnixConn:
+ f, err = c1.File()
+ }
+ if err := c1.Close(); err != nil {
+ if perr := parseCloseError(err, false); perr != nil {
+ t.Error(perr)
+ }
+ t.Error(err)
+ }
+ if err != nil {
+ if perr := parseCommonError(err); perr != nil {
+ t.Error(perr)
+ }
+ t.Fatal(err)
+ }
+
+ c2, err := FileConn(f)
+ if err := f.Close(); err != nil {
+ t.Error(err)
+ }
+ if err != nil {
+ if perr := parseCommonError(err); perr != nil {
+ t.Error(perr)
+ }
+ t.Fatal(err)
+ }
+ defer c2.Close()
+
+ if _, err := c2.Write([]byte("FILECONN TEST")); err != nil {
+ if perr := parseWriteError(err); perr != nil {
+ t.Error(perr)
+ }
+ t.Fatal(err)
+ }
+ if !reflect.DeepEqual(c2.LocalAddr(), addr) {
+ t.Fatalf("got %#v; want %#v", c2.LocalAddr(), addr)
+ }
+ }
+}
+
+var fileListenerTests = []struct {
+ network string
+}{
+ {"tcp"},
+ {"unix"},
+ {"unixpacket"},
+}
+
+func TestFileListener(t *testing.T) {
+ switch runtime.GOOS {
+ case "plan9", "windows":
+ t.Skipf("not supported on %s", runtime.GOOS)
+ }
+
+ for _, tt := range fileListenerTests {
+ if !testableNetwork(tt.network) {
+ t.Logf("skipping %s test", tt.network)
+ continue
+ }
+
+ ln1 := newLocalListener(t, tt.network)
+ switch tt.network {
+ case "unix", "unixpacket":
+ defer os.Remove(ln1.Addr().String())
+ }
+ addr := ln1.Addr()
+
+ var (
+ f *os.File
+ err error
+ )
+ switch ln1 := ln1.(type) {
+ case *TCPListener:
+ f, err = ln1.File()
+ case *UnixListener:
+ f, err = ln1.File()
+ }
+ switch tt.network {
+ case "unix", "unixpacket":
+ defer ln1.Close() // UnixListener.Close calls syscall.Unlink internally
+ default:
+ if err := ln1.Close(); err != nil {
+ t.Error(err)
+ }
+ }
+ if err != nil {
+ if perr := parseCommonError(err); perr != nil {
+ t.Error(perr)
+ }
+ t.Fatal(err)
+ }
+
+ ln2, err := FileListener(f)
+ if err := f.Close(); err != nil {
+ t.Error(err)
+ }
+ if err != nil {
+ if perr := parseCommonError(err); perr != nil {
+ t.Error(perr)
+ }
+ t.Fatal(err)
+ }
+ defer ln2.Close()
+
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ c, err := Dial(ln2.Addr().Network(), ln2.Addr().String())
+ if err != nil {
+ if perr := parseDialError(err); perr != nil {
+ t.Error(perr)
+ }
+ t.Error(err)
+ return
+ }
+ c.Close()
+ }()
+ c, err := ln2.Accept()
+ if err != nil {
+ if perr := parseAcceptError(err); perr != nil {
+ t.Error(perr)
+ }
+ t.Fatal(err)
+ }
+ c.Close()
+ wg.Wait()
+ if !reflect.DeepEqual(ln2.Addr(), addr) {
+ t.Fatalf("got %#v; want %#v", ln2.Addr(), addr)
+ }
+ }
+}
+
+var filePacketConnTests = []struct {
+ network string
+}{
+ {"udp"},
+ {"unixgram"},
+}
+
+func TestFilePacketConn(t *testing.T) {
+ switch runtime.GOOS {
+ case "plan9", "windows":
+ t.Skipf("not supported on %s", runtime.GOOS)
+ }
+
+ for _, tt := range filePacketConnTests {
+ if !testableNetwork(tt.network) {
+ t.Logf("skipping %s test", tt.network)
+ continue
+ }
+
+ c1 := newLocalPacketListener(t, tt.network)
+ switch tt.network {
+ case "unixgram":
+ defer os.Remove(c1.LocalAddr().String())
+ }
+ addr := c1.LocalAddr()
+
+ var (
+ f *os.File
+ err error
+ )
+ switch c1 := c1.(type) {
+ case *UDPConn:
+ f, err = c1.File()
+ case *UnixConn:
+ f, err = c1.File()
+ }
+ if err := c1.Close(); err != nil {
+ if perr := parseCloseError(err, false); perr != nil {
+ t.Error(perr)
+ }
+ t.Error(err)
+ }
+ if err != nil {
+ if perr := parseCommonError(err); perr != nil {
+ t.Error(perr)
+ }
+ t.Fatal(err)
+ }
+
+ c2, err := FilePacketConn(f)
+ if err := f.Close(); err != nil {
+ t.Error(err)
+ }
+ if err != nil {
+ if perr := parseCommonError(err); perr != nil {
+ t.Error(perr)
+ }
+ t.Fatal(err)
+ }
+ defer c2.Close()
+
+ if _, err := c2.WriteTo([]byte("FILEPACKETCONN TEST"), addr); err != nil {
+ if perr := parseWriteError(err); perr != nil {
+ t.Error(perr)
+ }
+ t.Fatal(err)
+ }
+ if !reflect.DeepEqual(c2.LocalAddr(), addr) {
+ t.Fatalf("got %#v; want %#v", c2.LocalAddr(), addr)
+ }
+ }
+}
+
+// Issue 24483.
+func TestFileCloseRace(t *testing.T) {
+ switch runtime.GOOS {
+ case "plan9", "windows":
+ t.Skipf("not supported on %s", runtime.GOOS)
+ }
+ if !testableNetwork("tcp") {
+ t.Skip("tcp not supported")
+ }
+
+ handler := func(ls *localServer, ln Listener) {
+ c, err := ln.Accept()
+ if err != nil {
+ return
+ }
+ defer c.Close()
+ var b [1]byte
+ c.Read(b[:])
+ }
+
+ ls := newLocalServer(t, "tcp")
+ defer ls.teardown()
+ if err := ls.buildup(handler); err != nil {
+ t.Fatal(err)
+ }
+
+ const tries = 100
+ for i := 0; i < tries; i++ {
+ c1, err := Dial(ls.Listener.Addr().Network(), ls.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ tc := c1.(*TCPConn)
+
+ var wg sync.WaitGroup
+ wg.Add(2)
+ go func() {
+ defer wg.Done()
+ f, err := tc.File()
+ if err == nil {
+ f.Close()
+ }
+ }()
+ go func() {
+ defer wg.Done()
+ c1.Close()
+ }()
+ wg.Wait()
+ }
+}
diff --git a/src/net/file_unix.go b/src/net/file_unix.go
new file mode 100644
index 0000000..8b9fc38
--- /dev/null
+++ b/src/net/file_unix.go
@@ -0,0 +1,119 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build unix
+
+package net
+
+import (
+ "internal/poll"
+ "os"
+ "syscall"
+)
+
+func dupSocket(f *os.File) (int, error) {
+ s, call, err := poll.DupCloseOnExec(int(f.Fd()))
+ if err != nil {
+ if call != "" {
+ err = os.NewSyscallError(call, err)
+ }
+ return -1, err
+ }
+ if err := syscall.SetNonblock(s, true); err != nil {
+ poll.CloseFunc(s)
+ return -1, os.NewSyscallError("setnonblock", err)
+ }
+ return s, nil
+}
+
+func newFileFD(f *os.File) (*netFD, error) {
+ s, err := dupSocket(f)
+ if err != nil {
+ return nil, err
+ }
+ family := syscall.AF_UNSPEC
+ sotype, err := syscall.GetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_TYPE)
+ if err != nil {
+ poll.CloseFunc(s)
+ return nil, os.NewSyscallError("getsockopt", err)
+ }
+ lsa, _ := syscall.Getsockname(s)
+ rsa, _ := syscall.Getpeername(s)
+ switch lsa.(type) {
+ case *syscall.SockaddrInet4:
+ family = syscall.AF_INET
+ case *syscall.SockaddrInet6:
+ family = syscall.AF_INET6
+ case *syscall.SockaddrUnix:
+ family = syscall.AF_UNIX
+ default:
+ poll.CloseFunc(s)
+ return nil, syscall.EPROTONOSUPPORT
+ }
+ fd, err := newFD(s, family, sotype, "")
+ if err != nil {
+ poll.CloseFunc(s)
+ return nil, err
+ }
+ laddr := fd.addrFunc()(lsa)
+ raddr := fd.addrFunc()(rsa)
+ fd.net = laddr.Network()
+ if err := fd.init(); err != nil {
+ fd.Close()
+ return nil, err
+ }
+ fd.setAddr(laddr, raddr)
+ return fd, nil
+}
+
+func fileConn(f *os.File) (Conn, error) {
+ fd, err := newFileFD(f)
+ if err != nil {
+ return nil, err
+ }
+ switch fd.laddr.(type) {
+ case *TCPAddr:
+ return newTCPConn(fd, defaultTCPKeepAlive, testHookSetKeepAlive), nil
+ case *UDPAddr:
+ return newUDPConn(fd), nil
+ case *IPAddr:
+ return newIPConn(fd), nil
+ case *UnixAddr:
+ return newUnixConn(fd), nil
+ }
+ fd.Close()
+ return nil, syscall.EINVAL
+}
+
+func fileListener(f *os.File) (Listener, error) {
+ fd, err := newFileFD(f)
+ if err != nil {
+ return nil, err
+ }
+ switch laddr := fd.laddr.(type) {
+ case *TCPAddr:
+ return &TCPListener{fd: fd}, nil
+ case *UnixAddr:
+ return &UnixListener{fd: fd, path: laddr.Name, unlink: false}, nil
+ }
+ fd.Close()
+ return nil, syscall.EINVAL
+}
+
+func filePacketConn(f *os.File) (PacketConn, error) {
+ fd, err := newFileFD(f)
+ if err != nil {
+ return nil, err
+ }
+ switch fd.laddr.(type) {
+ case *UDPAddr:
+ return newUDPConn(fd), nil
+ case *IPAddr:
+ return newIPConn(fd), nil
+ case *UnixAddr:
+ return newUnixConn(fd), nil
+ }
+ fd.Close()
+ return nil, syscall.EINVAL
+}
diff --git a/src/net/file_unix_test.go b/src/net/file_unix_test.go
new file mode 100644
index 0000000..0499a02
--- /dev/null
+++ b/src/net/file_unix_test.go
@@ -0,0 +1,101 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build unix
+
+package net
+
+import (
+ "internal/syscall/unix"
+ "testing"
+)
+
+// For backward compatibility, opening a net.Conn, turning it into an os.File,
+// and calling the Fd method should return a blocking descriptor.
+func TestFileFdBlocks(t *testing.T) {
+ if !testableNetwork("unix") {
+ t.Skipf("skipping: unix sockets not supported")
+ }
+
+ ls := newLocalServer(t, "unix")
+ defer ls.teardown()
+
+ errc := make(chan error, 1)
+ done := make(chan bool)
+ handler := func(ls *localServer, ln Listener) {
+ server, err := ln.Accept()
+ errc <- err
+ if err != nil {
+ return
+ }
+ defer server.Close()
+ <-done
+ }
+ if err := ls.buildup(handler); err != nil {
+ t.Fatal(err)
+ }
+ defer close(done)
+
+ client, err := Dial(ls.Listener.Addr().Network(), ls.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer client.Close()
+
+ if err := <-errc; err != nil {
+ t.Fatalf("server error: %v", err)
+ }
+
+ // The socket should be non-blocking.
+ rawconn, err := client.(*UnixConn).SyscallConn()
+ if err != nil {
+ t.Fatal(err)
+ }
+ err = rawconn.Control(func(fd uintptr) {
+ nonblock, err := unix.IsNonblock(int(fd))
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !nonblock {
+ t.Fatal("unix socket is in blocking mode")
+ }
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ file, err := client.(*UnixConn).File()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // At this point the descriptor should still be non-blocking.
+ rawconn, err = file.SyscallConn()
+ if err != nil {
+ t.Fatal(err)
+ }
+ err = rawconn.Control(func(fd uintptr) {
+ nonblock, err := unix.IsNonblock(int(fd))
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !nonblock {
+ t.Fatal("unix socket as os.File is in blocking mode")
+ }
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ fd := file.Fd()
+
+ // Calling Fd should have put the descriptor into blocking mode.
+ nonblock, err := unix.IsNonblock(int(fd))
+ if err != nil {
+ t.Fatal(err)
+ }
+ if nonblock {
+ t.Error("unix socket through os.File.Fd is non-blocking")
+ }
+}
diff --git a/src/net/file_wasip1.go b/src/net/file_wasip1.go
new file mode 100644
index 0000000..a3624ef
--- /dev/null
+++ b/src/net/file_wasip1.go
@@ -0,0 +1,102 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build wasip1
+
+package net
+
+import (
+ "os"
+ "syscall"
+ _ "unsafe" // for go:linkname
+)
+
+func fileListener(f *os.File) (Listener, error) {
+ filetype, err := fd_fdstat_get_type(f.PollFD().Sysfd)
+ if err != nil {
+ return nil, err
+ }
+ net, err := fileListenNet(filetype)
+ if err != nil {
+ return nil, err
+ }
+ pfd := f.PollFD().Copy()
+ fd := newPollFD(net, pfd)
+ if err := fd.init(); err != nil {
+ pfd.Close()
+ return nil, err
+ }
+ return newFileListener(fd), nil
+}
+
+func fileConn(f *os.File) (Conn, error) {
+ filetype, err := fd_fdstat_get_type(f.PollFD().Sysfd)
+ if err != nil {
+ return nil, err
+ }
+ net, err := fileConnNet(filetype)
+ if err != nil {
+ return nil, err
+ }
+ pfd := f.PollFD().Copy()
+ fd := newPollFD(net, pfd)
+ if err := fd.init(); err != nil {
+ pfd.Close()
+ return nil, err
+ }
+ return newFileConn(fd), nil
+}
+
+func filePacketConn(f *os.File) (PacketConn, error) {
+ return nil, syscall.ENOPROTOOPT
+}
+
+func fileListenNet(filetype syscall.Filetype) (string, error) {
+ switch filetype {
+ case syscall.FILETYPE_SOCKET_STREAM:
+ return "tcp", nil
+ case syscall.FILETYPE_SOCKET_DGRAM:
+ return "", syscall.EOPNOTSUPP
+ default:
+ return "", syscall.ENOTSOCK
+ }
+}
+
+func fileConnNet(filetype syscall.Filetype) (string, error) {
+ switch filetype {
+ case syscall.FILETYPE_SOCKET_STREAM:
+ return "tcp", nil
+ case syscall.FILETYPE_SOCKET_DGRAM:
+ return "udp", nil
+ default:
+ return "", syscall.ENOTSOCK
+ }
+}
+
+func newFileListener(fd *netFD) Listener {
+ switch fd.net {
+ case "tcp":
+ return &TCPListener{fd: fd}
+ default:
+ panic("unsupported network for file listener: " + fd.net)
+ }
+}
+
+func newFileConn(fd *netFD) Conn {
+ switch fd.net {
+ case "tcp":
+ return &TCPConn{conn{fd: fd}}
+ case "udp":
+ return &UDPConn{conn{fd: fd}}
+ default:
+ panic("unsupported network for file connection: " + fd.net)
+ }
+}
+
+// This helper is implemented in the syscall package. It means we don't have
+// to redefine the fd_fdstat_get host import or the fdstat struct it
+// populates.
+//
+//go:linkname fd_fdstat_get_type syscall.fd_fdstat_get_type
+func fd_fdstat_get_type(fd int) (uint8, error)
diff --git a/src/net/file_wasip1_test.go b/src/net/file_wasip1_test.go
new file mode 100644
index 0000000..4f42590
--- /dev/null
+++ b/src/net/file_wasip1_test.go
@@ -0,0 +1,112 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build wasip1
+
+package net
+
+import (
+ "syscall"
+ "testing"
+)
+
+// The tests in this file intend to validate the ability for net.FileConn and
+// net.FileListener to handle both TCP and UDP sockets. Ideally we would test
+// the public interface by constructing an *os.File from a file descriptor
+// opened on a socket, but the WASI preview 1 specification is too limited to
+// support this approach for UDP sockets. Instead, we test the internals that
+// make it possible for WASI host runtimes and guest programs to integrate
+// socket extensions with the net package using net.FileConn/net.FileListener.
+//
+// Note that the creation of net.Conn and net.Listener values for TCP sockets
+// has an end-to-end test in src/runtime/internal/wasitest, here we are only
+// verifying the code paths specific to UDP, and error handling for invalid use
+// of the functions.
+
+func TestWasip1FileConnNet(t *testing.T) {
+ tests := []struct {
+ filetype syscall.Filetype
+ network string
+ error error
+ }{
+ {syscall.FILETYPE_SOCKET_STREAM, "tcp", nil},
+ {syscall.FILETYPE_SOCKET_DGRAM, "udp", nil},
+ {syscall.FILETYPE_BLOCK_DEVICE, "", syscall.ENOTSOCK},
+ {syscall.FILETYPE_CHARACTER_DEVICE, "", syscall.ENOTSOCK},
+ {syscall.FILETYPE_DIRECTORY, "", syscall.ENOTSOCK},
+ {syscall.FILETYPE_REGULAR_FILE, "", syscall.ENOTSOCK},
+ {syscall.FILETYPE_SYMBOLIC_LINK, "", syscall.ENOTSOCK},
+ {syscall.FILETYPE_UNKNOWN, "", syscall.ENOTSOCK},
+ }
+ for _, test := range tests {
+ net, err := fileConnNet(test.filetype)
+ if net != test.network {
+ t.Errorf("fileConnNet: network mismatch: want=%q got=%q", test.network, net)
+ }
+ if err != test.error {
+ t.Errorf("fileConnNet: error mismatch: want=%v got=%v", test.error, err)
+ }
+ }
+}
+
+func TestWasip1FileListenNet(t *testing.T) {
+ tests := []struct {
+ filetype syscall.Filetype
+ network string
+ error error
+ }{
+ {syscall.FILETYPE_SOCKET_STREAM, "tcp", nil},
+ {syscall.FILETYPE_SOCKET_DGRAM, "", syscall.EOPNOTSUPP},
+ {syscall.FILETYPE_BLOCK_DEVICE, "", syscall.ENOTSOCK},
+ {syscall.FILETYPE_CHARACTER_DEVICE, "", syscall.ENOTSOCK},
+ {syscall.FILETYPE_DIRECTORY, "", syscall.ENOTSOCK},
+ {syscall.FILETYPE_REGULAR_FILE, "", syscall.ENOTSOCK},
+ {syscall.FILETYPE_SYMBOLIC_LINK, "", syscall.ENOTSOCK},
+ {syscall.FILETYPE_UNKNOWN, "", syscall.ENOTSOCK},
+ }
+ for _, test := range tests {
+ net, err := fileListenNet(test.filetype)
+ if net != test.network {
+ t.Errorf("fileListenNet: network mismatch: want=%q got=%q", test.network, net)
+ }
+ if err != test.error {
+ t.Errorf("fileListenNet: error mismatch: want=%v got=%v", test.error, err)
+ }
+ }
+}
+
+func TestWasip1NewFileListener(t *testing.T) {
+ if l, ok := newFileListener(newFD("tcp", -1)).(*TCPListener); !ok {
+ t.Errorf("newFileListener: tcp listener type mismatch: %T", l)
+ } else {
+ testIsTCPAddr(t, "Addr", l.Addr())
+ }
+}
+
+func TestWasip1NewFileConn(t *testing.T) {
+ if c, ok := newFileConn(newFD("tcp", -1)).(*TCPConn); !ok {
+ t.Errorf("newFileConn: tcp conn type mismatch: %T", c)
+ } else {
+ testIsTCPAddr(t, "LocalAddr", c.LocalAddr())
+ testIsTCPAddr(t, "RemoteAddr", c.RemoteAddr())
+ }
+ if c, ok := newFileConn(newFD("udp", -1)).(*UDPConn); !ok {
+ t.Errorf("newFileConn: udp conn type mismatch: %T", c)
+ } else {
+ testIsUDPAddr(t, "LocalAddr", c.LocalAddr())
+ testIsUDPAddr(t, "RemoteAddr", c.RemoteAddr())
+ }
+}
+
+func testIsTCPAddr(t *testing.T, method string, addr Addr) {
+ if _, ok := addr.(*TCPAddr); !ok {
+ t.Errorf("%s: returned address is not a *TCPAddr: %T", method, addr)
+ }
+}
+
+func testIsUDPAddr(t *testing.T, method string, addr Addr) {
+ if _, ok := addr.(*UDPAddr); !ok {
+ t.Errorf("%s: returned address is not a *UDPAddr: %T", method, addr)
+ }
+}
diff --git a/src/net/file_windows.go b/src/net/file_windows.go
new file mode 100644
index 0000000..241fa17
--- /dev/null
+++ b/src/net/file_windows.go
@@ -0,0 +1,25 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "os"
+ "syscall"
+)
+
+func fileConn(f *os.File) (Conn, error) {
+ // TODO: Implement this
+ return nil, syscall.EWINDOWS
+}
+
+func fileListener(f *os.File) (Listener, error) {
+ // TODO: Implement this
+ return nil, syscall.EWINDOWS
+}
+
+func filePacketConn(f *os.File) (PacketConn, error) {
+ // TODO: Implement this
+ return nil, syscall.EWINDOWS
+}
diff --git a/src/net/hook.go b/src/net/hook.go
new file mode 100644
index 0000000..ea71803
--- /dev/null
+++ b/src/net/hook.go
@@ -0,0 +1,26 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "context"
+ "time"
+)
+
+var (
+ // if non-nil, overrides dialTCP.
+ testHookDialTCP func(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error)
+
+ testHookHostsPath = "/etc/hosts"
+ testHookLookupIP = func(
+ ctx context.Context,
+ fn func(context.Context, string, string) ([]IPAddr, error),
+ network string,
+ host string,
+ ) ([]IPAddr, error) {
+ return fn(ctx, network, host)
+ }
+ testHookSetKeepAlive = func(time.Duration) {}
+)
diff --git a/src/net/hook_plan9.go b/src/net/hook_plan9.go
new file mode 100644
index 0000000..e053348
--- /dev/null
+++ b/src/net/hook_plan9.go
@@ -0,0 +1,9 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import "time"
+
+var testHookDialChannel = func() { time.Sleep(time.Millisecond) } // see golang.org/issue/5349
diff --git a/src/net/hook_unix.go b/src/net/hook_unix.go
new file mode 100644
index 0000000..4e20f59
--- /dev/null
+++ b/src/net/hook_unix.go
@@ -0,0 +1,20 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build unix || (js && wasm) || wasip1
+
+package net
+
+import "syscall"
+
+var (
+ testHookDialChannel = func() {} // for golang.org/issue/5349
+ testHookCanceledDial = func() {} // for golang.org/issue/16523
+
+ // Placeholders for socket system calls.
+ socketFunc func(int, int, int) (int, error) = syscall.Socket
+ connectFunc func(int, syscall.Sockaddr) error = syscall.Connect
+ listenFunc func(int, int) error = syscall.Listen
+ getsockoptIntFunc func(int, int, int) (int, error) = syscall.GetsockoptInt
+)
diff --git a/src/net/hook_windows.go b/src/net/hook_windows.go
new file mode 100644
index 0000000..ab8656c
--- /dev/null
+++ b/src/net/hook_windows.go
@@ -0,0 +1,21 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "internal/syscall/windows"
+ "syscall"
+ "time"
+)
+
+var (
+ testHookDialChannel = func() { time.Sleep(time.Millisecond) } // see golang.org/issue/5349
+
+ // Placeholders for socket system calls.
+ socketFunc func(int, int, int) (syscall.Handle, error) = syscall.Socket
+ wsaSocketFunc func(int32, int32, int32, *syscall.WSAProtocolInfo, uint32, uint32) (syscall.Handle, error) = windows.WSASocket
+ connectFunc func(syscall.Handle, syscall.Sockaddr) error = syscall.Connect
+ listenFunc func(syscall.Handle, int) error = syscall.Listen
+)
diff --git a/src/net/hosts.go b/src/net/hosts.go
new file mode 100644
index 0000000..56e6674
--- /dev/null
+++ b/src/net/hosts.go
@@ -0,0 +1,165 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "errors"
+ "internal/bytealg"
+ "io/fs"
+ "net/netip"
+ "sync"
+ "time"
+)
+
+const cacheMaxAge = 5 * time.Second
+
+func parseLiteralIP(addr string) string {
+ ip, err := netip.ParseAddr(addr)
+ if err != nil {
+ return ""
+ }
+ return ip.String()
+}
+
+type byName struct {
+ addrs []string
+ canonicalName string
+}
+
+// hosts contains known host entries.
+var hosts struct {
+ sync.Mutex
+
+ // Key for the list of literal IP addresses must be a host
+ // name. It would be part of DNS labels, a FQDN or an absolute
+ // FQDN.
+ // For now the key is converted to lower case for convenience.
+ byName map[string]byName
+
+ // Key for the list of host names must be a literal IP address
+ // including IPv6 address with zone identifier.
+ // We don't support old-classful IP address notation.
+ byAddr map[string][]string
+
+ expire time.Time
+ path string
+ mtime time.Time
+ size int64
+}
+
+func readHosts() {
+ now := time.Now()
+ hp := testHookHostsPath
+
+ if now.Before(hosts.expire) && hosts.path == hp && len(hosts.byName) > 0 {
+ return
+ }
+ mtime, size, err := stat(hp)
+ if err == nil && hosts.path == hp && hosts.mtime.Equal(mtime) && hosts.size == size {
+ hosts.expire = now.Add(cacheMaxAge)
+ return
+ }
+
+ hs := make(map[string]byName)
+ is := make(map[string][]string)
+
+ file, err := open(hp)
+ if err != nil {
+ if !errors.Is(err, fs.ErrNotExist) && !errors.Is(err, fs.ErrPermission) {
+ return
+ }
+ }
+
+ if file != nil {
+ defer file.close()
+ for line, ok := file.readLine(); ok; line, ok = file.readLine() {
+ if i := bytealg.IndexByteString(line, '#'); i >= 0 {
+ // Discard comments.
+ line = line[0:i]
+ }
+ f := getFields(line)
+ if len(f) < 2 {
+ continue
+ }
+ addr := parseLiteralIP(f[0])
+ if addr == "" {
+ continue
+ }
+
+ var canonical string
+ for i := 1; i < len(f); i++ {
+ name := absDomainName(f[i])
+ h := []byte(f[i])
+ lowerASCIIBytes(h)
+ key := absDomainName(string(h))
+
+ if i == 1 {
+ canonical = key
+ }
+
+ is[addr] = append(is[addr], name)
+
+ if v, ok := hs[key]; ok {
+ hs[key] = byName{
+ addrs: append(v.addrs, addr),
+ canonicalName: v.canonicalName,
+ }
+ continue
+ }
+
+ hs[key] = byName{
+ addrs: []string{addr},
+ canonicalName: canonical,
+ }
+ }
+ }
+ }
+ // Update the data cache.
+ hosts.expire = now.Add(cacheMaxAge)
+ hosts.path = hp
+ hosts.byName = hs
+ hosts.byAddr = is
+ hosts.mtime = mtime
+ hosts.size = size
+}
+
+// lookupStaticHost looks up the addresses and the canonical name for the given host from /etc/hosts.
+func lookupStaticHost(host string) ([]string, string) {
+ hosts.Lock()
+ defer hosts.Unlock()
+ readHosts()
+ if len(hosts.byName) != 0 {
+ if hasUpperCase(host) {
+ lowerHost := []byte(host)
+ lowerASCIIBytes(lowerHost)
+ host = string(lowerHost)
+ }
+ if byName, ok := hosts.byName[absDomainName(host)]; ok {
+ ipsCp := make([]string, len(byName.addrs))
+ copy(ipsCp, byName.addrs)
+ return ipsCp, byName.canonicalName
+ }
+ }
+ return nil, ""
+}
+
+// lookupStaticAddr looks up the hosts for the given address from /etc/hosts.
+func lookupStaticAddr(addr string) []string {
+ hosts.Lock()
+ defer hosts.Unlock()
+ readHosts()
+ addr = parseLiteralIP(addr)
+ if addr == "" {
+ return nil
+ }
+ if len(hosts.byAddr) != 0 {
+ if hosts, ok := hosts.byAddr[addr]; ok {
+ hostsCp := make([]string, len(hosts))
+ copy(hostsCp, hosts)
+ return hostsCp
+ }
+ }
+ return nil
+}
diff --git a/src/net/hosts_test.go b/src/net/hosts_test.go
new file mode 100644
index 0000000..b3f189e
--- /dev/null
+++ b/src/net/hosts_test.go
@@ -0,0 +1,214 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "reflect"
+ "strings"
+ "testing"
+)
+
+type staticHostEntry struct {
+ in string
+ out []string
+}
+
+var lookupStaticHostTests = []struct {
+ name string
+ ents []staticHostEntry
+}{
+ {
+ "testdata/hosts",
+ []staticHostEntry{
+ {"odin", []string{"127.0.0.2", "127.0.0.3", "::2"}},
+ {"thor", []string{"127.1.1.1"}},
+ {"ullr", []string{"127.1.1.2"}},
+ {"ullrhost", []string{"127.1.1.2"}},
+ {"localhost", []string{"fe80::1%lo0"}},
+ },
+ },
+ {
+ "testdata/singleline-hosts", // see golang.org/issue/6646
+ []staticHostEntry{
+ {"odin", []string{"127.0.0.2"}},
+ },
+ },
+ {
+ "testdata/ipv4-hosts",
+ []staticHostEntry{
+ {"localhost", []string{"127.0.0.1", "127.0.0.2", "127.0.0.3"}},
+ {"localhost.localdomain", []string{"127.0.0.3"}},
+ },
+ },
+ {
+ "testdata/ipv6-hosts", // see golang.org/issue/8996
+ []staticHostEntry{
+ {"localhost", []string{"::1", "fe80::1", "fe80::2%lo0", "fe80::3%lo0"}},
+ {"localhost.localdomain", []string{"fe80::3%lo0"}},
+ },
+ },
+ {
+ "testdata/case-hosts", // see golang.org/issue/12806
+ []staticHostEntry{
+ {"PreserveMe", []string{"127.0.0.1", "::1"}},
+ {"PreserveMe.local", []string{"127.0.0.1", "::1"}},
+ },
+ },
+}
+
+func TestLookupStaticHost(t *testing.T) {
+ defer func(orig string) { testHookHostsPath = orig }(testHookHostsPath)
+
+ for _, tt := range lookupStaticHostTests {
+ testHookHostsPath = tt.name
+ for _, ent := range tt.ents {
+ testStaticHost(t, tt.name, ent)
+ }
+ }
+}
+
+func testStaticHost(t *testing.T, hostsPath string, ent staticHostEntry) {
+ ins := []string{ent.in, absDomainName(ent.in), strings.ToLower(ent.in), strings.ToUpper(ent.in)}
+ for _, in := range ins {
+ addrs, _ := lookupStaticHost(in)
+ if !reflect.DeepEqual(addrs, ent.out) {
+ t.Errorf("%s, lookupStaticHost(%s) = %v; want %v", hostsPath, in, addrs, ent.out)
+ }
+ }
+}
+
+var lookupStaticAddrTests = []struct {
+ name string
+ ents []staticHostEntry
+}{
+ {
+ "testdata/hosts",
+ []staticHostEntry{
+ {"255.255.255.255", []string{"broadcasthost"}},
+ {"127.0.0.2", []string{"odin"}},
+ {"127.0.0.3", []string{"odin"}},
+ {"::2", []string{"odin"}},
+ {"127.1.1.1", []string{"thor"}},
+ {"127.1.1.2", []string{"ullr", "ullrhost"}},
+ {"fe80::1%lo0", []string{"localhost"}},
+ },
+ },
+ {
+ "testdata/singleline-hosts", // see golang.org/issue/6646
+ []staticHostEntry{
+ {"127.0.0.2", []string{"odin"}},
+ },
+ },
+ {
+ "testdata/ipv4-hosts",
+ []staticHostEntry{
+ {"127.0.0.1", []string{"localhost"}},
+ {"127.0.0.2", []string{"localhost"}},
+ {"127.0.0.3", []string{"localhost", "localhost.localdomain"}},
+ },
+ },
+ {
+ "testdata/ipv6-hosts", // see golang.org/issue/8996
+ []staticHostEntry{
+ {"::1", []string{"localhost"}},
+ {"fe80::1", []string{"localhost"}},
+ {"fe80::2%lo0", []string{"localhost"}},
+ {"fe80::3%lo0", []string{"localhost", "localhost.localdomain"}},
+ },
+ },
+ {
+ "testdata/case-hosts", // see golang.org/issue/12806
+ []staticHostEntry{
+ {"127.0.0.1", []string{"PreserveMe", "PreserveMe.local"}},
+ {"::1", []string{"PreserveMe", "PreserveMe.local"}},
+ },
+ },
+}
+
+func TestLookupStaticAddr(t *testing.T) {
+ defer func(orig string) { testHookHostsPath = orig }(testHookHostsPath)
+
+ for _, tt := range lookupStaticAddrTests {
+ testHookHostsPath = tt.name
+ for _, ent := range tt.ents {
+ testStaticAddr(t, tt.name, ent)
+ }
+ }
+}
+
+func testStaticAddr(t *testing.T, hostsPath string, ent staticHostEntry) {
+ hosts := lookupStaticAddr(ent.in)
+ for i := range ent.out {
+ ent.out[i] = absDomainName(ent.out[i])
+ }
+ if !reflect.DeepEqual(hosts, ent.out) {
+ t.Errorf("%s, lookupStaticAddr(%s) = %v; want %v", hostsPath, ent.in, hosts, ent.out)
+ }
+}
+
+func TestHostCacheModification(t *testing.T) {
+ // Ensure that programs can't modify the internals of the host cache.
+ // See https://golang.org/issues/14212.
+ defer func(orig string) { testHookHostsPath = orig }(testHookHostsPath)
+
+ testHookHostsPath = "testdata/ipv4-hosts"
+ ent := staticHostEntry{"localhost", []string{"127.0.0.1", "127.0.0.2", "127.0.0.3"}}
+ testStaticHost(t, testHookHostsPath, ent)
+ // Modify the addresses return by lookupStaticHost.
+ addrs, _ := lookupStaticHost(ent.in)
+ for i := range addrs {
+ addrs[i] += "junk"
+ }
+ testStaticHost(t, testHookHostsPath, ent)
+
+ testHookHostsPath = "testdata/ipv6-hosts"
+ ent = staticHostEntry{"::1", []string{"localhost"}}
+ testStaticAddr(t, testHookHostsPath, ent)
+ // Modify the hosts return by lookupStaticAddr.
+ hosts := lookupStaticAddr(ent.in)
+ for i := range hosts {
+ hosts[i] += "junk"
+ }
+ testStaticAddr(t, testHookHostsPath, ent)
+}
+
+var lookupStaticHostAliasesTest = []struct {
+ lookup, res string
+}{
+ // 127.0.0.1
+ {"test", "test"},
+ // 127.0.0.2
+ {"test2.example.com", "test2.example.com"},
+ {"2.test", "test2.example.com"},
+ // 127.0.0.3
+ {"test3.example.com", "3.test"},
+ {"3.test", "3.test"},
+ // 127.0.0.4
+ {"example.com", "example.com"},
+ // 127.0.0.5
+ {"test5.example.com", "test4.example.com"},
+ {"5.test", "test4.example.com"},
+ {"4.test", "test4.example.com"},
+ {"test4.example.com", "test4.example.com"},
+}
+
+func TestLookupStaticHostAliases(t *testing.T) {
+ defer func(orig string) { testHookHostsPath = orig }(testHookHostsPath)
+
+ testHookHostsPath = "testdata/aliases"
+ for _, ent := range lookupStaticHostAliasesTest {
+ testLookupStaticHostAliases(t, ent.lookup, absDomainName(ent.res))
+ }
+}
+
+func testLookupStaticHostAliases(t *testing.T, lookup, lookupRes string) {
+ ins := []string{lookup, absDomainName(lookup), strings.ToLower(lookup), strings.ToUpper(lookup)}
+ for _, in := range ins {
+ _, res := lookupStaticHost(in)
+ if res != lookupRes {
+ t.Errorf("lookupStaticHost(%v): got %v, want %v", in, res, lookupRes)
+ }
+ }
+}
diff --git a/src/net/http/alpn_test.go b/src/net/http/alpn_test.go
new file mode 100644
index 0000000..a51038c
--- /dev/null
+++ b/src/net/http/alpn_test.go
@@ -0,0 +1,132 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http_test
+
+import (
+ "bufio"
+ "bytes"
+ "crypto/tls"
+ "crypto/x509"
+ "fmt"
+ "io"
+ . "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+)
+
+func TestNextProtoUpgrade(t *testing.T) {
+ setParallel(t)
+ defer afterTest(t)
+ ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ fmt.Fprintf(w, "path=%s,proto=", r.URL.Path)
+ if r.TLS != nil {
+ w.Write([]byte(r.TLS.NegotiatedProtocol))
+ }
+ if r.RemoteAddr == "" {
+ t.Error("request with no RemoteAddr")
+ }
+ if r.Body == nil {
+ t.Errorf("request with nil Body")
+ }
+ }))
+ ts.TLS = &tls.Config{
+ NextProtos: []string{"unhandled-proto", "tls-0.9"},
+ }
+ ts.Config.TLSNextProto = map[string]func(*Server, *tls.Conn, Handler){
+ "tls-0.9": handleTLSProtocol09,
+ }
+ ts.StartTLS()
+ defer ts.Close()
+
+ // Normal request, without NPN.
+ {
+ c := ts.Client()
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ body, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if want := "path=/,proto="; string(body) != want {
+ t.Errorf("plain request = %q; want %q", body, want)
+ }
+ }
+
+ // Request to an advertised but unhandled NPN protocol.
+ // Server will hang up.
+ {
+ certPool := x509.NewCertPool()
+ certPool.AddCert(ts.Certificate())
+ tr := &Transport{
+ TLSClientConfig: &tls.Config{
+ RootCAs: certPool,
+ NextProtos: []string{"unhandled-proto"},
+ },
+ }
+ defer tr.CloseIdleConnections()
+ c := &Client{
+ Transport: tr,
+ }
+ res, err := c.Get(ts.URL)
+ if err == nil {
+ defer res.Body.Close()
+ var buf bytes.Buffer
+ res.Write(&buf)
+ t.Errorf("expected error on unhandled-proto request; got: %s", buf.Bytes())
+ }
+ }
+
+ // Request using the "tls-0.9" protocol, which we register here.
+ // It is HTTP/0.9 over TLS.
+ {
+ c := ts.Client()
+ tlsConfig := c.Transport.(*Transport).TLSClientConfig
+ tlsConfig.NextProtos = []string{"tls-0.9"}
+ conn, err := tls.Dial("tcp", ts.Listener.Addr().String(), tlsConfig)
+ if err != nil {
+ t.Fatal(err)
+ }
+ conn.Write([]byte("GET /foo\n"))
+ body, err := io.ReadAll(conn)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if want := "path=/foo,proto=tls-0.9"; string(body) != want {
+ t.Errorf("plain request = %q; want %q", body, want)
+ }
+ }
+}
+
+// handleTLSProtocol09 implements the HTTP/0.9 protocol over TLS, for the
+// TestNextProtoUpgrade test.
+func handleTLSProtocol09(srv *Server, conn *tls.Conn, h Handler) {
+ br := bufio.NewReader(conn)
+ line, err := br.ReadString('\n')
+ if err != nil {
+ return
+ }
+ line = strings.TrimSpace(line)
+ path := strings.TrimPrefix(line, "GET ")
+ if path == line {
+ return
+ }
+ req, _ := NewRequest("GET", path, nil)
+ req.Proto = "HTTP/0.9"
+ req.ProtoMajor = 0
+ req.ProtoMinor = 9
+ rw := &http09Writer{conn, make(Header)}
+ h.ServeHTTP(rw, req)
+}
+
+type http09Writer struct {
+ io.Writer
+ h Header
+}
+
+func (w http09Writer) Header() Header { return w.h }
+func (w http09Writer) WriteHeader(int) {} // no headers
diff --git a/src/net/http/cgi/child.go b/src/net/http/cgi/child.go
new file mode 100644
index 0000000..1411f0b
--- /dev/null
+++ b/src/net/http/cgi/child.go
@@ -0,0 +1,222 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This file implements CGI from the perspective of a child
+// process.
+
+package cgi
+
+import (
+ "bufio"
+ "crypto/tls"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "net/http"
+ "net/url"
+ "os"
+ "strconv"
+ "strings"
+)
+
+// Request returns the HTTP request as represented in the current
+// environment. This assumes the current program is being run
+// by a web server in a CGI environment.
+// The returned Request's Body is populated, if applicable.
+func Request() (*http.Request, error) {
+ r, err := RequestFromMap(envMap(os.Environ()))
+ if err != nil {
+ return nil, err
+ }
+ if r.ContentLength > 0 {
+ r.Body = io.NopCloser(io.LimitReader(os.Stdin, r.ContentLength))
+ }
+ return r, nil
+}
+
+func envMap(env []string) map[string]string {
+ m := make(map[string]string)
+ for _, kv := range env {
+ if k, v, ok := strings.Cut(kv, "="); ok {
+ m[k] = v
+ }
+ }
+ return m
+}
+
+// RequestFromMap creates an http.Request from CGI variables.
+// The returned Request's Body field is not populated.
+func RequestFromMap(params map[string]string) (*http.Request, error) {
+ r := new(http.Request)
+ r.Method = params["REQUEST_METHOD"]
+ if r.Method == "" {
+ return nil, errors.New("cgi: no REQUEST_METHOD in environment")
+ }
+
+ r.Proto = params["SERVER_PROTOCOL"]
+ var ok bool
+ r.ProtoMajor, r.ProtoMinor, ok = http.ParseHTTPVersion(r.Proto)
+ if !ok {
+ return nil, errors.New("cgi: invalid SERVER_PROTOCOL version")
+ }
+
+ r.Close = true
+ r.Trailer = http.Header{}
+ r.Header = http.Header{}
+
+ r.Host = params["HTTP_HOST"]
+
+ if lenstr := params["CONTENT_LENGTH"]; lenstr != "" {
+ clen, err := strconv.ParseInt(lenstr, 10, 64)
+ if err != nil {
+ return nil, errors.New("cgi: bad CONTENT_LENGTH in environment: " + lenstr)
+ }
+ r.ContentLength = clen
+ }
+
+ if ct := params["CONTENT_TYPE"]; ct != "" {
+ r.Header.Set("Content-Type", ct)
+ }
+
+ // Copy "HTTP_FOO_BAR" variables to "Foo-Bar" Headers
+ for k, v := range params {
+ if k == "HTTP_HOST" {
+ continue
+ }
+ if after, found := strings.CutPrefix(k, "HTTP_"); found {
+ r.Header.Add(strings.ReplaceAll(after, "_", "-"), v)
+ }
+ }
+
+ uriStr := params["REQUEST_URI"]
+ if uriStr == "" {
+ // Fallback to SCRIPT_NAME, PATH_INFO and QUERY_STRING.
+ uriStr = params["SCRIPT_NAME"] + params["PATH_INFO"]
+ s := params["QUERY_STRING"]
+ if s != "" {
+ uriStr += "?" + s
+ }
+ }
+
+ // There's apparently a de-facto standard for this.
+ // https://web.archive.org/web/20170105004655/http://docstore.mik.ua/orelly/linux/cgi/ch03_02.htm#ch03-35636
+ if s := params["HTTPS"]; s == "on" || s == "ON" || s == "1" {
+ r.TLS = &tls.ConnectionState{HandshakeComplete: true}
+ }
+
+ if r.Host != "" {
+ // Hostname is provided, so we can reasonably construct a URL.
+ rawurl := r.Host + uriStr
+ if r.TLS == nil {
+ rawurl = "http://" + rawurl
+ } else {
+ rawurl = "https://" + rawurl
+ }
+ url, err := url.Parse(rawurl)
+ if err != nil {
+ return nil, errors.New("cgi: failed to parse host and REQUEST_URI into a URL: " + rawurl)
+ }
+ r.URL = url
+ }
+ // Fallback logic if we don't have a Host header or the URL
+ // failed to parse
+ if r.URL == nil {
+ url, err := url.Parse(uriStr)
+ if err != nil {
+ return nil, errors.New("cgi: failed to parse REQUEST_URI into a URL: " + uriStr)
+ }
+ r.URL = url
+ }
+
+ // Request.RemoteAddr has its port set by Go's standard http
+ // server, so we do here too.
+ remotePort, _ := strconv.Atoi(params["REMOTE_PORT"]) // zero if unset or invalid
+ r.RemoteAddr = net.JoinHostPort(params["REMOTE_ADDR"], strconv.Itoa(remotePort))
+
+ return r, nil
+}
+
+// Serve executes the provided Handler on the currently active CGI
+// request, if any. If there's no current CGI environment
+// an error is returned. The provided handler may be nil to use
+// http.DefaultServeMux.
+func Serve(handler http.Handler) error {
+ req, err := Request()
+ if err != nil {
+ return err
+ }
+ if req.Body == nil {
+ req.Body = http.NoBody
+ }
+ if handler == nil {
+ handler = http.DefaultServeMux
+ }
+ rw := &response{
+ req: req,
+ header: make(http.Header),
+ bufw: bufio.NewWriter(os.Stdout),
+ }
+ handler.ServeHTTP(rw, req)
+ rw.Write(nil) // make sure a response is sent
+ if err = rw.bufw.Flush(); err != nil {
+ return err
+ }
+ return nil
+}
+
+type response struct {
+ req *http.Request
+ header http.Header
+ code int
+ wroteHeader bool
+ wroteCGIHeader bool
+ bufw *bufio.Writer
+}
+
+func (r *response) Flush() {
+ r.bufw.Flush()
+}
+
+func (r *response) Header() http.Header {
+ return r.header
+}
+
+func (r *response) Write(p []byte) (n int, err error) {
+ if !r.wroteHeader {
+ r.WriteHeader(http.StatusOK)
+ }
+ if !r.wroteCGIHeader {
+ r.writeCGIHeader(p)
+ }
+ return r.bufw.Write(p)
+}
+
+func (r *response) WriteHeader(code int) {
+ if r.wroteHeader {
+ // Note: explicitly using Stderr, as Stdout is our HTTP output.
+ fmt.Fprintf(os.Stderr, "CGI attempted to write header twice on request for %s", r.req.URL)
+ return
+ }
+ r.wroteHeader = true
+ r.code = code
+}
+
+// writeCGIHeader finalizes the header sent to the client and writes it to the output.
+// p is not written by writeHeader, but is the first chunk of the body
+// that will be written. It is sniffed for a Content-Type if none is
+// set explicitly.
+func (r *response) writeCGIHeader(p []byte) {
+ if r.wroteCGIHeader {
+ return
+ }
+ r.wroteCGIHeader = true
+ fmt.Fprintf(r.bufw, "Status: %d %s\r\n", r.code, http.StatusText(r.code))
+ if _, hasType := r.header["Content-Type"]; !hasType {
+ r.header.Set("Content-Type", http.DetectContentType(p))
+ }
+ r.header.Write(r.bufw)
+ r.bufw.WriteString("\r\n")
+ r.bufw.Flush()
+}
diff --git a/src/net/http/cgi/child_test.go b/src/net/http/cgi/child_test.go
new file mode 100644
index 0000000..18cf789
--- /dev/null
+++ b/src/net/http/cgi/child_test.go
@@ -0,0 +1,208 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Tests for CGI (the child process perspective)
+
+package cgi
+
+import (
+ "bufio"
+ "bytes"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+)
+
+func TestRequest(t *testing.T) {
+ env := map[string]string{
+ "SERVER_PROTOCOL": "HTTP/1.1",
+ "REQUEST_METHOD": "GET",
+ "HTTP_HOST": "example.com",
+ "HTTP_REFERER": "elsewhere",
+ "HTTP_USER_AGENT": "goclient",
+ "HTTP_FOO_BAR": "baz",
+ "REQUEST_URI": "/path?a=b",
+ "CONTENT_LENGTH": "123",
+ "CONTENT_TYPE": "text/xml",
+ "REMOTE_ADDR": "5.6.7.8",
+ "REMOTE_PORT": "54321",
+ }
+ req, err := RequestFromMap(env)
+ if err != nil {
+ t.Fatalf("RequestFromMap: %v", err)
+ }
+ if g, e := req.UserAgent(), "goclient"; e != g {
+ t.Errorf("expected UserAgent %q; got %q", e, g)
+ }
+ if g, e := req.Method, "GET"; e != g {
+ t.Errorf("expected Method %q; got %q", e, g)
+ }
+ if g, e := req.Header.Get("Content-Type"), "text/xml"; e != g {
+ t.Errorf("expected Content-Type %q; got %q", e, g)
+ }
+ if g, e := req.ContentLength, int64(123); e != g {
+ t.Errorf("expected ContentLength %d; got %d", e, g)
+ }
+ if g, e := req.Referer(), "elsewhere"; e != g {
+ t.Errorf("expected Referer %q; got %q", e, g)
+ }
+ if req.Header == nil {
+ t.Fatalf("unexpected nil Header")
+ }
+ if g, e := req.Header.Get("Foo-Bar"), "baz"; e != g {
+ t.Errorf("expected Foo-Bar %q; got %q", e, g)
+ }
+ if g, e := req.URL.String(), "http://example.com/path?a=b"; e != g {
+ t.Errorf("expected URL %q; got %q", e, g)
+ }
+ if g, e := req.FormValue("a"), "b"; e != g {
+ t.Errorf("expected FormValue(a) %q; got %q", e, g)
+ }
+ if req.Trailer == nil {
+ t.Errorf("unexpected nil Trailer")
+ }
+ if req.TLS != nil {
+ t.Errorf("expected nil TLS")
+ }
+ if e, g := "5.6.7.8:54321", req.RemoteAddr; e != g {
+ t.Errorf("RemoteAddr: got %q; want %q", g, e)
+ }
+}
+
+func TestRequestWithTLS(t *testing.T) {
+ env := map[string]string{
+ "SERVER_PROTOCOL": "HTTP/1.1",
+ "REQUEST_METHOD": "GET",
+ "HTTP_HOST": "example.com",
+ "HTTP_REFERER": "elsewhere",
+ "REQUEST_URI": "/path?a=b",
+ "CONTENT_TYPE": "text/xml",
+ "HTTPS": "1",
+ "REMOTE_ADDR": "5.6.7.8",
+ }
+ req, err := RequestFromMap(env)
+ if err != nil {
+ t.Fatalf("RequestFromMap: %v", err)
+ }
+ if g, e := req.URL.String(), "https://example.com/path?a=b"; e != g {
+ t.Errorf("expected URL %q; got %q", e, g)
+ }
+ if req.TLS == nil {
+ t.Errorf("expected non-nil TLS")
+ }
+}
+
+func TestRequestWithoutHost(t *testing.T) {
+ env := map[string]string{
+ "SERVER_PROTOCOL": "HTTP/1.1",
+ "HTTP_HOST": "",
+ "REQUEST_METHOD": "GET",
+ "REQUEST_URI": "/path?a=b",
+ "CONTENT_LENGTH": "123",
+ }
+ req, err := RequestFromMap(env)
+ if err != nil {
+ t.Fatalf("RequestFromMap: %v", err)
+ }
+ if req.URL == nil {
+ t.Fatalf("unexpected nil URL")
+ }
+ if g, e := req.URL.String(), "/path?a=b"; e != g {
+ t.Errorf("URL = %q; want %q", g, e)
+ }
+}
+
+func TestRequestWithoutRequestURI(t *testing.T) {
+ env := map[string]string{
+ "SERVER_PROTOCOL": "HTTP/1.1",
+ "HTTP_HOST": "example.com",
+ "REQUEST_METHOD": "GET",
+ "SCRIPT_NAME": "/dir/scriptname",
+ "PATH_INFO": "/p1/p2",
+ "QUERY_STRING": "a=1&b=2",
+ "CONTENT_LENGTH": "123",
+ }
+ req, err := RequestFromMap(env)
+ if err != nil {
+ t.Fatalf("RequestFromMap: %v", err)
+ }
+ if req.URL == nil {
+ t.Fatalf("unexpected nil URL")
+ }
+ if g, e := req.URL.String(), "http://example.com/dir/scriptname/p1/p2?a=1&b=2"; e != g {
+ t.Errorf("URL = %q; want %q", g, e)
+ }
+}
+
+func TestRequestWithoutRemotePort(t *testing.T) {
+ env := map[string]string{
+ "SERVER_PROTOCOL": "HTTP/1.1",
+ "HTTP_HOST": "example.com",
+ "REQUEST_METHOD": "GET",
+ "REQUEST_URI": "/path?a=b",
+ "CONTENT_LENGTH": "123",
+ "REMOTE_ADDR": "5.6.7.8",
+ }
+ req, err := RequestFromMap(env)
+ if err != nil {
+ t.Fatalf("RequestFromMap: %v", err)
+ }
+ if e, g := "5.6.7.8:0", req.RemoteAddr; e != g {
+ t.Errorf("RemoteAddr: got %q; want %q", g, e)
+ }
+}
+
+func TestResponse(t *testing.T) {
+ var tests = []struct {
+ name string
+ body string
+ wantCT string
+ }{
+ {
+ name: "no body",
+ wantCT: "text/plain; charset=utf-8",
+ },
+ {
+ name: "html",
+ body: "<html><head><title>test page</title></head><body>This is a body</body></html>",
+ wantCT: "text/html; charset=utf-8",
+ },
+ {
+ name: "text",
+ body: strings.Repeat("gopher", 86),
+ wantCT: "text/plain; charset=utf-8",
+ },
+ {
+ name: "jpg",
+ body: "\xFF\xD8\xFF" + strings.Repeat("B", 1024),
+ wantCT: "image/jpeg",
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ var buf bytes.Buffer
+ resp := response{
+ req: httptest.NewRequest("GET", "/", nil),
+ header: http.Header{},
+ bufw: bufio.NewWriter(&buf),
+ }
+ n, err := resp.Write([]byte(tt.body))
+ if err != nil {
+ t.Errorf("Write: unexpected %v", err)
+ }
+ if want := len(tt.body); n != want {
+ t.Errorf("reported short Write: got %v want %v", n, want)
+ }
+ resp.writeCGIHeader(nil)
+ resp.Flush()
+ if got := resp.Header().Get("Content-Type"); got != tt.wantCT {
+ t.Errorf("wrong content-type: got %q, want %q", got, tt.wantCT)
+ }
+ if !bytes.HasSuffix(buf.Bytes(), []byte(tt.body)) {
+ t.Errorf("body was not correctly written")
+ }
+ })
+ }
+}
diff --git a/src/net/http/cgi/host.go b/src/net/http/cgi/host.go
new file mode 100644
index 0000000..073952a
--- /dev/null
+++ b/src/net/http/cgi/host.go
@@ -0,0 +1,413 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This file implements the host side of CGI (being the webserver
+// parent process).
+
+// Package cgi implements CGI (Common Gateway Interface) as specified
+// in RFC 3875.
+//
+// Note that using CGI means starting a new process to handle each
+// request, which is typically less efficient than using a
+// long-running server. This package is intended primarily for
+// compatibility with existing systems.
+package cgi
+
+import (
+ "bufio"
+ "fmt"
+ "io"
+ "log"
+ "net"
+ "net/http"
+ "net/textproto"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "regexp"
+ "runtime"
+ "strconv"
+ "strings"
+
+ "golang.org/x/net/http/httpguts"
+)
+
+var trailingPort = regexp.MustCompile(`:([0-9]+)$`)
+
+var osDefaultInheritEnv = func() []string {
+ switch runtime.GOOS {
+ case "darwin", "ios":
+ return []string{"DYLD_LIBRARY_PATH"}
+ case "android", "linux", "freebsd", "netbsd", "openbsd":
+ return []string{"LD_LIBRARY_PATH"}
+ case "hpux":
+ return []string{"LD_LIBRARY_PATH", "SHLIB_PATH"}
+ case "irix":
+ return []string{"LD_LIBRARY_PATH", "LD_LIBRARYN32_PATH", "LD_LIBRARY64_PATH"}
+ case "illumos", "solaris":
+ return []string{"LD_LIBRARY_PATH", "LD_LIBRARY_PATH_32", "LD_LIBRARY_PATH_64"}
+ case "windows":
+ return []string{"SystemRoot", "COMSPEC", "PATHEXT", "WINDIR"}
+ }
+ return nil
+}()
+
+// Handler runs an executable in a subprocess with a CGI environment.
+type Handler struct {
+ Path string // path to the CGI executable
+ Root string // root URI prefix of handler or empty for "/"
+
+ // Dir specifies the CGI executable's working directory.
+ // If Dir is empty, the base directory of Path is used.
+ // If Path has no base directory, the current working
+ // directory is used.
+ Dir string
+
+ Env []string // extra environment variables to set, if any, as "key=value"
+ InheritEnv []string // environment variables to inherit from host, as "key"
+ Logger *log.Logger // optional log for errors or nil to use log.Print
+ Args []string // optional arguments to pass to child process
+ Stderr io.Writer // optional stderr for the child process; nil means os.Stderr
+
+ // PathLocationHandler specifies the root http Handler that
+ // should handle internal redirects when the CGI process
+ // returns a Location header value starting with a "/", as
+ // specified in RFC 3875 § 6.3.2. This will likely be
+ // http.DefaultServeMux.
+ //
+ // If nil, a CGI response with a local URI path is instead sent
+ // back to the client and not redirected internally.
+ PathLocationHandler http.Handler
+}
+
+func (h *Handler) stderr() io.Writer {
+ if h.Stderr != nil {
+ return h.Stderr
+ }
+ return os.Stderr
+}
+
+// removeLeadingDuplicates remove leading duplicate in environments.
+// It's possible to override environment like following.
+//
+// cgi.Handler{
+// ...
+// Env: []string{"SCRIPT_FILENAME=foo.php"},
+// }
+func removeLeadingDuplicates(env []string) (ret []string) {
+ for i, e := range env {
+ found := false
+ if eq := strings.IndexByte(e, '='); eq != -1 {
+ keq := e[:eq+1] // "key="
+ for _, e2 := range env[i+1:] {
+ if strings.HasPrefix(e2, keq) {
+ found = true
+ break
+ }
+ }
+ }
+ if !found {
+ ret = append(ret, e)
+ }
+ }
+ return
+}
+
+func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
+ root := h.Root
+ if root == "" {
+ root = "/"
+ }
+
+ if len(req.TransferEncoding) > 0 && req.TransferEncoding[0] == "chunked" {
+ rw.WriteHeader(http.StatusBadRequest)
+ rw.Write([]byte("Chunked request bodies are not supported by CGI."))
+ return
+ }
+
+ pathInfo := req.URL.Path
+ if root != "/" && strings.HasPrefix(pathInfo, root) {
+ pathInfo = pathInfo[len(root):]
+ }
+
+ port := "80"
+ if matches := trailingPort.FindStringSubmatch(req.Host); len(matches) != 0 {
+ port = matches[1]
+ }
+
+ env := []string{
+ "SERVER_SOFTWARE=go",
+ "SERVER_PROTOCOL=HTTP/1.1",
+ "HTTP_HOST=" + req.Host,
+ "GATEWAY_INTERFACE=CGI/1.1",
+ "REQUEST_METHOD=" + req.Method,
+ "QUERY_STRING=" + req.URL.RawQuery,
+ "REQUEST_URI=" + req.URL.RequestURI(),
+ "PATH_INFO=" + pathInfo,
+ "SCRIPT_NAME=" + root,
+ "SCRIPT_FILENAME=" + h.Path,
+ "SERVER_PORT=" + port,
+ }
+
+ if remoteIP, remotePort, err := net.SplitHostPort(req.RemoteAddr); err == nil {
+ env = append(env, "REMOTE_ADDR="+remoteIP, "REMOTE_HOST="+remoteIP, "REMOTE_PORT="+remotePort)
+ } else {
+ // could not parse ip:port, let's use whole RemoteAddr and leave REMOTE_PORT undefined
+ env = append(env, "REMOTE_ADDR="+req.RemoteAddr, "REMOTE_HOST="+req.RemoteAddr)
+ }
+
+ if hostDomain, _, err := net.SplitHostPort(req.Host); err == nil {
+ env = append(env, "SERVER_NAME="+hostDomain)
+ } else {
+ env = append(env, "SERVER_NAME="+req.Host)
+ }
+
+ if req.TLS != nil {
+ env = append(env, "HTTPS=on")
+ }
+
+ for k, v := range req.Header {
+ k = strings.Map(upperCaseAndUnderscore, k)
+ if k == "PROXY" {
+ // See Issue 16405
+ continue
+ }
+ joinStr := ", "
+ if k == "COOKIE" {
+ joinStr = "; "
+ }
+ env = append(env, "HTTP_"+k+"="+strings.Join(v, joinStr))
+ }
+
+ if req.ContentLength > 0 {
+ env = append(env, fmt.Sprintf("CONTENT_LENGTH=%d", req.ContentLength))
+ }
+ if ctype := req.Header.Get("Content-Type"); ctype != "" {
+ env = append(env, "CONTENT_TYPE="+ctype)
+ }
+
+ envPath := os.Getenv("PATH")
+ if envPath == "" {
+ envPath = "/bin:/usr/bin:/usr/ucb:/usr/bsd:/usr/local/bin"
+ }
+ env = append(env, "PATH="+envPath)
+
+ for _, e := range h.InheritEnv {
+ if v := os.Getenv(e); v != "" {
+ env = append(env, e+"="+v)
+ }
+ }
+
+ for _, e := range osDefaultInheritEnv {
+ if v := os.Getenv(e); v != "" {
+ env = append(env, e+"="+v)
+ }
+ }
+
+ if h.Env != nil {
+ env = append(env, h.Env...)
+ }
+
+ env = removeLeadingDuplicates(env)
+
+ var cwd, path string
+ if h.Dir != "" {
+ path = h.Path
+ cwd = h.Dir
+ } else {
+ cwd, path = filepath.Split(h.Path)
+ }
+ if cwd == "" {
+ cwd = "."
+ }
+
+ internalError := func(err error) {
+ rw.WriteHeader(http.StatusInternalServerError)
+ h.printf("CGI error: %v", err)
+ }
+
+ cmd := &exec.Cmd{
+ Path: path,
+ Args: append([]string{h.Path}, h.Args...),
+ Dir: cwd,
+ Env: env,
+ Stderr: h.stderr(),
+ }
+ if req.ContentLength != 0 {
+ cmd.Stdin = req.Body
+ }
+ stdoutRead, err := cmd.StdoutPipe()
+ if err != nil {
+ internalError(err)
+ return
+ }
+
+ err = cmd.Start()
+ if err != nil {
+ internalError(err)
+ return
+ }
+ if hook := testHookStartProcess; hook != nil {
+ hook(cmd.Process)
+ }
+ defer cmd.Wait()
+ defer stdoutRead.Close()
+
+ linebody := bufio.NewReaderSize(stdoutRead, 1024)
+ headers := make(http.Header)
+ statusCode := 0
+ headerLines := 0
+ sawBlankLine := false
+ for {
+ line, isPrefix, err := linebody.ReadLine()
+ if isPrefix {
+ rw.WriteHeader(http.StatusInternalServerError)
+ h.printf("cgi: long header line from subprocess.")
+ return
+ }
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ rw.WriteHeader(http.StatusInternalServerError)
+ h.printf("cgi: error reading headers: %v", err)
+ return
+ }
+ if len(line) == 0 {
+ sawBlankLine = true
+ break
+ }
+ headerLines++
+ header, val, ok := strings.Cut(string(line), ":")
+ if !ok {
+ h.printf("cgi: bogus header line: %s", string(line))
+ continue
+ }
+ if !httpguts.ValidHeaderFieldName(header) {
+ h.printf("cgi: invalid header name: %q", header)
+ continue
+ }
+ val = textproto.TrimString(val)
+ switch {
+ case header == "Status":
+ if len(val) < 3 {
+ h.printf("cgi: bogus status (short): %q", val)
+ return
+ }
+ code, err := strconv.Atoi(val[0:3])
+ if err != nil {
+ h.printf("cgi: bogus status: %q", val)
+ h.printf("cgi: line was %q", line)
+ return
+ }
+ statusCode = code
+ default:
+ headers.Add(header, val)
+ }
+ }
+ if headerLines == 0 || !sawBlankLine {
+ rw.WriteHeader(http.StatusInternalServerError)
+ h.printf("cgi: no headers")
+ return
+ }
+
+ if loc := headers.Get("Location"); loc != "" {
+ if strings.HasPrefix(loc, "/") && h.PathLocationHandler != nil {
+ h.handleInternalRedirect(rw, req, loc)
+ return
+ }
+ if statusCode == 0 {
+ statusCode = http.StatusFound
+ }
+ }
+
+ if statusCode == 0 && headers.Get("Content-Type") == "" {
+ rw.WriteHeader(http.StatusInternalServerError)
+ h.printf("cgi: missing required Content-Type in headers")
+ return
+ }
+
+ if statusCode == 0 {
+ statusCode = http.StatusOK
+ }
+
+ // Copy headers to rw's headers, after we've decided not to
+ // go into handleInternalRedirect, which won't want its rw
+ // headers to have been touched.
+ for k, vv := range headers {
+ for _, v := range vv {
+ rw.Header().Add(k, v)
+ }
+ }
+
+ rw.WriteHeader(statusCode)
+
+ _, err = io.Copy(rw, linebody)
+ if err != nil {
+ h.printf("cgi: copy error: %v", err)
+ // And kill the child CGI process so we don't hang on
+ // the deferred cmd.Wait above if the error was just
+ // the client (rw) going away. If it was a read error
+ // (because the child died itself), then the extra
+ // kill of an already-dead process is harmless (the PID
+ // won't be reused until the Wait above).
+ cmd.Process.Kill()
+ }
+}
+
+func (h *Handler) printf(format string, v ...any) {
+ if h.Logger != nil {
+ h.Logger.Printf(format, v...)
+ } else {
+ log.Printf(format, v...)
+ }
+}
+
+func (h *Handler) handleInternalRedirect(rw http.ResponseWriter, req *http.Request, path string) {
+ url, err := req.URL.Parse(path)
+ if err != nil {
+ rw.WriteHeader(http.StatusInternalServerError)
+ h.printf("cgi: error resolving local URI path %q: %v", path, err)
+ return
+ }
+ // TODO: RFC 3875 isn't clear if only GET is supported, but it
+ // suggests so: "Note that any message-body attached to the
+ // request (such as for a POST request) may not be available
+ // to the resource that is the target of the redirect." We
+ // should do some tests against Apache to see how it handles
+ // POST, HEAD, etc. Does the internal redirect get the same
+ // method or just GET? What about incoming headers?
+ // (e.g. Cookies) Which headers, if any, are copied into the
+ // second request?
+ newReq := &http.Request{
+ Method: "GET",
+ URL: url,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: make(http.Header),
+ Host: url.Host,
+ RemoteAddr: req.RemoteAddr,
+ TLS: req.TLS,
+ }
+ h.PathLocationHandler.ServeHTTP(rw, newReq)
+}
+
+func upperCaseAndUnderscore(r rune) rune {
+ switch {
+ case r >= 'a' && r <= 'z':
+ return r - ('a' - 'A')
+ case r == '-':
+ return '_'
+ case r == '=':
+ // Maybe not part of the CGI 'spec' but would mess up
+ // the environment in any case, as Go represents the
+ // environment as a slice of "key=value" strings.
+ return '_'
+ }
+ // TODO: other transformations in spec or practice?
+ return r
+}
+
+var testHookStartProcess func(*os.Process) // nil except for some tests
diff --git a/src/net/http/cgi/host_test.go b/src/net/http/cgi/host_test.go
new file mode 100644
index 0000000..860e9b3
--- /dev/null
+++ b/src/net/http/cgi/host_test.go
@@ -0,0 +1,577 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Tests for package cgi
+
+package cgi
+
+import (
+ "bufio"
+ "fmt"
+ "io"
+ "net"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "reflect"
+ "runtime"
+ "strconv"
+ "strings"
+ "testing"
+ "time"
+)
+
+func newRequest(httpreq string) *http.Request {
+ buf := bufio.NewReader(strings.NewReader(httpreq))
+ req, err := http.ReadRequest(buf)
+ if err != nil {
+ panic("cgi: bogus http request in test: " + httpreq)
+ }
+ req.RemoteAddr = "1.2.3.4:1234"
+ return req
+}
+
+func runCgiTest(t *testing.T, h *Handler,
+ httpreq string,
+ expectedMap map[string]string, checks ...func(reqInfo map[string]string)) *httptest.ResponseRecorder {
+ rw := httptest.NewRecorder()
+ req := newRequest(httpreq)
+ h.ServeHTTP(rw, req)
+ runResponseChecks(t, rw, expectedMap, checks...)
+ return rw
+}
+
+func runResponseChecks(t *testing.T, rw *httptest.ResponseRecorder,
+ expectedMap map[string]string, checks ...func(reqInfo map[string]string)) {
+ // Make a map to hold the test map that the CGI returns.
+ m := make(map[string]string)
+ m["_body"] = rw.Body.String()
+ linesRead := 0
+readlines:
+ for {
+ line, err := rw.Body.ReadString('\n')
+ switch {
+ case err == io.EOF:
+ break readlines
+ case err != nil:
+ t.Fatalf("unexpected error reading from CGI: %v", err)
+ }
+ linesRead++
+ trimmedLine := strings.TrimRight(line, "\r\n")
+ k, v, ok := strings.Cut(trimmedLine, "=")
+ if !ok {
+ t.Fatalf("Unexpected response from invalid line number %v: %q; existing map=%v",
+ linesRead, line, m)
+ }
+ m[k] = v
+ }
+
+ for key, expected := range expectedMap {
+ got := m[key]
+ if key == "cwd" {
+ // For Windows. golang.org/issue/4645.
+ fi1, _ := os.Stat(got)
+ fi2, _ := os.Stat(expected)
+ if os.SameFile(fi1, fi2) {
+ got = expected
+ }
+ }
+ if got != expected {
+ t.Errorf("for key %q got %q; expected %q", key, got, expected)
+ }
+ }
+ for _, check := range checks {
+ check(m)
+ }
+}
+
+var cgiTested, cgiWorks bool
+
+func check(t *testing.T) {
+ if !cgiTested {
+ cgiTested = true
+ cgiWorks = exec.Command("./testdata/test.cgi").Run() == nil
+ }
+ if !cgiWorks {
+ // No Perl on Windows, needed by test.cgi
+ // TODO: make the child process be Go, not Perl.
+ t.Skip("Skipping test: test.cgi failed.")
+ }
+}
+
+func TestCGIBasicGet(t *testing.T) {
+ check(t)
+ h := &Handler{
+ Path: "testdata/test.cgi",
+ Root: "/test.cgi",
+ }
+ expectedMap := map[string]string{
+ "test": "Hello CGI",
+ "param-a": "b",
+ "param-foo": "bar",
+ "env-GATEWAY_INTERFACE": "CGI/1.1",
+ "env-HTTP_HOST": "example.com:80",
+ "env-PATH_INFO": "",
+ "env-QUERY_STRING": "foo=bar&a=b",
+ "env-REMOTE_ADDR": "1.2.3.4",
+ "env-REMOTE_HOST": "1.2.3.4",
+ "env-REMOTE_PORT": "1234",
+ "env-REQUEST_METHOD": "GET",
+ "env-REQUEST_URI": "/test.cgi?foo=bar&a=b",
+ "env-SCRIPT_FILENAME": "testdata/test.cgi",
+ "env-SCRIPT_NAME": "/test.cgi",
+ "env-SERVER_NAME": "example.com",
+ "env-SERVER_PORT": "80",
+ "env-SERVER_SOFTWARE": "go",
+ }
+ replay := runCgiTest(t, h, "GET /test.cgi?foo=bar&a=b HTTP/1.0\nHost: example.com:80\n\n", expectedMap)
+
+ if expected, got := "text/html", replay.Header().Get("Content-Type"); got != expected {
+ t.Errorf("got a Content-Type of %q; expected %q", got, expected)
+ }
+ if expected, got := "X-Test-Value", replay.Header().Get("X-Test-Header"); got != expected {
+ t.Errorf("got a X-Test-Header of %q; expected %q", got, expected)
+ }
+}
+
+func TestCGIEnvIPv6(t *testing.T) {
+ check(t)
+ h := &Handler{
+ Path: "testdata/test.cgi",
+ Root: "/test.cgi",
+ }
+ expectedMap := map[string]string{
+ "test": "Hello CGI",
+ "param-a": "b",
+ "param-foo": "bar",
+ "env-GATEWAY_INTERFACE": "CGI/1.1",
+ "env-HTTP_HOST": "example.com",
+ "env-PATH_INFO": "",
+ "env-QUERY_STRING": "foo=bar&a=b",
+ "env-REMOTE_ADDR": "2000::3000",
+ "env-REMOTE_HOST": "2000::3000",
+ "env-REMOTE_PORT": "12345",
+ "env-REQUEST_METHOD": "GET",
+ "env-REQUEST_URI": "/test.cgi?foo=bar&a=b",
+ "env-SCRIPT_FILENAME": "testdata/test.cgi",
+ "env-SCRIPT_NAME": "/test.cgi",
+ "env-SERVER_NAME": "example.com",
+ "env-SERVER_PORT": "80",
+ "env-SERVER_SOFTWARE": "go",
+ }
+
+ rw := httptest.NewRecorder()
+ req := newRequest("GET /test.cgi?foo=bar&a=b HTTP/1.0\nHost: example.com\n\n")
+ req.RemoteAddr = "[2000::3000]:12345"
+ h.ServeHTTP(rw, req)
+ runResponseChecks(t, rw, expectedMap)
+}
+
+func TestCGIBasicGetAbsPath(t *testing.T) {
+ check(t)
+ pwd, err := os.Getwd()
+ if err != nil {
+ t.Fatalf("getwd error: %v", err)
+ }
+ h := &Handler{
+ Path: pwd + "/testdata/test.cgi",
+ Root: "/test.cgi",
+ }
+ expectedMap := map[string]string{
+ "env-REQUEST_URI": "/test.cgi?foo=bar&a=b",
+ "env-SCRIPT_FILENAME": pwd + "/testdata/test.cgi",
+ "env-SCRIPT_NAME": "/test.cgi",
+ }
+ runCgiTest(t, h, "GET /test.cgi?foo=bar&a=b HTTP/1.0\nHost: example.com\n\n", expectedMap)
+}
+
+func TestPathInfo(t *testing.T) {
+ check(t)
+ h := &Handler{
+ Path: "testdata/test.cgi",
+ Root: "/test.cgi",
+ }
+ expectedMap := map[string]string{
+ "param-a": "b",
+ "env-PATH_INFO": "/extrapath",
+ "env-QUERY_STRING": "a=b",
+ "env-REQUEST_URI": "/test.cgi/extrapath?a=b",
+ "env-SCRIPT_FILENAME": "testdata/test.cgi",
+ "env-SCRIPT_NAME": "/test.cgi",
+ }
+ runCgiTest(t, h, "GET /test.cgi/extrapath?a=b HTTP/1.0\nHost: example.com\n\n", expectedMap)
+}
+
+func TestPathInfoDirRoot(t *testing.T) {
+ check(t)
+ h := &Handler{
+ Path: "testdata/test.cgi",
+ Root: "/myscript/",
+ }
+ expectedMap := map[string]string{
+ "env-PATH_INFO": "bar",
+ "env-QUERY_STRING": "a=b",
+ "env-REQUEST_URI": "/myscript/bar?a=b",
+ "env-SCRIPT_FILENAME": "testdata/test.cgi",
+ "env-SCRIPT_NAME": "/myscript/",
+ }
+ runCgiTest(t, h, "GET /myscript/bar?a=b HTTP/1.0\nHost: example.com\n\n", expectedMap)
+}
+
+func TestDupHeaders(t *testing.T) {
+ check(t)
+ h := &Handler{
+ Path: "testdata/test.cgi",
+ }
+ expectedMap := map[string]string{
+ "env-REQUEST_URI": "/myscript/bar?a=b",
+ "env-SCRIPT_FILENAME": "testdata/test.cgi",
+ "env-HTTP_COOKIE": "nom=NOM; yum=YUM",
+ "env-HTTP_X_FOO": "val1, val2",
+ }
+ runCgiTest(t, h, "GET /myscript/bar?a=b HTTP/1.0\n"+
+ "Cookie: nom=NOM\n"+
+ "Cookie: yum=YUM\n"+
+ "X-Foo: val1\n"+
+ "X-Foo: val2\n"+
+ "Host: example.com\n\n",
+ expectedMap)
+}
+
+// Issue 16405: CGI+http.Transport differing uses of HTTP_PROXY.
+// Verify we don't set the HTTP_PROXY environment variable.
+// Hope nobody was depending on it. It's not a known header, though.
+func TestDropProxyHeader(t *testing.T) {
+ check(t)
+ h := &Handler{
+ Path: "testdata/test.cgi",
+ }
+ expectedMap := map[string]string{
+ "env-REQUEST_URI": "/myscript/bar?a=b",
+ "env-SCRIPT_FILENAME": "testdata/test.cgi",
+ "env-HTTP_X_FOO": "a",
+ }
+ runCgiTest(t, h, "GET /myscript/bar?a=b HTTP/1.0\n"+
+ "X-Foo: a\n"+
+ "Proxy: should_be_stripped\n"+
+ "Host: example.com\n\n",
+ expectedMap,
+ func(reqInfo map[string]string) {
+ if v, ok := reqInfo["env-HTTP_PROXY"]; ok {
+ t.Errorf("HTTP_PROXY = %q; should be absent", v)
+ }
+ })
+}
+
+func TestPathInfoNoRoot(t *testing.T) {
+ check(t)
+ h := &Handler{
+ Path: "testdata/test.cgi",
+ Root: "",
+ }
+ expectedMap := map[string]string{
+ "env-PATH_INFO": "/bar",
+ "env-QUERY_STRING": "a=b",
+ "env-REQUEST_URI": "/bar?a=b",
+ "env-SCRIPT_FILENAME": "testdata/test.cgi",
+ "env-SCRIPT_NAME": "/",
+ }
+ runCgiTest(t, h, "GET /bar?a=b HTTP/1.0\nHost: example.com\n\n", expectedMap)
+}
+
+func TestCGIBasicPost(t *testing.T) {
+ check(t)
+ postReq := `POST /test.cgi?a=b HTTP/1.0
+Host: example.com
+Content-Type: application/x-www-form-urlencoded
+Content-Length: 15
+
+postfoo=postbar`
+ h := &Handler{
+ Path: "testdata/test.cgi",
+ Root: "/test.cgi",
+ }
+ expectedMap := map[string]string{
+ "test": "Hello CGI",
+ "param-postfoo": "postbar",
+ "env-REQUEST_METHOD": "POST",
+ "env-CONTENT_LENGTH": "15",
+ "env-REQUEST_URI": "/test.cgi?a=b",
+ }
+ runCgiTest(t, h, postReq, expectedMap)
+}
+
+func chunk(s string) string {
+ return fmt.Sprintf("%x\r\n%s\r\n", len(s), s)
+}
+
+// The CGI spec doesn't allow chunked requests.
+func TestCGIPostChunked(t *testing.T) {
+ check(t)
+ postReq := `POST /test.cgi?a=b HTTP/1.1
+Host: example.com
+Content-Type: application/x-www-form-urlencoded
+Transfer-Encoding: chunked
+
+` + chunk("postfoo") + chunk("=") + chunk("postbar") + chunk("")
+
+ h := &Handler{
+ Path: "testdata/test.cgi",
+ Root: "/test.cgi",
+ }
+ expectedMap := map[string]string{}
+ resp := runCgiTest(t, h, postReq, expectedMap)
+ if got, expected := resp.Code, http.StatusBadRequest; got != expected {
+ t.Fatalf("Expected %v response code from chunked request body; got %d",
+ expected, got)
+ }
+}
+
+func TestRedirect(t *testing.T) {
+ check(t)
+ h := &Handler{
+ Path: "testdata/test.cgi",
+ Root: "/test.cgi",
+ }
+ rec := runCgiTest(t, h, "GET /test.cgi?loc=http://foo.com/ HTTP/1.0\nHost: example.com\n\n", nil)
+ if e, g := 302, rec.Code; e != g {
+ t.Errorf("expected status code %d; got %d", e, g)
+ }
+ if e, g := "http://foo.com/", rec.Header().Get("Location"); e != g {
+ t.Errorf("expected Location header of %q; got %q", e, g)
+ }
+}
+
+func TestInternalRedirect(t *testing.T) {
+ check(t)
+ baseHandler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
+ fmt.Fprintf(rw, "basepath=%s\n", req.URL.Path)
+ fmt.Fprintf(rw, "remoteaddr=%s\n", req.RemoteAddr)
+ })
+ h := &Handler{
+ Path: "testdata/test.cgi",
+ Root: "/test.cgi",
+ PathLocationHandler: baseHandler,
+ }
+ expectedMap := map[string]string{
+ "basepath": "/foo",
+ "remoteaddr": "1.2.3.4:1234",
+ }
+ runCgiTest(t, h, "GET /test.cgi?loc=/foo HTTP/1.0\nHost: example.com\n\n", expectedMap)
+}
+
+// TestCopyError tests that we kill the process if there's an error copying
+// its output. (for example, from the client having gone away)
+func TestCopyError(t *testing.T) {
+ check(t)
+ if runtime.GOOS == "windows" {
+ t.Skipf("skipping test on %q", runtime.GOOS)
+ }
+ h := &Handler{
+ Path: "testdata/test.cgi",
+ Root: "/test.cgi",
+ }
+ ts := httptest.NewServer(h)
+ defer ts.Close()
+
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ req, _ := http.NewRequest("GET", "http://example.com/test.cgi?bigresponse=1", nil)
+ err = req.Write(conn)
+ if err != nil {
+ t.Fatalf("Write: %v", err)
+ }
+
+ res, err := http.ReadResponse(bufio.NewReader(conn), req)
+ if err != nil {
+ t.Fatalf("ReadResponse: %v", err)
+ }
+
+ pidstr := res.Header.Get("X-CGI-Pid")
+ if pidstr == "" {
+ t.Fatalf("expected an X-CGI-Pid header in response")
+ }
+ pid, err := strconv.Atoi(pidstr)
+ if err != nil {
+ t.Fatalf("invalid X-CGI-Pid value")
+ }
+
+ var buf [5000]byte
+ n, err := io.ReadFull(res.Body, buf[:])
+ if err != nil {
+ t.Fatalf("ReadFull: %d bytes, %v", n, err)
+ }
+
+ childRunning := func() bool {
+ return isProcessRunning(pid)
+ }
+
+ if !childRunning() {
+ t.Fatalf("pre-conn.Close, expected child to be running")
+ }
+ conn.Close()
+
+ tries := 0
+ for tries < 25 && childRunning() {
+ time.Sleep(50 * time.Millisecond * time.Duration(tries))
+ tries++
+ }
+ if childRunning() {
+ t.Fatalf("post-conn.Close, expected child to be gone")
+ }
+}
+
+func TestDirUnix(t *testing.T) {
+ check(t)
+ if runtime.GOOS == "windows" {
+ t.Skipf("skipping test on %q", runtime.GOOS)
+ }
+ cwd, _ := os.Getwd()
+ h := &Handler{
+ Path: "testdata/test.cgi",
+ Root: "/test.cgi",
+ Dir: cwd,
+ }
+ expectedMap := map[string]string{
+ "cwd": cwd,
+ }
+ runCgiTest(t, h, "GET /test.cgi HTTP/1.0\nHost: example.com\n\n", expectedMap)
+
+ cwd, _ = os.Getwd()
+ cwd = filepath.Join(cwd, "testdata")
+ h = &Handler{
+ Path: "testdata/test.cgi",
+ Root: "/test.cgi",
+ }
+ expectedMap = map[string]string{
+ "cwd": cwd,
+ }
+ runCgiTest(t, h, "GET /test.cgi HTTP/1.0\nHost: example.com\n\n", expectedMap)
+}
+
+func findPerl(t *testing.T) string {
+ t.Helper()
+ perl, err := exec.LookPath("perl")
+ if err != nil {
+ t.Skip("Skipping test: perl not found.")
+ }
+ perl, _ = filepath.Abs(perl)
+
+ cmd := exec.Command(perl, "-e", "print 123")
+ cmd.Env = []string{"PATH=/garbage"}
+ out, err := cmd.Output()
+ if err != nil || string(out) != "123" {
+ t.Skipf("Skipping test: %s is not functional", perl)
+ }
+ return perl
+}
+
+func TestDirWindows(t *testing.T) {
+ if runtime.GOOS != "windows" {
+ t.Skip("Skipping windows specific test.")
+ }
+
+ cgifile, _ := filepath.Abs("testdata/test.cgi")
+
+ perl := findPerl(t)
+
+ cwd, _ := os.Getwd()
+ h := &Handler{
+ Path: perl,
+ Root: "/test.cgi",
+ Dir: cwd,
+ Args: []string{cgifile},
+ Env: []string{"SCRIPT_FILENAME=" + cgifile},
+ }
+ expectedMap := map[string]string{
+ "cwd": cwd,
+ }
+ runCgiTest(t, h, "GET /test.cgi HTTP/1.0\nHost: example.com\n\n", expectedMap)
+
+ // If not specify Dir on windows, working directory should be
+ // base directory of perl.
+ cwd, _ = filepath.Split(perl)
+ if cwd != "" && cwd[len(cwd)-1] == filepath.Separator {
+ cwd = cwd[:len(cwd)-1]
+ }
+ h = &Handler{
+ Path: perl,
+ Root: "/test.cgi",
+ Args: []string{cgifile},
+ Env: []string{"SCRIPT_FILENAME=" + cgifile},
+ }
+ expectedMap = map[string]string{
+ "cwd": cwd,
+ }
+ runCgiTest(t, h, "GET /test.cgi HTTP/1.0\nHost: example.com\n\n", expectedMap)
+}
+
+func TestEnvOverride(t *testing.T) {
+ check(t)
+ cgifile, _ := filepath.Abs("testdata/test.cgi")
+
+ perl := findPerl(t)
+
+ cwd, _ := os.Getwd()
+ h := &Handler{
+ Path: perl,
+ Root: "/test.cgi",
+ Dir: cwd,
+ Args: []string{cgifile},
+ Env: []string{
+ "SCRIPT_FILENAME=" + cgifile,
+ "REQUEST_URI=/foo/bar",
+ "PATH=/wibble"},
+ }
+ expectedMap := map[string]string{
+ "cwd": cwd,
+ "env-SCRIPT_FILENAME": cgifile,
+ "env-REQUEST_URI": "/foo/bar",
+ "env-PATH": "/wibble",
+ }
+ runCgiTest(t, h, "GET /test.cgi HTTP/1.0\nHost: example.com\n\n", expectedMap)
+}
+
+func TestHandlerStderr(t *testing.T) {
+ check(t)
+ var stderr strings.Builder
+ h := &Handler{
+ Path: "testdata/test.cgi",
+ Root: "/test.cgi",
+ Stderr: &stderr,
+ }
+
+ rw := httptest.NewRecorder()
+ req := newRequest("GET /test.cgi?writestderr=1 HTTP/1.0\nHost: example.com\n\n")
+ h.ServeHTTP(rw, req)
+ if got, want := stderr.String(), "Hello, stderr!\n"; got != want {
+ t.Errorf("Stderr = %q; want %q", got, want)
+ }
+}
+
+func TestRemoveLeadingDuplicates(t *testing.T) {
+ tests := []struct {
+ env []string
+ want []string
+ }{
+ {
+ env: []string{"a=b", "b=c", "a=b2"},
+ want: []string{"b=c", "a=b2"},
+ },
+ {
+ env: []string{"a=b", "b=c", "d", "e=f"},
+ want: []string{"a=b", "b=c", "d", "e=f"},
+ },
+ }
+ for _, tt := range tests {
+ got := removeLeadingDuplicates(tt.env)
+ if !reflect.DeepEqual(got, tt.want) {
+ t.Errorf("removeLeadingDuplicates(%q) = %q; want %q", tt.env, got, tt.want)
+ }
+ }
+}
diff --git a/src/net/http/cgi/integration_test.go b/src/net/http/cgi/integration_test.go
new file mode 100644
index 0000000..ef2eaf7
--- /dev/null
+++ b/src/net/http/cgi/integration_test.go
@@ -0,0 +1,272 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Tests a Go CGI program running under a Go CGI host process.
+// Further, the two programs are the same binary, just checking
+// their environment to figure out what mode to run in.
+
+package cgi
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "internal/testenv"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "os"
+ "strings"
+ "testing"
+ "time"
+)
+
+// This test is a CGI host (testing host.go) that runs its own binary
+// as a child process testing the other half of CGI (child.go).
+func TestHostingOurselves(t *testing.T) {
+ testenv.MustHaveExec(t)
+
+ h := &Handler{
+ Path: os.Args[0],
+ Root: "/test.go",
+ Args: []string{"-test.run=TestBeChildCGIProcess"},
+ }
+ expectedMap := map[string]string{
+ "test": "Hello CGI-in-CGI",
+ "param-a": "b",
+ "param-foo": "bar",
+ "env-GATEWAY_INTERFACE": "CGI/1.1",
+ "env-HTTP_HOST": "example.com",
+ "env-PATH_INFO": "",
+ "env-QUERY_STRING": "foo=bar&a=b",
+ "env-REMOTE_ADDR": "1.2.3.4",
+ "env-REMOTE_HOST": "1.2.3.4",
+ "env-REMOTE_PORT": "1234",
+ "env-REQUEST_METHOD": "GET",
+ "env-REQUEST_URI": "/test.go?foo=bar&a=b",
+ "env-SCRIPT_FILENAME": os.Args[0],
+ "env-SCRIPT_NAME": "/test.go",
+ "env-SERVER_NAME": "example.com",
+ "env-SERVER_PORT": "80",
+ "env-SERVER_SOFTWARE": "go",
+ }
+ replay := runCgiTest(t, h, "GET /test.go?foo=bar&a=b HTTP/1.0\nHost: example.com\n\n", expectedMap)
+
+ if expected, got := "text/plain; charset=utf-8", replay.Header().Get("Content-Type"); got != expected {
+ t.Errorf("got a Content-Type of %q; expected %q", got, expected)
+ }
+ if expected, got := "X-Test-Value", replay.Header().Get("X-Test-Header"); got != expected {
+ t.Errorf("got a X-Test-Header of %q; expected %q", got, expected)
+ }
+}
+
+type customWriterRecorder struct {
+ w io.Writer
+ *httptest.ResponseRecorder
+}
+
+func (r *customWriterRecorder) Write(p []byte) (n int, err error) {
+ return r.w.Write(p)
+}
+
+type limitWriter struct {
+ w io.Writer
+ n int
+}
+
+func (w *limitWriter) Write(p []byte) (n int, err error) {
+ if len(p) > w.n {
+ p = p[:w.n]
+ }
+ if len(p) > 0 {
+ n, err = w.w.Write(p)
+ w.n -= n
+ }
+ if w.n == 0 {
+ err = errors.New("past write limit")
+ }
+ return
+}
+
+// If there's an error copying the child's output to the parent, test
+// that we kill the child.
+func TestKillChildAfterCopyError(t *testing.T) {
+ testenv.MustHaveExec(t)
+
+ h := &Handler{
+ Path: os.Args[0],
+ Root: "/test.go",
+ Args: []string{"-test.run=TestBeChildCGIProcess"},
+ }
+ req, _ := http.NewRequest("GET", "http://example.com/test.cgi?write-forever=1", nil)
+ rec := httptest.NewRecorder()
+ var out bytes.Buffer
+ const writeLen = 50 << 10
+ rw := &customWriterRecorder{&limitWriter{&out, writeLen}, rec}
+
+ h.ServeHTTP(rw, req)
+ if out.Len() != writeLen || out.Bytes()[0] != 'a' {
+ t.Errorf("unexpected output: %q", out.Bytes())
+ }
+}
+
+// Test that a child handler writing only headers works.
+// golang.org/issue/7196
+func TestChildOnlyHeaders(t *testing.T) {
+ testenv.MustHaveExec(t)
+
+ h := &Handler{
+ Path: os.Args[0],
+ Root: "/test.go",
+ Args: []string{"-test.run=TestBeChildCGIProcess"},
+ }
+ expectedMap := map[string]string{
+ "_body": "",
+ }
+ replay := runCgiTest(t, h, "GET /test.go?no-body=1 HTTP/1.0\nHost: example.com\n\n", expectedMap)
+ if expected, got := "X-Test-Value", replay.Header().Get("X-Test-Header"); got != expected {
+ t.Errorf("got a X-Test-Header of %q; expected %q", got, expected)
+ }
+}
+
+// Test that a child handler does not receive a nil Request Body.
+// golang.org/issue/39190
+func TestNilRequestBody(t *testing.T) {
+ testenv.MustHaveExec(t)
+
+ h := &Handler{
+ Path: os.Args[0],
+ Root: "/test.go",
+ Args: []string{"-test.run=TestBeChildCGIProcess"},
+ }
+ expectedMap := map[string]string{
+ "nil-request-body": "false",
+ }
+ _ = runCgiTest(t, h, "POST /test.go?nil-request-body=1 HTTP/1.0\nHost: example.com\n\n", expectedMap)
+ _ = runCgiTest(t, h, "POST /test.go?nil-request-body=1 HTTP/1.0\nHost: example.com\nContent-Length: 0\n\n", expectedMap)
+}
+
+func TestChildContentType(t *testing.T) {
+ testenv.MustHaveExec(t)
+
+ h := &Handler{
+ Path: os.Args[0],
+ Root: "/test.go",
+ Args: []string{"-test.run=TestBeChildCGIProcess"},
+ }
+ var tests = []struct {
+ name string
+ body string
+ wantCT string
+ }{
+ {
+ name: "no body",
+ wantCT: "text/plain; charset=utf-8",
+ },
+ {
+ name: "html",
+ body: "<html><head><title>test page</title></head><body>This is a body</body></html>",
+ wantCT: "text/html; charset=utf-8",
+ },
+ {
+ name: "text",
+ body: strings.Repeat("gopher", 86),
+ wantCT: "text/plain; charset=utf-8",
+ },
+ {
+ name: "jpg",
+ body: "\xFF\xD8\xFF" + strings.Repeat("B", 1024),
+ wantCT: "image/jpeg",
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ expectedMap := map[string]string{"_body": tt.body}
+ req := fmt.Sprintf("GET /test.go?exact-body=%s HTTP/1.0\nHost: example.com\n\n", url.QueryEscape(tt.body))
+ replay := runCgiTest(t, h, req, expectedMap)
+ if got := replay.Header().Get("Content-Type"); got != tt.wantCT {
+ t.Errorf("got a Content-Type of %q; expected it to start with %q", got, tt.wantCT)
+ }
+ })
+ }
+}
+
+// golang.org/issue/7198
+func Test500WithNoHeaders(t *testing.T) { want500Test(t, "/immediate-disconnect") }
+func Test500WithNoContentType(t *testing.T) { want500Test(t, "/no-content-type") }
+func Test500WithEmptyHeaders(t *testing.T) { want500Test(t, "/empty-headers") }
+
+func want500Test(t *testing.T, path string) {
+ h := &Handler{
+ Path: os.Args[0],
+ Root: "/test.go",
+ Args: []string{"-test.run=TestBeChildCGIProcess"},
+ }
+ expectedMap := map[string]string{
+ "_body": "",
+ }
+ replay := runCgiTest(t, h, "GET "+path+" HTTP/1.0\nHost: example.com\n\n", expectedMap)
+ if replay.Code != 500 {
+ t.Errorf("Got code %d; want 500", replay.Code)
+ }
+}
+
+type neverEnding byte
+
+func (b neverEnding) Read(p []byte) (n int, err error) {
+ for i := range p {
+ p[i] = byte(b)
+ }
+ return len(p), nil
+}
+
+// Note: not actually a test.
+func TestBeChildCGIProcess(t *testing.T) {
+ if os.Getenv("REQUEST_METHOD") == "" {
+ // Not in a CGI environment; skipping test.
+ return
+ }
+ switch os.Getenv("REQUEST_URI") {
+ case "/immediate-disconnect":
+ os.Exit(0)
+ case "/no-content-type":
+ fmt.Printf("Content-Length: 6\n\nHello\n")
+ os.Exit(0)
+ case "/empty-headers":
+ fmt.Printf("\nHello")
+ os.Exit(0)
+ }
+ Serve(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
+ if req.FormValue("nil-request-body") == "1" {
+ fmt.Fprintf(rw, "nil-request-body=%v\n", req.Body == nil)
+ return
+ }
+ rw.Header().Set("X-Test-Header", "X-Test-Value")
+ req.ParseForm()
+ if req.FormValue("no-body") == "1" {
+ return
+ }
+ if eb, ok := req.Form["exact-body"]; ok {
+ io.WriteString(rw, eb[0])
+ return
+ }
+ if req.FormValue("write-forever") == "1" {
+ io.Copy(rw, neverEnding('a'))
+ for {
+ time.Sleep(5 * time.Second) // hang forever, until killed
+ }
+ }
+ fmt.Fprintf(rw, "test=Hello CGI-in-CGI\n")
+ for k, vv := range req.Form {
+ for _, v := range vv {
+ fmt.Fprintf(rw, "param-%s=%s\n", k, v)
+ }
+ }
+ for _, kv := range os.Environ() {
+ fmt.Fprintf(rw, "env-%s\n", kv)
+ }
+ }))
+ os.Exit(0)
+}
diff --git a/src/net/http/cgi/plan9_test.go b/src/net/http/cgi/plan9_test.go
new file mode 100644
index 0000000..b7ace3f
--- /dev/null
+++ b/src/net/http/cgi/plan9_test.go
@@ -0,0 +1,17 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build plan9
+
+package cgi
+
+import (
+ "os"
+ "strconv"
+)
+
+func isProcessRunning(pid int) bool {
+ _, err := os.Stat("/proc/" + strconv.Itoa(pid))
+ return err == nil
+}
diff --git a/src/net/http/cgi/posix_test.go b/src/net/http/cgi/posix_test.go
new file mode 100644
index 0000000..49b9470
--- /dev/null
+++ b/src/net/http/cgi/posix_test.go
@@ -0,0 +1,20 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !plan9
+
+package cgi
+
+import (
+ "os"
+ "syscall"
+)
+
+func isProcessRunning(pid int) bool {
+ p, err := os.FindProcess(pid)
+ if err != nil {
+ return false
+ }
+ return p.Signal(syscall.Signal(0)) == nil
+}
diff --git a/src/net/http/cgi/testdata/test.cgi b/src/net/http/cgi/testdata/test.cgi
new file mode 100644
index 0000000..667fce2
--- /dev/null
+++ b/src/net/http/cgi/testdata/test.cgi
@@ -0,0 +1,95 @@
+#!/usr/bin/perl
+# Copyright 2011 The Go Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file.
+#
+# Test script run as a child process under cgi_test.go
+
+use strict;
+use Cwd;
+
+binmode STDOUT;
+
+my $q = MiniCGI->new;
+my $params = $q->Vars;
+
+if ($params->{"loc"}) {
+ print "Location: $params->{loc}\r\n\r\n";
+ exit(0);
+}
+
+print "Content-Type: text/html\r\n";
+print "X-CGI-Pid: $$\r\n";
+print "X-Test-Header: X-Test-Value\r\n";
+print "\r\n";
+
+if ($params->{"writestderr"}) {
+ print STDERR "Hello, stderr!\n";
+}
+
+if ($params->{"bigresponse"}) {
+ # 17 MB, for OS X: golang.org/issue/4958
+ for (1..(17 * 1024)) {
+ print "A" x 1024, "\r\n";
+ }
+ exit 0;
+}
+
+print "test=Hello CGI\r\n";
+
+foreach my $k (sort keys %$params) {
+ print "param-$k=$params->{$k}\r\n";
+}
+
+foreach my $k (sort keys %ENV) {
+ my $clean_env = $ENV{$k};
+ $clean_env =~ s/[\n\r]//g;
+ print "env-$k=$clean_env\r\n";
+}
+
+# NOTE: msys perl returns /c/go/src/... not C:\go\....
+my $dir = getcwd();
+if ($^O eq 'MSWin32' || $^O eq 'msys' || $^O eq 'cygwin') {
+ if ($dir =~ /^.:/) {
+ $dir =~ s!/!\\!g;
+ } else {
+ my $cmd = $ENV{'COMSPEC'} || 'c:\\windows\\system32\\cmd.exe';
+ $cmd =~ s!\\!/!g;
+ $dir = `$cmd /c cd`;
+ chomp $dir;
+ }
+}
+print "cwd=$dir\r\n";
+
+# A minimal version of CGI.pm, for people without the perl-modules
+# package installed. (CGI.pm used to be part of the Perl core, but
+# some distros now bundle perl-base and perl-modules separately...)
+package MiniCGI;
+
+sub new {
+ my $class = shift;
+ return bless {}, $class;
+}
+
+sub Vars {
+ my $self = shift;
+ my $pairs;
+ if ($ENV{CONTENT_LENGTH}) {
+ $pairs = do { local $/; <STDIN> };
+ } else {
+ $pairs = $ENV{QUERY_STRING};
+ }
+ my $vars = {};
+ foreach my $kv (split(/&/, $pairs)) {
+ my ($k, $v) = split(/=/, $kv, 2);
+ $vars->{_urldecode($k)} = _urldecode($v);
+ }
+ return $vars;
+}
+
+sub _urldecode {
+ my $v = shift;
+ $v =~ tr/+/ /;
+ $v =~ s/%([a-fA-F0-9][a-fA-F0-9])/pack("C", hex($1))/eg;
+ return $v;
+}
diff --git a/src/net/http/client.go b/src/net/http/client.go
new file mode 100644
index 0000000..77a701b
--- /dev/null
+++ b/src/net/http/client.go
@@ -0,0 +1,1038 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// HTTP client. See RFC 7230 through 7235.
+//
+// This is the high-level Client interface.
+// The low-level implementation is in transport.go.
+
+package http
+
+import (
+ "context"
+ "crypto/tls"
+ "encoding/base64"
+ "errors"
+ "fmt"
+ "io"
+ "log"
+ "net/http/internal/ascii"
+ "net/url"
+ "reflect"
+ "sort"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "time"
+)
+
+// A Client is an HTTP client. Its zero value (DefaultClient) is a
+// usable client that uses DefaultTransport.
+//
+// The Client's Transport typically has internal state (cached TCP
+// connections), so Clients should be reused instead of created as
+// needed. Clients are safe for concurrent use by multiple goroutines.
+//
+// A Client is higher-level than a RoundTripper (such as Transport)
+// and additionally handles HTTP details such as cookies and
+// redirects.
+//
+// When following redirects, the Client will forward all headers set on the
+// initial Request except:
+//
+// • when forwarding sensitive headers like "Authorization",
+// "WWW-Authenticate", and "Cookie" to untrusted targets.
+// These headers will be ignored when following a redirect to a domain
+// that is not a subdomain match or exact match of the initial domain.
+// For example, a redirect from "foo.com" to either "foo.com" or "sub.foo.com"
+// will forward the sensitive headers, but a redirect to "bar.com" will not.
+//
+// • when forwarding the "Cookie" header with a non-nil cookie Jar.
+// Since each redirect may mutate the state of the cookie jar,
+// a redirect may possibly alter a cookie set in the initial request.
+// When forwarding the "Cookie" header, any mutated cookies will be omitted,
+// with the expectation that the Jar will insert those mutated cookies
+// with the updated values (assuming the origin matches).
+// If Jar is nil, the initial cookies are forwarded without change.
+type Client struct {
+ // Transport specifies the mechanism by which individual
+ // HTTP requests are made.
+ // If nil, DefaultTransport is used.
+ Transport RoundTripper
+
+ // CheckRedirect specifies the policy for handling redirects.
+ // If CheckRedirect is not nil, the client calls it before
+ // following an HTTP redirect. The arguments req and via are
+ // the upcoming request and the requests made already, oldest
+ // first. If CheckRedirect returns an error, the Client's Get
+ // method returns both the previous Response (with its Body
+ // closed) and CheckRedirect's error (wrapped in a url.Error)
+ // instead of issuing the Request req.
+ // As a special case, if CheckRedirect returns ErrUseLastResponse,
+ // then the most recent response is returned with its body
+ // unclosed, along with a nil error.
+ //
+ // If CheckRedirect is nil, the Client uses its default policy,
+ // which is to stop after 10 consecutive requests.
+ CheckRedirect func(req *Request, via []*Request) error
+
+ // Jar specifies the cookie jar.
+ //
+ // The Jar is used to insert relevant cookies into every
+ // outbound Request and is updated with the cookie values
+ // of every inbound Response. The Jar is consulted for every
+ // redirect that the Client follows.
+ //
+ // If Jar is nil, cookies are only sent if they are explicitly
+ // set on the Request.
+ Jar CookieJar
+
+ // Timeout specifies a time limit for requests made by this
+ // Client. The timeout includes connection time, any
+ // redirects, and reading the response body. The timer remains
+ // running after Get, Head, Post, or Do return and will
+ // interrupt reading of the Response.Body.
+ //
+ // A Timeout of zero means no timeout.
+ //
+ // The Client cancels requests to the underlying Transport
+ // as if the Request's Context ended.
+ //
+ // For compatibility, the Client will also use the deprecated
+ // CancelRequest method on Transport if found. New
+ // RoundTripper implementations should use the Request's Context
+ // for cancellation instead of implementing CancelRequest.
+ Timeout time.Duration
+}
+
+// DefaultClient is the default Client and is used by Get, Head, and Post.
+var DefaultClient = &Client{}
+
+// RoundTripper is an interface representing the ability to execute a
+// single HTTP transaction, obtaining the Response for a given Request.
+//
+// A RoundTripper must be safe for concurrent use by multiple
+// goroutines.
+type RoundTripper interface {
+ // RoundTrip executes a single HTTP transaction, returning
+ // a Response for the provided Request.
+ //
+ // RoundTrip should not attempt to interpret the response. In
+ // particular, RoundTrip must return err == nil if it obtained
+ // a response, regardless of the response's HTTP status code.
+ // A non-nil err should be reserved for failure to obtain a
+ // response. Similarly, RoundTrip should not attempt to
+ // handle higher-level protocol details such as redirects,
+ // authentication, or cookies.
+ //
+ // RoundTrip should not modify the request, except for
+ // consuming and closing the Request's Body. RoundTrip may
+ // read fields of the request in a separate goroutine. Callers
+ // should not mutate or reuse the request until the Response's
+ // Body has been closed.
+ //
+ // RoundTrip must always close the body, including on errors,
+ // but depending on the implementation may do so in a separate
+ // goroutine even after RoundTrip returns. This means that
+ // callers wanting to reuse the body for subsequent requests
+ // must arrange to wait for the Close call before doing so.
+ //
+ // The Request's URL and Header fields must be initialized.
+ RoundTrip(*Request) (*Response, error)
+}
+
+// refererForURL returns a referer without any authentication info or
+// an empty string if lastReq scheme is https and newReq scheme is http.
+// If the referer was explicitly set, then it will continue to be used.
+func refererForURL(lastReq, newReq *url.URL, explicitRef string) string {
+ // https://tools.ietf.org/html/rfc7231#section-5.5.2
+ // "Clients SHOULD NOT include a Referer header field in a
+ // (non-secure) HTTP request if the referring page was
+ // transferred with a secure protocol."
+ if lastReq.Scheme == "https" && newReq.Scheme == "http" {
+ return ""
+ }
+ if explicitRef != "" {
+ return explicitRef
+ }
+
+ referer := lastReq.String()
+ if lastReq.User != nil {
+ // This is not very efficient, but is the best we can
+ // do without:
+ // - introducing a new method on URL
+ // - creating a race condition
+ // - copying the URL struct manually, which would cause
+ // maintenance problems down the line
+ auth := lastReq.User.String() + "@"
+ referer = strings.Replace(referer, auth, "", 1)
+ }
+ return referer
+}
+
+// didTimeout is non-nil only if err != nil.
+func (c *Client) send(req *Request, deadline time.Time) (resp *Response, didTimeout func() bool, err error) {
+ if c.Jar != nil {
+ for _, cookie := range c.Jar.Cookies(req.URL) {
+ req.AddCookie(cookie)
+ }
+ }
+ resp, didTimeout, err = send(req, c.transport(), deadline)
+ if err != nil {
+ return nil, didTimeout, err
+ }
+ if c.Jar != nil {
+ if rc := resp.Cookies(); len(rc) > 0 {
+ c.Jar.SetCookies(req.URL, rc)
+ }
+ }
+ return resp, nil, nil
+}
+
+func (c *Client) deadline() time.Time {
+ if c.Timeout > 0 {
+ return time.Now().Add(c.Timeout)
+ }
+ return time.Time{}
+}
+
+func (c *Client) transport() RoundTripper {
+ if c.Transport != nil {
+ return c.Transport
+ }
+ return DefaultTransport
+}
+
+// ErrSchemeMismatch is returned when a server returns an HTTP response to an HTTPS client.
+var ErrSchemeMismatch = errors.New("http: server gave HTTP response to HTTPS client")
+
+// send issues an HTTP request.
+// Caller should close resp.Body when done reading from it.
+func send(ireq *Request, rt RoundTripper, deadline time.Time) (resp *Response, didTimeout func() bool, err error) {
+ req := ireq // req is either the original request, or a modified fork
+
+ if rt == nil {
+ req.closeBody()
+ return nil, alwaysFalse, errors.New("http: no Client.Transport or DefaultTransport")
+ }
+
+ if req.URL == nil {
+ req.closeBody()
+ return nil, alwaysFalse, errors.New("http: nil Request.URL")
+ }
+
+ if req.RequestURI != "" {
+ req.closeBody()
+ return nil, alwaysFalse, errors.New("http: Request.RequestURI can't be set in client requests")
+ }
+
+ // forkReq forks req into a shallow clone of ireq the first
+ // time it's called.
+ forkReq := func() {
+ if ireq == req {
+ req = new(Request)
+ *req = *ireq // shallow clone
+ }
+ }
+
+ // Most the callers of send (Get, Post, et al) don't need
+ // Headers, leaving it uninitialized. We guarantee to the
+ // Transport that this has been initialized, though.
+ if req.Header == nil {
+ forkReq()
+ req.Header = make(Header)
+ }
+
+ if u := req.URL.User; u != nil && req.Header.Get("Authorization") == "" {
+ username := u.Username()
+ password, _ := u.Password()
+ forkReq()
+ req.Header = cloneOrMakeHeader(ireq.Header)
+ req.Header.Set("Authorization", "Basic "+basicAuth(username, password))
+ }
+
+ if !deadline.IsZero() {
+ forkReq()
+ }
+ stopTimer, didTimeout := setRequestCancel(req, rt, deadline)
+
+ resp, err = rt.RoundTrip(req)
+ if err != nil {
+ stopTimer()
+ if resp != nil {
+ log.Printf("RoundTripper returned a response & error; ignoring response")
+ }
+ if tlsErr, ok := err.(tls.RecordHeaderError); ok {
+ // If we get a bad TLS record header, check to see if the
+ // response looks like HTTP and give a more helpful error.
+ // See golang.org/issue/11111.
+ if string(tlsErr.RecordHeader[:]) == "HTTP/" {
+ err = ErrSchemeMismatch
+ }
+ }
+ return nil, didTimeout, err
+ }
+ if resp == nil {
+ return nil, didTimeout, fmt.Errorf("http: RoundTripper implementation (%T) returned a nil *Response with a nil error", rt)
+ }
+ if resp.Body == nil {
+ // The documentation on the Body field says “The http Client and Transport
+ // guarantee that Body is always non-nil, even on responses without a body
+ // or responses with a zero-length body.” Unfortunately, we didn't document
+ // that same constraint for arbitrary RoundTripper implementations, and
+ // RoundTripper implementations in the wild (mostly in tests) assume that
+ // they can use a nil Body to mean an empty one (similar to Request.Body).
+ // (See https://golang.org/issue/38095.)
+ //
+ // If the ContentLength allows the Body to be empty, fill in an empty one
+ // here to ensure that it is non-nil.
+ if resp.ContentLength > 0 && req.Method != "HEAD" {
+ return nil, didTimeout, fmt.Errorf("http: RoundTripper implementation (%T) returned a *Response with content length %d but a nil Body", rt, resp.ContentLength)
+ }
+ resp.Body = io.NopCloser(strings.NewReader(""))
+ }
+ if !deadline.IsZero() {
+ resp.Body = &cancelTimerBody{
+ stop: stopTimer,
+ rc: resp.Body,
+ reqDidTimeout: didTimeout,
+ }
+ }
+ return resp, nil, nil
+}
+
+// timeBeforeContextDeadline reports whether the non-zero Time t is
+// before ctx's deadline, if any. If ctx does not have a deadline, it
+// always reports true (the deadline is considered infinite).
+func timeBeforeContextDeadline(t time.Time, ctx context.Context) bool {
+ d, ok := ctx.Deadline()
+ if !ok {
+ return true
+ }
+ return t.Before(d)
+}
+
+// knownRoundTripperImpl reports whether rt is a RoundTripper that's
+// maintained by the Go team and known to implement the latest
+// optional semantics (notably contexts). The Request is used
+// to check whether this particular request is using an alternate protocol,
+// in which case we need to check the RoundTripper for that protocol.
+func knownRoundTripperImpl(rt RoundTripper, req *Request) bool {
+ switch t := rt.(type) {
+ case *Transport:
+ if altRT := t.alternateRoundTripper(req); altRT != nil {
+ return knownRoundTripperImpl(altRT, req)
+ }
+ return true
+ case *http2Transport, http2noDialH2RoundTripper:
+ return true
+ }
+ // There's a very minor chance of a false positive with this.
+ // Instead of detecting our golang.org/x/net/http2.Transport,
+ // it might detect a Transport type in a different http2
+ // package. But I know of none, and the only problem would be
+ // some temporarily leaked goroutines if the transport didn't
+ // support contexts. So this is a good enough heuristic:
+ if reflect.TypeOf(rt).String() == "*http2.Transport" {
+ return true
+ }
+ return false
+}
+
+// setRequestCancel sets req.Cancel and adds a deadline context to req
+// if deadline is non-zero. The RoundTripper's type is used to
+// determine whether the legacy CancelRequest behavior should be used.
+//
+// As background, there are three ways to cancel a request:
+// First was Transport.CancelRequest. (deprecated)
+// Second was Request.Cancel.
+// Third was Request.Context.
+// This function populates the second and third, and uses the first if it really needs to.
+func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) (stopTimer func(), didTimeout func() bool) {
+ if deadline.IsZero() {
+ return nop, alwaysFalse
+ }
+ knownTransport := knownRoundTripperImpl(rt, req)
+ oldCtx := req.Context()
+
+ if req.Cancel == nil && knownTransport {
+ // If they already had a Request.Context that's
+ // expiring sooner, do nothing:
+ if !timeBeforeContextDeadline(deadline, oldCtx) {
+ return nop, alwaysFalse
+ }
+
+ var cancelCtx func()
+ req.ctx, cancelCtx = context.WithDeadline(oldCtx, deadline)
+ return cancelCtx, func() bool { return time.Now().After(deadline) }
+ }
+ initialReqCancel := req.Cancel // the user's original Request.Cancel, if any
+
+ var cancelCtx func()
+ if timeBeforeContextDeadline(deadline, oldCtx) {
+ req.ctx, cancelCtx = context.WithDeadline(oldCtx, deadline)
+ }
+
+ cancel := make(chan struct{})
+ req.Cancel = cancel
+
+ doCancel := func() {
+ // The second way in the func comment above:
+ close(cancel)
+ // The first way, used only for RoundTripper
+ // implementations written before Go 1.5 or Go 1.6.
+ type canceler interface{ CancelRequest(*Request) }
+ if v, ok := rt.(canceler); ok {
+ v.CancelRequest(req)
+ }
+ }
+
+ stopTimerCh := make(chan struct{})
+ var once sync.Once
+ stopTimer = func() {
+ once.Do(func() {
+ close(stopTimerCh)
+ if cancelCtx != nil {
+ cancelCtx()
+ }
+ })
+ }
+
+ timer := time.NewTimer(time.Until(deadline))
+ var timedOut atomic.Bool
+
+ go func() {
+ select {
+ case <-initialReqCancel:
+ doCancel()
+ timer.Stop()
+ case <-timer.C:
+ timedOut.Store(true)
+ doCancel()
+ case <-stopTimerCh:
+ timer.Stop()
+ }
+ }()
+
+ return stopTimer, timedOut.Load
+}
+
+// See 2 (end of page 4) https://www.ietf.org/rfc/rfc2617.txt
+// "To receive authorization, the client sends the userid and password,
+// separated by a single colon (":") character, within a base64
+// encoded string in the credentials."
+// It is not meant to be urlencoded.
+func basicAuth(username, password string) string {
+ auth := username + ":" + password
+ return base64.StdEncoding.EncodeToString([]byte(auth))
+}
+
+// Get issues a GET to the specified URL. If the response is one of
+// the following redirect codes, Get follows the redirect, up to a
+// maximum of 10 redirects:
+//
+// 301 (Moved Permanently)
+// 302 (Found)
+// 303 (See Other)
+// 307 (Temporary Redirect)
+// 308 (Permanent Redirect)
+//
+// An error is returned if there were too many redirects or if there
+// was an HTTP protocol error. A non-2xx response doesn't cause an
+// error. Any returned error will be of type *url.Error. The url.Error
+// value's Timeout method will report true if the request timed out.
+//
+// When err is nil, resp always contains a non-nil resp.Body.
+// Caller should close resp.Body when done reading from it.
+//
+// Get is a wrapper around DefaultClient.Get.
+//
+// To make a request with custom headers, use NewRequest and
+// DefaultClient.Do.
+//
+// To make a request with a specified context.Context, use NewRequestWithContext
+// and DefaultClient.Do.
+func Get(url string) (resp *Response, err error) {
+ return DefaultClient.Get(url)
+}
+
+// Get issues a GET to the specified URL. If the response is one of the
+// following redirect codes, Get follows the redirect after calling the
+// Client's CheckRedirect function:
+//
+// 301 (Moved Permanently)
+// 302 (Found)
+// 303 (See Other)
+// 307 (Temporary Redirect)
+// 308 (Permanent Redirect)
+//
+// An error is returned if the Client's CheckRedirect function fails
+// or if there was an HTTP protocol error. A non-2xx response doesn't
+// cause an error. Any returned error will be of type *url.Error. The
+// url.Error value's Timeout method will report true if the request
+// timed out.
+//
+// When err is nil, resp always contains a non-nil resp.Body.
+// Caller should close resp.Body when done reading from it.
+//
+// To make a request with custom headers, use NewRequest and Client.Do.
+//
+// To make a request with a specified context.Context, use NewRequestWithContext
+// and Client.Do.
+func (c *Client) Get(url string) (resp *Response, err error) {
+ req, err := NewRequest("GET", url, nil)
+ if err != nil {
+ return nil, err
+ }
+ return c.Do(req)
+}
+
+func alwaysFalse() bool { return false }
+
+// ErrUseLastResponse can be returned by Client.CheckRedirect hooks to
+// control how redirects are processed. If returned, the next request
+// is not sent and the most recent response is returned with its body
+// unclosed.
+var ErrUseLastResponse = errors.New("net/http: use last response")
+
+// checkRedirect calls either the user's configured CheckRedirect
+// function, or the default.
+func (c *Client) checkRedirect(req *Request, via []*Request) error {
+ fn := c.CheckRedirect
+ if fn == nil {
+ fn = defaultCheckRedirect
+ }
+ return fn(req, via)
+}
+
+// redirectBehavior describes what should happen when the
+// client encounters a 3xx status code from the server.
+func redirectBehavior(reqMethod string, resp *Response, ireq *Request) (redirectMethod string, shouldRedirect, includeBody bool) {
+ switch resp.StatusCode {
+ case 301, 302, 303:
+ redirectMethod = reqMethod
+ shouldRedirect = true
+ includeBody = false
+
+ // RFC 2616 allowed automatic redirection only with GET and
+ // HEAD requests. RFC 7231 lifts this restriction, but we still
+ // restrict other methods to GET to maintain compatibility.
+ // See Issue 18570.
+ if reqMethod != "GET" && reqMethod != "HEAD" {
+ redirectMethod = "GET"
+ }
+ case 307, 308:
+ redirectMethod = reqMethod
+ shouldRedirect = true
+ includeBody = true
+
+ if ireq.GetBody == nil && ireq.outgoingLength() != 0 {
+ // We had a request body, and 307/308 require
+ // re-sending it, but GetBody is not defined. So just
+ // return this response to the user instead of an
+ // error, like we did in Go 1.7 and earlier.
+ shouldRedirect = false
+ }
+ }
+ return redirectMethod, shouldRedirect, includeBody
+}
+
+// urlErrorOp returns the (*url.Error).Op value to use for the
+// provided (*Request).Method value.
+func urlErrorOp(method string) string {
+ if method == "" {
+ return "Get"
+ }
+ if lowerMethod, ok := ascii.ToLower(method); ok {
+ return method[:1] + lowerMethod[1:]
+ }
+ return method
+}
+
+// Do sends an HTTP request and returns an HTTP response, following
+// policy (such as redirects, cookies, auth) as configured on the
+// client.
+//
+// An error is returned if caused by client policy (such as
+// CheckRedirect), or failure to speak HTTP (such as a network
+// connectivity problem). A non-2xx status code doesn't cause an
+// error.
+//
+// If the returned error is nil, the Response will contain a non-nil
+// Body which the user is expected to close. If the Body is not both
+// read to EOF and closed, the Client's underlying RoundTripper
+// (typically Transport) may not be able to re-use a persistent TCP
+// connection to the server for a subsequent "keep-alive" request.
+//
+// The request Body, if non-nil, will be closed by the underlying
+// Transport, even on errors.
+//
+// On error, any Response can be ignored. A non-nil Response with a
+// non-nil error only occurs when CheckRedirect fails, and even then
+// the returned Response.Body is already closed.
+//
+// Generally Get, Post, or PostForm will be used instead of Do.
+//
+// If the server replies with a redirect, the Client first uses the
+// CheckRedirect function to determine whether the redirect should be
+// followed. If permitted, a 301, 302, or 303 redirect causes
+// subsequent requests to use HTTP method GET
+// (or HEAD if the original request was HEAD), with no body.
+// A 307 or 308 redirect preserves the original HTTP method and body,
+// provided that the Request.GetBody function is defined.
+// The NewRequest function automatically sets GetBody for common
+// standard library body types.
+//
+// Any returned error will be of type *url.Error. The url.Error
+// value's Timeout method will report true if the request timed out.
+func (c *Client) Do(req *Request) (*Response, error) {
+ return c.do(req)
+}
+
+var testHookClientDoResult func(retres *Response, reterr error)
+
+func (c *Client) do(req *Request) (retres *Response, reterr error) {
+ if testHookClientDoResult != nil {
+ defer func() { testHookClientDoResult(retres, reterr) }()
+ }
+ if req.URL == nil {
+ req.closeBody()
+ return nil, &url.Error{
+ Op: urlErrorOp(req.Method),
+ Err: errors.New("http: nil Request.URL"),
+ }
+ }
+
+ var (
+ deadline = c.deadline()
+ reqs []*Request
+ resp *Response
+ copyHeaders = c.makeHeadersCopier(req)
+ reqBodyClosed = false // have we closed the current req.Body?
+
+ // Redirect behavior:
+ redirectMethod string
+ includeBody bool
+ )
+ uerr := func(err error) error {
+ // the body may have been closed already by c.send()
+ if !reqBodyClosed {
+ req.closeBody()
+ }
+ var urlStr string
+ if resp != nil && resp.Request != nil {
+ urlStr = stripPassword(resp.Request.URL)
+ } else {
+ urlStr = stripPassword(req.URL)
+ }
+ return &url.Error{
+ Op: urlErrorOp(reqs[0].Method),
+ URL: urlStr,
+ Err: err,
+ }
+ }
+ for {
+ // For all but the first request, create the next
+ // request hop and replace req.
+ if len(reqs) > 0 {
+ loc := resp.Header.Get("Location")
+ if loc == "" {
+ // While most 3xx responses include a Location, it is not
+ // required and 3xx responses without a Location have been
+ // observed in the wild. See issues #17773 and #49281.
+ return resp, nil
+ }
+ u, err := req.URL.Parse(loc)
+ if err != nil {
+ resp.closeBody()
+ return nil, uerr(fmt.Errorf("failed to parse Location header %q: %v", loc, err))
+ }
+ host := ""
+ if req.Host != "" && req.Host != req.URL.Host {
+ // If the caller specified a custom Host header and the
+ // redirect location is relative, preserve the Host header
+ // through the redirect. See issue #22233.
+ if u, _ := url.Parse(loc); u != nil && !u.IsAbs() {
+ host = req.Host
+ }
+ }
+ ireq := reqs[0]
+ req = &Request{
+ Method: redirectMethod,
+ Response: resp,
+ URL: u,
+ Header: make(Header),
+ Host: host,
+ Cancel: ireq.Cancel,
+ ctx: ireq.ctx,
+ }
+ if includeBody && ireq.GetBody != nil {
+ req.Body, err = ireq.GetBody()
+ if err != nil {
+ resp.closeBody()
+ return nil, uerr(err)
+ }
+ req.ContentLength = ireq.ContentLength
+ }
+
+ // Copy original headers before setting the Referer,
+ // in case the user set Referer on their first request.
+ // If they really want to override, they can do it in
+ // their CheckRedirect func.
+ copyHeaders(req)
+
+ // Add the Referer header from the most recent
+ // request URL to the new one, if it's not https->http:
+ if ref := refererForURL(reqs[len(reqs)-1].URL, req.URL, req.Header.Get("Referer")); ref != "" {
+ req.Header.Set("Referer", ref)
+ }
+ err = c.checkRedirect(req, reqs)
+
+ // Sentinel error to let users select the
+ // previous response, without closing its
+ // body. See Issue 10069.
+ if err == ErrUseLastResponse {
+ return resp, nil
+ }
+
+ // Close the previous response's body. But
+ // read at least some of the body so if it's
+ // small the underlying TCP connection will be
+ // re-used. No need to check for errors: if it
+ // fails, the Transport won't reuse it anyway.
+ const maxBodySlurpSize = 2 << 10
+ if resp.ContentLength == -1 || resp.ContentLength <= maxBodySlurpSize {
+ io.CopyN(io.Discard, resp.Body, maxBodySlurpSize)
+ }
+ resp.Body.Close()
+
+ if err != nil {
+ // Special case for Go 1 compatibility: return both the response
+ // and an error if the CheckRedirect function failed.
+ // See https://golang.org/issue/3795
+ // The resp.Body has already been closed.
+ ue := uerr(err)
+ ue.(*url.Error).URL = loc
+ return resp, ue
+ }
+ }
+
+ reqs = append(reqs, req)
+ var err error
+ var didTimeout func() bool
+ if resp, didTimeout, err = c.send(req, deadline); err != nil {
+ // c.send() always closes req.Body
+ reqBodyClosed = true
+ if !deadline.IsZero() && didTimeout() {
+ err = &httpError{
+ err: err.Error() + " (Client.Timeout exceeded while awaiting headers)",
+ timeout: true,
+ }
+ }
+ return nil, uerr(err)
+ }
+
+ var shouldRedirect bool
+ redirectMethod, shouldRedirect, includeBody = redirectBehavior(req.Method, resp, reqs[0])
+ if !shouldRedirect {
+ return resp, nil
+ }
+
+ req.closeBody()
+ }
+}
+
+// makeHeadersCopier makes a function that copies headers from the
+// initial Request, ireq. For every redirect, this function must be called
+// so that it can copy headers into the upcoming Request.
+func (c *Client) makeHeadersCopier(ireq *Request) func(*Request) {
+ // The headers to copy are from the very initial request.
+ // We use a closured callback to keep a reference to these original headers.
+ var (
+ ireqhdr = cloneOrMakeHeader(ireq.Header)
+ icookies map[string][]*Cookie
+ )
+ if c.Jar != nil && ireq.Header.Get("Cookie") != "" {
+ icookies = make(map[string][]*Cookie)
+ for _, c := range ireq.Cookies() {
+ icookies[c.Name] = append(icookies[c.Name], c)
+ }
+ }
+
+ preq := ireq // The previous request
+ return func(req *Request) {
+ // If Jar is present and there was some initial cookies provided
+ // via the request header, then we may need to alter the initial
+ // cookies as we follow redirects since each redirect may end up
+ // modifying a pre-existing cookie.
+ //
+ // Since cookies already set in the request header do not contain
+ // information about the original domain and path, the logic below
+ // assumes any new set cookies override the original cookie
+ // regardless of domain or path.
+ //
+ // See https://golang.org/issue/17494
+ if c.Jar != nil && icookies != nil {
+ var changed bool
+ resp := req.Response // The response that caused the upcoming redirect
+ for _, c := range resp.Cookies() {
+ if _, ok := icookies[c.Name]; ok {
+ delete(icookies, c.Name)
+ changed = true
+ }
+ }
+ if changed {
+ ireqhdr.Del("Cookie")
+ var ss []string
+ for _, cs := range icookies {
+ for _, c := range cs {
+ ss = append(ss, c.Name+"="+c.Value)
+ }
+ }
+ sort.Strings(ss) // Ensure deterministic headers
+ ireqhdr.Set("Cookie", strings.Join(ss, "; "))
+ }
+ }
+
+ // Copy the initial request's Header values
+ // (at least the safe ones).
+ for k, vv := range ireqhdr {
+ if shouldCopyHeaderOnRedirect(k, preq.URL, req.URL) {
+ req.Header[k] = vv
+ }
+ }
+
+ preq = req // Update previous Request with the current request
+ }
+}
+
+func defaultCheckRedirect(req *Request, via []*Request) error {
+ if len(via) >= 10 {
+ return errors.New("stopped after 10 redirects")
+ }
+ return nil
+}
+
+// Post issues a POST to the specified URL.
+//
+// Caller should close resp.Body when done reading from it.
+//
+// If the provided body is an io.Closer, it is closed after the
+// request.
+//
+// Post is a wrapper around DefaultClient.Post.
+//
+// To set custom headers, use NewRequest and DefaultClient.Do.
+//
+// See the Client.Do method documentation for details on how redirects
+// are handled.
+//
+// To make a request with a specified context.Context, use NewRequestWithContext
+// and DefaultClient.Do.
+func Post(url, contentType string, body io.Reader) (resp *Response, err error) {
+ return DefaultClient.Post(url, contentType, body)
+}
+
+// Post issues a POST to the specified URL.
+//
+// Caller should close resp.Body when done reading from it.
+//
+// If the provided body is an io.Closer, it is closed after the
+// request.
+//
+// To set custom headers, use NewRequest and Client.Do.
+//
+// To make a request with a specified context.Context, use NewRequestWithContext
+// and Client.Do.
+//
+// See the Client.Do method documentation for details on how redirects
+// are handled.
+func (c *Client) Post(url, contentType string, body io.Reader) (resp *Response, err error) {
+ req, err := NewRequest("POST", url, body)
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Content-Type", contentType)
+ return c.Do(req)
+}
+
+// PostForm issues a POST to the specified URL, with data's keys and
+// values URL-encoded as the request body.
+//
+// The Content-Type header is set to application/x-www-form-urlencoded.
+// To set other headers, use NewRequest and DefaultClient.Do.
+//
+// When err is nil, resp always contains a non-nil resp.Body.
+// Caller should close resp.Body when done reading from it.
+//
+// PostForm is a wrapper around DefaultClient.PostForm.
+//
+// See the Client.Do method documentation for details on how redirects
+// are handled.
+//
+// To make a request with a specified context.Context, use NewRequestWithContext
+// and DefaultClient.Do.
+func PostForm(url string, data url.Values) (resp *Response, err error) {
+ return DefaultClient.PostForm(url, data)
+}
+
+// PostForm issues a POST to the specified URL,
+// with data's keys and values URL-encoded as the request body.
+//
+// The Content-Type header is set to application/x-www-form-urlencoded.
+// To set other headers, use NewRequest and Client.Do.
+//
+// When err is nil, resp always contains a non-nil resp.Body.
+// Caller should close resp.Body when done reading from it.
+//
+// See the Client.Do method documentation for details on how redirects
+// are handled.
+//
+// To make a request with a specified context.Context, use NewRequestWithContext
+// and Client.Do.
+func (c *Client) PostForm(url string, data url.Values) (resp *Response, err error) {
+ return c.Post(url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode()))
+}
+
+// Head issues a HEAD to the specified URL. If the response is one of
+// the following redirect codes, Head follows the redirect, up to a
+// maximum of 10 redirects:
+//
+// 301 (Moved Permanently)
+// 302 (Found)
+// 303 (See Other)
+// 307 (Temporary Redirect)
+// 308 (Permanent Redirect)
+//
+// Head is a wrapper around DefaultClient.Head.
+//
+// To make a request with a specified context.Context, use NewRequestWithContext
+// and DefaultClient.Do.
+func Head(url string) (resp *Response, err error) {
+ return DefaultClient.Head(url)
+}
+
+// Head issues a HEAD to the specified URL. If the response is one of the
+// following redirect codes, Head follows the redirect after calling the
+// Client's CheckRedirect function:
+//
+// 301 (Moved Permanently)
+// 302 (Found)
+// 303 (See Other)
+// 307 (Temporary Redirect)
+// 308 (Permanent Redirect)
+//
+// To make a request with a specified context.Context, use NewRequestWithContext
+// and Client.Do.
+func (c *Client) Head(url string) (resp *Response, err error) {
+ req, err := NewRequest("HEAD", url, nil)
+ if err != nil {
+ return nil, err
+ }
+ return c.Do(req)
+}
+
+// CloseIdleConnections closes any connections on its Transport which
+// were previously connected from previous requests but are now
+// sitting idle in a "keep-alive" state. It does not interrupt any
+// connections currently in use.
+//
+// If the Client's Transport does not have a CloseIdleConnections method
+// then this method does nothing.
+func (c *Client) CloseIdleConnections() {
+ type closeIdler interface {
+ CloseIdleConnections()
+ }
+ if tr, ok := c.transport().(closeIdler); ok {
+ tr.CloseIdleConnections()
+ }
+}
+
+// cancelTimerBody is an io.ReadCloser that wraps rc with two features:
+// 1. On Read error or close, the stop func is called.
+// 2. On Read failure, if reqDidTimeout is true, the error is wrapped and
+// marked as net.Error that hit its timeout.
+type cancelTimerBody struct {
+ stop func() // stops the time.Timer waiting to cancel the request
+ rc io.ReadCloser
+ reqDidTimeout func() bool
+}
+
+func (b *cancelTimerBody) Read(p []byte) (n int, err error) {
+ n, err = b.rc.Read(p)
+ if err == nil {
+ return n, nil
+ }
+ if err == io.EOF {
+ return n, err
+ }
+ if b.reqDidTimeout() {
+ err = &httpError{
+ err: err.Error() + " (Client.Timeout or context cancellation while reading body)",
+ timeout: true,
+ }
+ }
+ return n, err
+}
+
+func (b *cancelTimerBody) Close() error {
+ err := b.rc.Close()
+ b.stop()
+ return err
+}
+
+func shouldCopyHeaderOnRedirect(headerKey string, initial, dest *url.URL) bool {
+ switch CanonicalHeaderKey(headerKey) {
+ case "Authorization", "Www-Authenticate", "Cookie", "Cookie2":
+ // Permit sending auth/cookie headers from "foo.com"
+ // to "sub.foo.com".
+
+ // Note that we don't send all cookies to subdomains
+ // automatically. This function is only used for
+ // Cookies set explicitly on the initial outgoing
+ // client request. Cookies automatically added via the
+ // CookieJar mechanism continue to follow each
+ // cookie's scope as set by Set-Cookie. But for
+ // outgoing requests with the Cookie header set
+ // directly, we don't know their scope, so we assume
+ // it's for *.domain.com.
+
+ ihost := idnaASCIIFromURL(initial)
+ dhost := idnaASCIIFromURL(dest)
+ return isDomainOrSubdomain(dhost, ihost)
+ }
+ // All other headers are copied:
+ return true
+}
+
+// isDomainOrSubdomain reports whether sub is a subdomain (or exact
+// match) of the parent domain.
+//
+// Both domains must already be in canonical form.
+func isDomainOrSubdomain(sub, parent string) bool {
+ if sub == parent {
+ return true
+ }
+ // If sub contains a :, it's probably an IPv6 address (and is definitely not a hostname).
+ // Don't check the suffix in this case, to avoid matching the contents of a IPv6 zone.
+ // For example, "::1%.www.example.com" is not a subdomain of "www.example.com".
+ if strings.ContainsAny(sub, ":%") {
+ return false
+ }
+ // If sub is "foo.example.com" and parent is "example.com",
+ // that means sub must end in "."+parent.
+ // Do it without allocating.
+ if !strings.HasSuffix(sub, parent) {
+ return false
+ }
+ return sub[len(sub)-len(parent)-1] == '.'
+}
+
+func stripPassword(u *url.URL) string {
+ _, passSet := u.User.Password()
+ if passSet {
+ return strings.Replace(u.String(), u.User.String()+"@", u.User.Username()+":***@", 1)
+ }
+ return u.String()
+}
diff --git a/src/net/http/client_test.go b/src/net/http/client_test.go
new file mode 100644
index 0000000..fc1d791
--- /dev/null
+++ b/src/net/http/client_test.go
@@ -0,0 +1,2144 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Tests for client.go
+
+package http_test
+
+import (
+ "bytes"
+ "context"
+ "crypto/tls"
+ "encoding/base64"
+ "errors"
+ "fmt"
+ "internal/testenv"
+ "io"
+ "log"
+ "net"
+ . "net/http"
+ "net/http/cookiejar"
+ "net/http/httptest"
+ "net/url"
+ "reflect"
+ "runtime"
+ "strconv"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+)
+
+var robotsTxtHandler = HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Last-Modified", "sometime")
+ fmt.Fprintf(w, "User-agent: go\nDisallow: /something/")
+})
+
+// pedanticReadAll works like io.ReadAll but additionally
+// verifies that r obeys the documented io.Reader contract.
+func pedanticReadAll(r io.Reader) (b []byte, err error) {
+ var bufa [64]byte
+ buf := bufa[:]
+ for {
+ n, err := r.Read(buf)
+ if n == 0 && err == nil {
+ return nil, fmt.Errorf("Read: n=0 with err=nil")
+ }
+ b = append(b, buf[:n]...)
+ if err == io.EOF {
+ n, err := r.Read(buf)
+ if n != 0 || err != io.EOF {
+ return nil, fmt.Errorf("Read: n=%d err=%#v after EOF", n, err)
+ }
+ return b, nil
+ }
+ if err != nil {
+ return b, err
+ }
+ }
+}
+
+type chanWriter chan string
+
+func (w chanWriter) Write(p []byte) (n int, err error) {
+ w <- string(p)
+ return len(p), nil
+}
+
+func TestClient(t *testing.T) { run(t, testClient) }
+func testClient(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, robotsTxtHandler).ts
+
+ c := ts.Client()
+ r, err := c.Get(ts.URL)
+ var b []byte
+ if err == nil {
+ b, err = pedanticReadAll(r.Body)
+ r.Body.Close()
+ }
+ if err != nil {
+ t.Error(err)
+ } else if s := string(b); !strings.HasPrefix(s, "User-agent:") {
+ t.Errorf("Incorrect page body (did not begin with User-agent): %q", s)
+ }
+}
+
+func TestClientHead(t *testing.T) { run(t, testClientHead) }
+func testClientHead(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, robotsTxtHandler)
+ r, err := cst.c.Head(cst.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if _, ok := r.Header["Last-Modified"]; !ok {
+ t.Error("Last-Modified header not found.")
+ }
+}
+
+type recordingTransport struct {
+ req *Request
+}
+
+func (t *recordingTransport) RoundTrip(req *Request) (resp *Response, err error) {
+ t.req = req
+ return nil, errors.New("dummy impl")
+}
+
+func TestGetRequestFormat(t *testing.T) {
+ setParallel(t)
+ defer afterTest(t)
+ tr := &recordingTransport{}
+ client := &Client{Transport: tr}
+ url := "http://dummy.faketld/"
+ client.Get(url) // Note: doesn't hit network
+ if tr.req.Method != "GET" {
+ t.Errorf("expected method %q; got %q", "GET", tr.req.Method)
+ }
+ if tr.req.URL.String() != url {
+ t.Errorf("expected URL %q; got %q", url, tr.req.URL.String())
+ }
+ if tr.req.Header == nil {
+ t.Errorf("expected non-nil request Header")
+ }
+}
+
+func TestPostRequestFormat(t *testing.T) {
+ defer afterTest(t)
+ tr := &recordingTransport{}
+ client := &Client{Transport: tr}
+
+ url := "http://dummy.faketld/"
+ json := `{"key":"value"}`
+ b := strings.NewReader(json)
+ client.Post(url, "application/json", b) // Note: doesn't hit network
+
+ if tr.req.Method != "POST" {
+ t.Errorf("got method %q, want %q", tr.req.Method, "POST")
+ }
+ if tr.req.URL.String() != url {
+ t.Errorf("got URL %q, want %q", tr.req.URL.String(), url)
+ }
+ if tr.req.Header == nil {
+ t.Fatalf("expected non-nil request Header")
+ }
+ if tr.req.Close {
+ t.Error("got Close true, want false")
+ }
+ if g, e := tr.req.ContentLength, int64(len(json)); g != e {
+ t.Errorf("got ContentLength %d, want %d", g, e)
+ }
+}
+
+func TestPostFormRequestFormat(t *testing.T) {
+ defer afterTest(t)
+ tr := &recordingTransport{}
+ client := &Client{Transport: tr}
+
+ urlStr := "http://dummy.faketld/"
+ form := make(url.Values)
+ form.Set("foo", "bar")
+ form.Add("foo", "bar2")
+ form.Set("bar", "baz")
+ client.PostForm(urlStr, form) // Note: doesn't hit network
+
+ if tr.req.Method != "POST" {
+ t.Errorf("got method %q, want %q", tr.req.Method, "POST")
+ }
+ if tr.req.URL.String() != urlStr {
+ t.Errorf("got URL %q, want %q", tr.req.URL.String(), urlStr)
+ }
+ if tr.req.Header == nil {
+ t.Fatalf("expected non-nil request Header")
+ }
+ if g, e := tr.req.Header.Get("Content-Type"), "application/x-www-form-urlencoded"; g != e {
+ t.Errorf("got Content-Type %q, want %q", g, e)
+ }
+ if tr.req.Close {
+ t.Error("got Close true, want false")
+ }
+ // Depending on map iteration, body can be either of these.
+ expectedBody := "foo=bar&foo=bar2&bar=baz"
+ expectedBody1 := "bar=baz&foo=bar&foo=bar2"
+ if g, e := tr.req.ContentLength, int64(len(expectedBody)); g != e {
+ t.Errorf("got ContentLength %d, want %d", g, e)
+ }
+ bodyb, err := io.ReadAll(tr.req.Body)
+ if err != nil {
+ t.Fatalf("ReadAll on req.Body: %v", err)
+ }
+ if g := string(bodyb); g != expectedBody && g != expectedBody1 {
+ t.Errorf("got body %q, want %q or %q", g, expectedBody, expectedBody1)
+ }
+}
+
+func TestClientRedirects(t *testing.T) { run(t, testClientRedirects) }
+func testClientRedirects(t *testing.T, mode testMode) {
+ var ts *httptest.Server
+ ts = newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ n, _ := strconv.Atoi(r.FormValue("n"))
+ // Test Referer header. (7 is arbitrary position to test at)
+ if n == 7 {
+ if g, e := r.Referer(), ts.URL+"/?n=6"; e != g {
+ t.Errorf("on request ?n=7, expected referer of %q; got %q", e, g)
+ }
+ }
+ if n < 15 {
+ Redirect(w, r, fmt.Sprintf("/?n=%d", n+1), StatusTemporaryRedirect)
+ return
+ }
+ fmt.Fprintf(w, "n=%d", n)
+ })).ts
+
+ c := ts.Client()
+ _, err := c.Get(ts.URL)
+ if e, g := `Get "/?n=10": stopped after 10 redirects`, fmt.Sprintf("%v", err); e != g {
+ t.Errorf("with default client Get, expected error %q, got %q", e, g)
+ }
+
+ // HEAD request should also have the ability to follow redirects.
+ _, err = c.Head(ts.URL)
+ if e, g := `Head "/?n=10": stopped after 10 redirects`, fmt.Sprintf("%v", err); e != g {
+ t.Errorf("with default client Head, expected error %q, got %q", e, g)
+ }
+
+ // Do should also follow redirects.
+ greq, _ := NewRequest("GET", ts.URL, nil)
+ _, err = c.Do(greq)
+ if e, g := `Get "/?n=10": stopped after 10 redirects`, fmt.Sprintf("%v", err); e != g {
+ t.Errorf("with default client Do, expected error %q, got %q", e, g)
+ }
+
+ // Requests with an empty Method should also redirect (Issue 12705)
+ greq.Method = ""
+ _, err = c.Do(greq)
+ if e, g := `Get "/?n=10": stopped after 10 redirects`, fmt.Sprintf("%v", err); e != g {
+ t.Errorf("with default client Do and empty Method, expected error %q, got %q", e, g)
+ }
+
+ var checkErr error
+ var lastVia []*Request
+ var lastReq *Request
+ c.CheckRedirect = func(req *Request, via []*Request) error {
+ lastReq = req
+ lastVia = via
+ return checkErr
+ }
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatalf("Get error: %v", err)
+ }
+ res.Body.Close()
+ finalURL := res.Request.URL.String()
+ if e, g := "<nil>", fmt.Sprintf("%v", err); e != g {
+ t.Errorf("with custom client, expected error %q, got %q", e, g)
+ }
+ if !strings.HasSuffix(finalURL, "/?n=15") {
+ t.Errorf("expected final url to end in /?n=15; got url %q", finalURL)
+ }
+ if e, g := 15, len(lastVia); e != g {
+ t.Errorf("expected lastVia to have contained %d elements; got %d", e, g)
+ }
+
+ // Test that Request.Cancel is propagated between requests (Issue 14053)
+ creq, _ := NewRequest("HEAD", ts.URL, nil)
+ cancel := make(chan struct{})
+ creq.Cancel = cancel
+ if _, err := c.Do(creq); err != nil {
+ t.Fatal(err)
+ }
+ if lastReq == nil {
+ t.Fatal("didn't see redirect")
+ }
+ if lastReq.Cancel != cancel {
+ t.Errorf("expected lastReq to have the cancel channel set on the initial req")
+ }
+
+ checkErr = errors.New("no redirects allowed")
+ res, err = c.Get(ts.URL)
+ if urlError, ok := err.(*url.Error); !ok || urlError.Err != checkErr {
+ t.Errorf("with redirects forbidden, expected a *url.Error with our 'no redirects allowed' error inside; got %#v (%q)", err, err)
+ }
+ if res == nil {
+ t.Fatalf("Expected a non-nil Response on CheckRedirect failure (https://golang.org/issue/3795)")
+ }
+ res.Body.Close()
+ if res.Header.Get("Location") == "" {
+ t.Errorf("no Location header in Response")
+ }
+}
+
+// Tests that Client redirects' contexts are derived from the original request's context.
+func TestClientRedirectsContext(t *testing.T) { run(t, testClientRedirectsContext) }
+func testClientRedirectsContext(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ Redirect(w, r, "/", StatusTemporaryRedirect)
+ })).ts
+
+ ctx, cancel := context.WithCancel(context.Background())
+ c := ts.Client()
+ c.CheckRedirect = func(req *Request, via []*Request) error {
+ cancel()
+ select {
+ case <-req.Context().Done():
+ return nil
+ case <-time.After(5 * time.Second):
+ return errors.New("redirected request's context never expired after root request canceled")
+ }
+ }
+ req, _ := NewRequestWithContext(ctx, "GET", ts.URL, nil)
+ _, err := c.Do(req)
+ ue, ok := err.(*url.Error)
+ if !ok {
+ t.Fatalf("got error %T; want *url.Error", err)
+ }
+ if ue.Err != context.Canceled {
+ t.Errorf("url.Error.Err = %v; want %v", ue.Err, context.Canceled)
+ }
+}
+
+type redirectTest struct {
+ suffix string
+ want int // response code
+ redirectBody string
+}
+
+func TestPostRedirects(t *testing.T) {
+ postRedirectTests := []redirectTest{
+ {"/", 200, "first"},
+ {"/?code=301&next=302", 200, "c301"},
+ {"/?code=302&next=302", 200, "c302"},
+ {"/?code=303&next=301", 200, "c303wc301"}, // Issue 9348
+ {"/?code=304", 304, "c304"},
+ {"/?code=305", 305, "c305"},
+ {"/?code=307&next=303,308,302", 200, "c307"},
+ {"/?code=308&next=302,301", 200, "c308"},
+ {"/?code=404", 404, "c404"},
+ }
+
+ wantSegments := []string{
+ `POST / "first"`,
+ `POST /?code=301&next=302 "c301"`,
+ `GET /?code=302 ""`,
+ `GET / ""`,
+ `POST /?code=302&next=302 "c302"`,
+ `GET /?code=302 ""`,
+ `GET / ""`,
+ `POST /?code=303&next=301 "c303wc301"`,
+ `GET /?code=301 ""`,
+ `GET / ""`,
+ `POST /?code=304 "c304"`,
+ `POST /?code=305 "c305"`,
+ `POST /?code=307&next=303,308,302 "c307"`,
+ `POST /?code=303&next=308,302 "c307"`,
+ `GET /?code=308&next=302 ""`,
+ `GET /?code=302 "c307"`,
+ `GET / ""`,
+ `POST /?code=308&next=302,301 "c308"`,
+ `POST /?code=302&next=301 "c308"`,
+ `GET /?code=301 ""`,
+ `GET / ""`,
+ `POST /?code=404 "c404"`,
+ }
+ want := strings.Join(wantSegments, "\n")
+ run(t, func(t *testing.T, mode testMode) {
+ testRedirectsByMethod(t, mode, "POST", postRedirectTests, want)
+ })
+}
+
+func TestDeleteRedirects(t *testing.T) {
+ deleteRedirectTests := []redirectTest{
+ {"/", 200, "first"},
+ {"/?code=301&next=302,308", 200, "c301"},
+ {"/?code=302&next=302", 200, "c302"},
+ {"/?code=303", 200, "c303"},
+ {"/?code=307&next=301,308,303,302,304", 304, "c307"},
+ {"/?code=308&next=307", 200, "c308"},
+ {"/?code=404", 404, "c404"},
+ }
+
+ wantSegments := []string{
+ `DELETE / "first"`,
+ `DELETE /?code=301&next=302,308 "c301"`,
+ `GET /?code=302&next=308 ""`,
+ `GET /?code=308 ""`,
+ `GET / "c301"`,
+ `DELETE /?code=302&next=302 "c302"`,
+ `GET /?code=302 ""`,
+ `GET / ""`,
+ `DELETE /?code=303 "c303"`,
+ `GET / ""`,
+ `DELETE /?code=307&next=301,308,303,302,304 "c307"`,
+ `DELETE /?code=301&next=308,303,302,304 "c307"`,
+ `GET /?code=308&next=303,302,304 ""`,
+ `GET /?code=303&next=302,304 "c307"`,
+ `GET /?code=302&next=304 ""`,
+ `GET /?code=304 ""`,
+ `DELETE /?code=308&next=307 "c308"`,
+ `DELETE /?code=307 "c308"`,
+ `DELETE / "c308"`,
+ `DELETE /?code=404 "c404"`,
+ }
+ want := strings.Join(wantSegments, "\n")
+ run(t, func(t *testing.T, mode testMode) {
+ testRedirectsByMethod(t, mode, "DELETE", deleteRedirectTests, want)
+ })
+}
+
+func testRedirectsByMethod(t *testing.T, mode testMode, method string, table []redirectTest, want string) {
+ var log struct {
+ sync.Mutex
+ bytes.Buffer
+ }
+ var ts *httptest.Server
+ ts = newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ log.Lock()
+ slurp, _ := io.ReadAll(r.Body)
+ fmt.Fprintf(&log.Buffer, "%s %s %q", r.Method, r.RequestURI, slurp)
+ if cl := r.Header.Get("Content-Length"); r.Method == "GET" && len(slurp) == 0 && (r.ContentLength != 0 || cl != "") {
+ fmt.Fprintf(&log.Buffer, " (but with body=%T, content-length = %v, %q)", r.Body, r.ContentLength, cl)
+ }
+ log.WriteByte('\n')
+ log.Unlock()
+ urlQuery := r.URL.Query()
+ if v := urlQuery.Get("code"); v != "" {
+ location := ts.URL
+ if final := urlQuery.Get("next"); final != "" {
+ first, rest, _ := strings.Cut(final, ",")
+ location = fmt.Sprintf("%s?code=%s", location, first)
+ if rest != "" {
+ location = fmt.Sprintf("%s&next=%s", location, rest)
+ }
+ }
+ code, _ := strconv.Atoi(v)
+ if code/100 == 3 {
+ w.Header().Set("Location", location)
+ }
+ w.WriteHeader(code)
+ }
+ })).ts
+
+ c := ts.Client()
+ for _, tt := range table {
+ content := tt.redirectBody
+ req, _ := NewRequest(method, ts.URL+tt.suffix, strings.NewReader(content))
+ req.GetBody = func() (io.ReadCloser, error) { return io.NopCloser(strings.NewReader(content)), nil }
+ res, err := c.Do(req)
+
+ if err != nil {
+ t.Fatal(err)
+ }
+ if res.StatusCode != tt.want {
+ t.Errorf("POST %s: status code = %d; want %d", tt.suffix, res.StatusCode, tt.want)
+ }
+ }
+ log.Lock()
+ got := log.String()
+ log.Unlock()
+
+ got = strings.TrimSpace(got)
+ want = strings.TrimSpace(want)
+
+ if got != want {
+ got, want, lines := removeCommonLines(got, want)
+ t.Errorf("Log differs after %d common lines.\n\nGot:\n%s\n\nWant:\n%s\n", lines, got, want)
+ }
+}
+
+func removeCommonLines(a, b string) (asuffix, bsuffix string, commonLines int) {
+ for {
+ nl := strings.IndexByte(a, '\n')
+ if nl < 0 {
+ return a, b, commonLines
+ }
+ line := a[:nl+1]
+ if !strings.HasPrefix(b, line) {
+ return a, b, commonLines
+ }
+ commonLines++
+ a = a[len(line):]
+ b = b[len(line):]
+ }
+}
+
+func TestClientRedirectUseResponse(t *testing.T) { run(t, testClientRedirectUseResponse) }
+func testClientRedirectUseResponse(t *testing.T, mode testMode) {
+ const body = "Hello, world."
+ var ts *httptest.Server
+ ts = newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ if strings.Contains(r.URL.Path, "/other") {
+ io.WriteString(w, "wrong body")
+ } else {
+ w.Header().Set("Location", ts.URL+"/other")
+ w.WriteHeader(StatusFound)
+ io.WriteString(w, body)
+ }
+ })).ts
+
+ c := ts.Client()
+ c.CheckRedirect = func(req *Request, via []*Request) error {
+ if req.Response == nil {
+ t.Error("expected non-nil Request.Response")
+ }
+ return ErrUseLastResponse
+ }
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if res.StatusCode != StatusFound {
+ t.Errorf("status = %d; want %d", res.StatusCode, StatusFound)
+ }
+ defer res.Body.Close()
+ slurp, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if string(slurp) != body {
+ t.Errorf("body = %q; want %q", slurp, body)
+ }
+}
+
+// Issues 17773 and 49281: don't follow a 3xx if the response doesn't
+// have a Location header.
+func TestClientRedirectNoLocation(t *testing.T) { run(t, testClientRedirectNoLocation) }
+func testClientRedirectNoLocation(t *testing.T, mode testMode) {
+ for _, code := range []int{301, 308} {
+ t.Run(fmt.Sprint(code), func(t *testing.T) {
+ setParallel(t)
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Foo", "Bar")
+ w.WriteHeader(code)
+ }))
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ if res.StatusCode != code {
+ t.Errorf("status = %d; want %d", res.StatusCode, code)
+ }
+ if got := res.Header.Get("Foo"); got != "Bar" {
+ t.Errorf("Foo header = %q; want Bar", got)
+ }
+ })
+ }
+}
+
+// Don't follow a 307/308 if we can't resent the request body.
+func TestClientRedirect308NoGetBody(t *testing.T) { run(t, testClientRedirect308NoGetBody) }
+func testClientRedirect308NoGetBody(t *testing.T, mode testMode) {
+ const fakeURL = "https://localhost:1234/" // won't be hit
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Location", fakeURL)
+ w.WriteHeader(308)
+ })).ts
+ req, err := NewRequest("POST", ts.URL, strings.NewReader("some body"))
+ if err != nil {
+ t.Fatal(err)
+ }
+ c := ts.Client()
+ req.GetBody = nil // so it can't rewind.
+ res, err := c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ if res.StatusCode != 308 {
+ t.Errorf("status = %d; want %d", res.StatusCode, 308)
+ }
+ if got := res.Header.Get("Location"); got != fakeURL {
+ t.Errorf("Location header = %q; want %q", got, fakeURL)
+ }
+}
+
+var expectedCookies = []*Cookie{
+ {Name: "ChocolateChip", Value: "tasty"},
+ {Name: "First", Value: "Hit"},
+ {Name: "Second", Value: "Hit"},
+}
+
+var echoCookiesRedirectHandler = HandlerFunc(func(w ResponseWriter, r *Request) {
+ for _, cookie := range r.Cookies() {
+ SetCookie(w, cookie)
+ }
+ if r.URL.Path == "/" {
+ SetCookie(w, expectedCookies[1])
+ Redirect(w, r, "/second", StatusMovedPermanently)
+ } else {
+ SetCookie(w, expectedCookies[2])
+ w.Write([]byte("hello"))
+ }
+})
+
+func TestClientSendsCookieFromJar(t *testing.T) {
+ defer afterTest(t)
+ tr := &recordingTransport{}
+ client := &Client{Transport: tr}
+ client.Jar = &TestJar{perURL: make(map[string][]*Cookie)}
+ us := "http://dummy.faketld/"
+ u, _ := url.Parse(us)
+ client.Jar.SetCookies(u, expectedCookies)
+
+ client.Get(us) // Note: doesn't hit network
+ matchReturnedCookies(t, expectedCookies, tr.req.Cookies())
+
+ client.Head(us) // Note: doesn't hit network
+ matchReturnedCookies(t, expectedCookies, tr.req.Cookies())
+
+ client.Post(us, "text/plain", strings.NewReader("body")) // Note: doesn't hit network
+ matchReturnedCookies(t, expectedCookies, tr.req.Cookies())
+
+ client.PostForm(us, url.Values{}) // Note: doesn't hit network
+ matchReturnedCookies(t, expectedCookies, tr.req.Cookies())
+
+ req, _ := NewRequest("GET", us, nil)
+ client.Do(req) // Note: doesn't hit network
+ matchReturnedCookies(t, expectedCookies, tr.req.Cookies())
+
+ req, _ = NewRequest("POST", us, nil)
+ client.Do(req) // Note: doesn't hit network
+ matchReturnedCookies(t, expectedCookies, tr.req.Cookies())
+}
+
+// Just enough correctness for our redirect tests. Uses the URL.Host as the
+// scope of all cookies.
+type TestJar struct {
+ m sync.Mutex
+ perURL map[string][]*Cookie
+}
+
+func (j *TestJar) SetCookies(u *url.URL, cookies []*Cookie) {
+ j.m.Lock()
+ defer j.m.Unlock()
+ if j.perURL == nil {
+ j.perURL = make(map[string][]*Cookie)
+ }
+ j.perURL[u.Host] = cookies
+}
+
+func (j *TestJar) Cookies(u *url.URL) []*Cookie {
+ j.m.Lock()
+ defer j.m.Unlock()
+ return j.perURL[u.Host]
+}
+
+func TestRedirectCookiesJar(t *testing.T) { run(t, testRedirectCookiesJar) }
+func testRedirectCookiesJar(t *testing.T, mode testMode) {
+ var ts *httptest.Server
+ ts = newClientServerTest(t, mode, echoCookiesRedirectHandler).ts
+ c := ts.Client()
+ c.Jar = new(TestJar)
+ u, _ := url.Parse(ts.URL)
+ c.Jar.SetCookies(u, []*Cookie{expectedCookies[0]})
+ resp, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ resp.Body.Close()
+ matchReturnedCookies(t, expectedCookies, resp.Cookies())
+}
+
+func matchReturnedCookies(t *testing.T, expected, given []*Cookie) {
+ if len(given) != len(expected) {
+ t.Logf("Received cookies: %v", given)
+ t.Errorf("Expected %d cookies, got %d", len(expected), len(given))
+ }
+ for _, ec := range expected {
+ foundC := false
+ for _, c := range given {
+ if ec.Name == c.Name && ec.Value == c.Value {
+ foundC = true
+ break
+ }
+ }
+ if !foundC {
+ t.Errorf("Missing cookie %v", ec)
+ }
+ }
+}
+
+func TestJarCalls(t *testing.T) { run(t, testJarCalls, []testMode{http1Mode}) }
+func testJarCalls(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ pathSuffix := r.RequestURI[1:]
+ if r.RequestURI == "/nosetcookie" {
+ return // don't set cookies for this path
+ }
+ SetCookie(w, &Cookie{Name: "name" + pathSuffix, Value: "val" + pathSuffix})
+ if r.RequestURI == "/" {
+ Redirect(w, r, "http://secondhost.fake/secondpath", 302)
+ }
+ })).ts
+ jar := new(RecordingJar)
+ c := ts.Client()
+ c.Jar = jar
+ c.Transport.(*Transport).Dial = func(_ string, _ string) (net.Conn, error) {
+ return net.Dial("tcp", ts.Listener.Addr().String())
+ }
+ _, err := c.Get("http://firsthost.fake/")
+ if err != nil {
+ t.Fatal(err)
+ }
+ _, err = c.Get("http://firsthost.fake/nosetcookie")
+ if err != nil {
+ t.Fatal(err)
+ }
+ got := jar.log.String()
+ want := `Cookies("http://firsthost.fake/")
+SetCookie("http://firsthost.fake/", [name=val])
+Cookies("http://secondhost.fake/secondpath")
+SetCookie("http://secondhost.fake/secondpath", [namesecondpath=valsecondpath])
+Cookies("http://firsthost.fake/nosetcookie")
+`
+ if got != want {
+ t.Errorf("Got Jar calls:\n%s\nWant:\n%s", got, want)
+ }
+}
+
+// RecordingJar keeps a log of calls made to it, without
+// tracking any cookies.
+type RecordingJar struct {
+ mu sync.Mutex
+ log bytes.Buffer
+}
+
+func (j *RecordingJar) SetCookies(u *url.URL, cookies []*Cookie) {
+ j.logf("SetCookie(%q, %v)\n", u, cookies)
+}
+
+func (j *RecordingJar) Cookies(u *url.URL) []*Cookie {
+ j.logf("Cookies(%q)\n", u)
+ return nil
+}
+
+func (j *RecordingJar) logf(format string, args ...any) {
+ j.mu.Lock()
+ defer j.mu.Unlock()
+ fmt.Fprintf(&j.log, format, args...)
+}
+
+func TestStreamingGet(t *testing.T) { run(t, testStreamingGet) }
+func testStreamingGet(t *testing.T, mode testMode) {
+ say := make(chan string)
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.(Flusher).Flush()
+ for str := range say {
+ w.Write([]byte(str))
+ w.(Flusher).Flush()
+ }
+ }))
+
+ c := cst.c
+ res, err := c.Get(cst.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ var buf [10]byte
+ for _, str := range []string{"i", "am", "also", "known", "as", "comet"} {
+ say <- str
+ n, err := io.ReadFull(res.Body, buf[0:len(str)])
+ if err != nil {
+ t.Fatalf("ReadFull on %q: %v", str, err)
+ }
+ if n != len(str) {
+ t.Fatalf("Receiving %q, only read %d bytes", str, n)
+ }
+ got := string(buf[0:n])
+ if got != str {
+ t.Fatalf("Expected %q, got %q", str, got)
+ }
+ }
+ close(say)
+ _, err = io.ReadFull(res.Body, buf[0:1])
+ if err != io.EOF {
+ t.Fatalf("at end expected EOF, got %v", err)
+ }
+}
+
+type writeCountingConn struct {
+ net.Conn
+ count *int
+}
+
+func (c *writeCountingConn) Write(p []byte) (int, error) {
+ *c.count++
+ return c.Conn.Write(p)
+}
+
+// TestClientWrites verifies that client requests are buffered and we
+// don't send a TCP packet per line of the http request + body.
+func TestClientWrites(t *testing.T) { run(t, testClientWrites, []testMode{http1Mode}) }
+func testClientWrites(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ })).ts
+
+ writes := 0
+ dialer := func(netz string, addr string) (net.Conn, error) {
+ c, err := net.Dial(netz, addr)
+ if err == nil {
+ c = &writeCountingConn{c, &writes}
+ }
+ return c, err
+ }
+ c := ts.Client()
+ c.Transport.(*Transport).Dial = dialer
+
+ _, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if writes != 1 {
+ t.Errorf("Get request did %d Write calls, want 1", writes)
+ }
+
+ writes = 0
+ _, err = c.PostForm(ts.URL, url.Values{"foo": {"bar"}})
+ if err != nil {
+ t.Fatal(err)
+ }
+ if writes != 1 {
+ t.Errorf("Post request did %d Write calls, want 1", writes)
+ }
+}
+
+func TestClientInsecureTransport(t *testing.T) {
+ run(t, testClientInsecureTransport, []testMode{https1Mode, http2Mode})
+}
+func testClientInsecureTransport(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Write([]byte("Hello"))
+ })).ts
+ errc := make(chanWriter, 10) // but only expecting 1
+ ts.Config.ErrorLog = log.New(errc, "", 0)
+ defer ts.Close()
+
+ // TODO(bradfitz): add tests for skipping hostname checks too?
+ // would require a new cert for testing, and probably
+ // redundant with these tests.
+ c := ts.Client()
+ for _, insecure := range []bool{true, false} {
+ c.Transport.(*Transport).TLSClientConfig = &tls.Config{
+ InsecureSkipVerify: insecure,
+ }
+ res, err := c.Get(ts.URL)
+ if (err == nil) != insecure {
+ t.Errorf("insecure=%v: got unexpected err=%v", insecure, err)
+ }
+ if res != nil {
+ res.Body.Close()
+ }
+ }
+
+ select {
+ case v := <-errc:
+ if !strings.Contains(v, "TLS handshake error") {
+ t.Errorf("expected an error log message containing 'TLS handshake error'; got %q", v)
+ }
+ case <-time.After(5 * time.Second):
+ t.Errorf("timeout waiting for logged error")
+ }
+
+}
+
+func TestClientErrorWithRequestURI(t *testing.T) {
+ defer afterTest(t)
+ req, _ := NewRequest("GET", "http://localhost:1234/", nil)
+ req.RequestURI = "/this/field/is/illegal/and/should/error/"
+ _, err := DefaultClient.Do(req)
+ if err == nil {
+ t.Fatalf("expected an error")
+ }
+ if !strings.Contains(err.Error(), "RequestURI") {
+ t.Errorf("wanted error mentioning RequestURI; got error: %v", err)
+ }
+}
+
+func TestClientWithCorrectTLSServerName(t *testing.T) {
+ run(t, testClientWithCorrectTLSServerName, []testMode{https1Mode, http2Mode})
+}
+func testClientWithCorrectTLSServerName(t *testing.T, mode testMode) {
+ const serverName = "example.com"
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.TLS.ServerName != serverName {
+ t.Errorf("expected client to set ServerName %q, got: %q", serverName, r.TLS.ServerName)
+ }
+ })).ts
+
+ c := ts.Client()
+ c.Transport.(*Transport).TLSClientConfig.ServerName = serverName
+ if _, err := c.Get(ts.URL); err != nil {
+ t.Fatalf("expected successful TLS connection, got error: %v", err)
+ }
+}
+
+func TestClientWithIncorrectTLSServerName(t *testing.T) {
+ run(t, testClientWithIncorrectTLSServerName, []testMode{https1Mode, http2Mode})
+}
+func testClientWithIncorrectTLSServerName(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts
+ errc := make(chanWriter, 10) // but only expecting 1
+ ts.Config.ErrorLog = log.New(errc, "", 0)
+
+ c := ts.Client()
+ c.Transport.(*Transport).TLSClientConfig.ServerName = "badserver"
+ _, err := c.Get(ts.URL)
+ if err == nil {
+ t.Fatalf("expected an error")
+ }
+ if !strings.Contains(err.Error(), "127.0.0.1") || !strings.Contains(err.Error(), "badserver") {
+ t.Errorf("wanted error mentioning 127.0.0.1 and badserver; got error: %v", err)
+ }
+ select {
+ case v := <-errc:
+ if !strings.Contains(v, "TLS handshake error") {
+ t.Errorf("expected an error log message containing 'TLS handshake error'; got %q", v)
+ }
+ case <-time.After(5 * time.Second):
+ t.Errorf("timeout waiting for logged error")
+ }
+}
+
+// Test for golang.org/issue/5829; the Transport should respect TLSClientConfig.ServerName
+// when not empty.
+//
+// tls.Config.ServerName (non-empty, set to "example.com") takes
+// precedence over "some-other-host.tld" which previously incorrectly
+// took precedence. We don't actually connect to (or even resolve)
+// "some-other-host.tld", though, because of the Transport.Dial hook.
+//
+// The httptest.Server has a cert with "example.com" as its name.
+func TestTransportUsesTLSConfigServerName(t *testing.T) {
+ run(t, testTransportUsesTLSConfigServerName, []testMode{https1Mode, http2Mode})
+}
+func testTransportUsesTLSConfigServerName(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Write([]byte("Hello"))
+ })).ts
+
+ c := ts.Client()
+ tr := c.Transport.(*Transport)
+ tr.TLSClientConfig.ServerName = "example.com" // one of httptest's Server cert names
+ tr.Dial = func(netw, addr string) (net.Conn, error) {
+ return net.Dial(netw, ts.Listener.Addr().String())
+ }
+ res, err := c.Get("https://some-other-host.tld/")
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+}
+
+func TestResponseSetsTLSConnectionState(t *testing.T) {
+ run(t, testResponseSetsTLSConnectionState, []testMode{https1Mode})
+}
+func testResponseSetsTLSConnectionState(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Write([]byte("Hello"))
+ })).ts
+
+ c := ts.Client()
+ tr := c.Transport.(*Transport)
+ tr.TLSClientConfig.CipherSuites = []uint16{tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA}
+ tr.TLSClientConfig.MaxVersion = tls.VersionTLS12 // to get to pick the cipher suite
+ tr.Dial = func(netw, addr string) (net.Conn, error) {
+ return net.Dial(netw, ts.Listener.Addr().String())
+ }
+ res, err := c.Get("https://example.com/")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ if res.TLS == nil {
+ t.Fatal("Response didn't set TLS Connection State.")
+ }
+ if got, want := res.TLS.CipherSuite, tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA; got != want {
+ t.Errorf("TLS Cipher Suite = %d; want %d", got, want)
+ }
+}
+
+// Check that an HTTPS client can interpret a particular TLS error
+// to determine that the server is speaking HTTP.
+// See golang.org/issue/11111.
+func TestHTTPSClientDetectsHTTPServer(t *testing.T) {
+ run(t, testHTTPSClientDetectsHTTPServer, []testMode{http1Mode})
+}
+func testHTTPSClientDetectsHTTPServer(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts
+ ts.Config.ErrorLog = quietLog
+
+ _, err := Get(strings.Replace(ts.URL, "http", "https", 1))
+ if got := err.Error(); !strings.Contains(got, "HTTP response to HTTPS client") {
+ t.Fatalf("error = %q; want error indicating HTTP response to HTTPS request", got)
+ }
+}
+
+// Verify Response.ContentLength is populated. https://golang.org/issue/4126
+func TestClientHeadContentLength(t *testing.T) { run(t, testClientHeadContentLength) }
+func testClientHeadContentLength(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ if v := r.FormValue("cl"); v != "" {
+ w.Header().Set("Content-Length", v)
+ }
+ }))
+ tests := []struct {
+ suffix string
+ want int64
+ }{
+ {"/?cl=1234", 1234},
+ {"/?cl=0", 0},
+ {"", -1},
+ }
+ for _, tt := range tests {
+ req, _ := NewRequest("HEAD", cst.ts.URL+tt.suffix, nil)
+ res, err := cst.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if res.ContentLength != tt.want {
+ t.Errorf("Content-Length = %d; want %d", res.ContentLength, tt.want)
+ }
+ bs, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(bs) != 0 {
+ t.Errorf("Unexpected content: %q", bs)
+ }
+ }
+}
+
+func TestEmptyPasswordAuth(t *testing.T) { run(t, testEmptyPasswordAuth) }
+func testEmptyPasswordAuth(t *testing.T, mode testMode) {
+ gopher := "gopher"
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ auth := r.Header.Get("Authorization")
+ if strings.HasPrefix(auth, "Basic ") {
+ encoded := auth[6:]
+ decoded, err := base64.StdEncoding.DecodeString(encoded)
+ if err != nil {
+ t.Fatal(err)
+ }
+ expected := gopher + ":"
+ s := string(decoded)
+ if expected != s {
+ t.Errorf("Invalid Authorization header. Got %q, wanted %q", s, expected)
+ }
+ } else {
+ t.Errorf("Invalid auth %q", auth)
+ }
+ })).ts
+ defer ts.Close()
+ req, err := NewRequest("GET", ts.URL, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ req.URL.User = url.User(gopher)
+ c := ts.Client()
+ resp, err := c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer resp.Body.Close()
+}
+
+func TestBasicAuth(t *testing.T) {
+ defer afterTest(t)
+ tr := &recordingTransport{}
+ client := &Client{Transport: tr}
+
+ url := "http://My%20User:My%20Pass@dummy.faketld/"
+ expected := "My User:My Pass"
+ client.Get(url)
+
+ if tr.req.Method != "GET" {
+ t.Errorf("got method %q, want %q", tr.req.Method, "GET")
+ }
+ if tr.req.URL.String() != url {
+ t.Errorf("got URL %q, want %q", tr.req.URL.String(), url)
+ }
+ if tr.req.Header == nil {
+ t.Fatalf("expected non-nil request Header")
+ }
+ auth := tr.req.Header.Get("Authorization")
+ if strings.HasPrefix(auth, "Basic ") {
+ encoded := auth[6:]
+ decoded, err := base64.StdEncoding.DecodeString(encoded)
+ if err != nil {
+ t.Fatal(err)
+ }
+ s := string(decoded)
+ if expected != s {
+ t.Errorf("Invalid Authorization header. Got %q, wanted %q", s, expected)
+ }
+ } else {
+ t.Errorf("Invalid auth %q", auth)
+ }
+}
+
+func TestBasicAuthHeadersPreserved(t *testing.T) {
+ defer afterTest(t)
+ tr := &recordingTransport{}
+ client := &Client{Transport: tr}
+
+ // If Authorization header is provided, username in URL should not override it
+ url := "http://My%20User@dummy.faketld/"
+ req, err := NewRequest("GET", url, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ req.SetBasicAuth("My User", "My Pass")
+ expected := "My User:My Pass"
+ client.Do(req)
+
+ if tr.req.Method != "GET" {
+ t.Errorf("got method %q, want %q", tr.req.Method, "GET")
+ }
+ if tr.req.URL.String() != url {
+ t.Errorf("got URL %q, want %q", tr.req.URL.String(), url)
+ }
+ if tr.req.Header == nil {
+ t.Fatalf("expected non-nil request Header")
+ }
+ auth := tr.req.Header.Get("Authorization")
+ if strings.HasPrefix(auth, "Basic ") {
+ encoded := auth[6:]
+ decoded, err := base64.StdEncoding.DecodeString(encoded)
+ if err != nil {
+ t.Fatal(err)
+ }
+ s := string(decoded)
+ if expected != s {
+ t.Errorf("Invalid Authorization header. Got %q, wanted %q", s, expected)
+ }
+ } else {
+ t.Errorf("Invalid auth %q", auth)
+ }
+
+}
+
+func TestStripPasswordFromError(t *testing.T) {
+ client := &Client{Transport: &recordingTransport{}}
+ testCases := []struct {
+ desc string
+ in string
+ out string
+ }{
+ {
+ desc: "Strip password from error message",
+ in: "http://user:password@dummy.faketld/",
+ out: `Get "http://user:***@dummy.faketld/": dummy impl`,
+ },
+ {
+ desc: "Don't Strip password from domain name",
+ in: "http://user:password@password.faketld/",
+ out: `Get "http://user:***@password.faketld/": dummy impl`,
+ },
+ {
+ desc: "Don't Strip password from path",
+ in: "http://user:password@dummy.faketld/password",
+ out: `Get "http://user:***@dummy.faketld/password": dummy impl`,
+ },
+ {
+ desc: "Strip escaped password",
+ in: "http://user:pa%2Fssword@dummy.faketld/",
+ out: `Get "http://user:***@dummy.faketld/": dummy impl`,
+ },
+ }
+ for _, tC := range testCases {
+ t.Run(tC.desc, func(t *testing.T) {
+ _, err := client.Get(tC.in)
+ if err.Error() != tC.out {
+ t.Errorf("Unexpected output for %q: expected %q, actual %q",
+ tC.in, tC.out, err.Error())
+ }
+ })
+ }
+}
+
+func TestClientTimeout(t *testing.T) { run(t, testClientTimeout) }
+func testClientTimeout(t *testing.T, mode testMode) {
+ var (
+ mu sync.Mutex
+ nonce string // a unique per-request string
+ sawSlowNonce bool // true if the handler saw /slow?nonce=<nonce>
+ )
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ _ = r.ParseForm()
+ if r.URL.Path == "/" {
+ Redirect(w, r, "/slow?nonce="+r.Form.Get("nonce"), StatusFound)
+ return
+ }
+ if r.URL.Path == "/slow" {
+ mu.Lock()
+ if r.Form.Get("nonce") == nonce {
+ sawSlowNonce = true
+ } else {
+ t.Logf("mismatched nonce: received %s, want %s", r.Form.Get("nonce"), nonce)
+ }
+ mu.Unlock()
+
+ w.Write([]byte("Hello"))
+ w.(Flusher).Flush()
+ <-r.Context().Done()
+ return
+ }
+ }))
+
+ // Try to trigger a timeout after reading part of the response body.
+ // The initial timeout is empirically usually long enough on a decently fast
+ // machine, but if we undershoot we'll retry with exponentially longer
+ // timeouts until the test either passes or times out completely.
+ // This keeps the test reasonably fast in the typical case but allows it to
+ // also eventually succeed on arbitrarily slow machines.
+ timeout := 10 * time.Millisecond
+ nextNonce := 0
+ for ; ; timeout *= 2 {
+ if timeout <= 0 {
+ // The only way we can feasibly hit this while the test is running is if
+ // the request fails without actually waiting for the timeout to occur.
+ t.Fatalf("timeout overflow")
+ }
+ if deadline, ok := t.Deadline(); ok && !time.Now().Add(timeout).Before(deadline) {
+ t.Fatalf("failed to produce expected timeout before test deadline")
+ }
+ t.Logf("attempting test with timeout %v", timeout)
+ cst.c.Timeout = timeout
+
+ mu.Lock()
+ nonce = fmt.Sprint(nextNonce)
+ nextNonce++
+ sawSlowNonce = false
+ mu.Unlock()
+ res, err := cst.c.Get(cst.ts.URL + "/?nonce=" + nonce)
+ if err != nil {
+ if strings.Contains(err.Error(), "Client.Timeout") {
+ // Timed out before handler could respond.
+ t.Logf("timeout before response received")
+ continue
+ }
+ if runtime.GOOS == "windows" && strings.HasPrefix(runtime.GOARCH, "arm") {
+ testenv.SkipFlaky(t, 43120)
+ }
+ t.Fatal(err)
+ }
+
+ mu.Lock()
+ ok := sawSlowNonce
+ mu.Unlock()
+ if !ok {
+ t.Fatal("handler never got /slow request, but client returned response")
+ }
+
+ _, err = io.ReadAll(res.Body)
+ res.Body.Close()
+
+ if err == nil {
+ t.Fatal("expected error from ReadAll")
+ }
+ ne, ok := err.(net.Error)
+ if !ok {
+ t.Errorf("error value from ReadAll was %T; expected some net.Error", err)
+ } else if !ne.Timeout() {
+ t.Errorf("net.Error.Timeout = false; want true")
+ }
+ if got := ne.Error(); !strings.Contains(got, "(Client.Timeout") {
+ if runtime.GOOS == "windows" && strings.HasPrefix(runtime.GOARCH, "arm") {
+ testenv.SkipFlaky(t, 43120)
+ }
+ t.Errorf("error string = %q; missing timeout substring", got)
+ }
+
+ break
+ }
+}
+
+// Client.Timeout firing before getting to the body
+func TestClientTimeout_Headers(t *testing.T) { run(t, testClientTimeout_Headers) }
+func testClientTimeout_Headers(t *testing.T, mode testMode) {
+ donec := make(chan bool, 1)
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ <-donec
+ }), optQuietLog)
+ // Note that we use a channel send here and not a close.
+ // The race detector doesn't know that we're waiting for a timeout
+ // and thinks that the waitgroup inside httptest.Server is added to concurrently
+ // with us closing it. If we timed out immediately, we could close the testserver
+ // before we entered the handler. We're not timing out immediately and there's
+ // no way we would be done before we entered the handler, but the race detector
+ // doesn't know this, so synchronize explicitly.
+ defer func() { donec <- true }()
+
+ cst.c.Timeout = 5 * time.Millisecond
+ res, err := cst.c.Get(cst.ts.URL)
+ if err == nil {
+ res.Body.Close()
+ t.Fatal("got response from Get; expected error")
+ }
+ if _, ok := err.(*url.Error); !ok {
+ t.Fatalf("Got error of type %T; want *url.Error", err)
+ }
+ ne, ok := err.(net.Error)
+ if !ok {
+ t.Fatalf("Got error of type %T; want some net.Error", err)
+ }
+ if !ne.Timeout() {
+ t.Error("net.Error.Timeout = false; want true")
+ }
+ if got := ne.Error(); !strings.Contains(got, "Client.Timeout exceeded") {
+ if runtime.GOOS == "windows" && strings.HasPrefix(runtime.GOARCH, "arm") {
+ testenv.SkipFlaky(t, 43120)
+ }
+ t.Errorf("error string = %q; missing timeout substring", got)
+ }
+}
+
+// Issue 16094: if Client.Timeout is set but not hit, a Timeout error shouldn't be
+// returned.
+func TestClientTimeoutCancel(t *testing.T) { run(t, testClientTimeoutCancel) }
+func testClientTimeoutCancel(t *testing.T, mode testMode) {
+ testDone := make(chan struct{})
+ ctx, cancel := context.WithCancel(context.Background())
+
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.(Flusher).Flush()
+ <-testDone
+ }))
+ defer close(testDone)
+
+ cst.c.Timeout = 1 * time.Hour
+ req, _ := NewRequest("GET", cst.ts.URL, nil)
+ req.Cancel = ctx.Done()
+ res, err := cst.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ cancel()
+ _, err = io.Copy(io.Discard, res.Body)
+ if err != ExportErrRequestCanceled {
+ t.Fatalf("error = %v; want errRequestCanceled", err)
+ }
+}
+
+// Issue 49366: if Client.Timeout is set but not hit, no error should be returned.
+func TestClientTimeoutDoesNotExpire(t *testing.T) { run(t, testClientTimeoutDoesNotExpire) }
+func testClientTimeoutDoesNotExpire(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Write([]byte("body"))
+ }))
+
+ cst.c.Timeout = 1 * time.Hour
+ req, _ := NewRequest("GET", cst.ts.URL, nil)
+ res, err := cst.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if _, err = io.Copy(io.Discard, res.Body); err != nil {
+ t.Fatalf("io.Copy(io.Discard, res.Body) = %v, want nil", err)
+ }
+ if err = res.Body.Close(); err != nil {
+ t.Fatalf("res.Body.Close() = %v, want nil", err)
+ }
+}
+
+func TestClientRedirectEatsBody_h1(t *testing.T) { run(t, testClientRedirectEatsBody) }
+func testClientRedirectEatsBody(t *testing.T, mode testMode) {
+ saw := make(chan string, 2)
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ saw <- r.RemoteAddr
+ if r.URL.Path == "/" {
+ Redirect(w, r, "/foo", StatusFound) // which includes a body
+ }
+ }))
+
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ _, err = io.ReadAll(res.Body)
+ res.Body.Close()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ var first string
+ select {
+ case first = <-saw:
+ default:
+ t.Fatal("server didn't see a request")
+ }
+
+ var second string
+ select {
+ case second = <-saw:
+ default:
+ t.Fatal("server didn't see a second request")
+ }
+
+ if first != second {
+ t.Fatal("server saw different client ports before & after the redirect")
+ }
+}
+
+// eofReaderFunc is an io.Reader that runs itself, and then returns io.EOF.
+type eofReaderFunc func()
+
+func (f eofReaderFunc) Read(p []byte) (n int, err error) {
+ f()
+ return 0, io.EOF
+}
+
+func TestReferer(t *testing.T) {
+ tests := []struct {
+ lastReq, newReq, explicitRef string // from -> to URLs, explicitly set Referer value
+ want string
+ }{
+ // don't send user:
+ {lastReq: "http://gopher@test.com", newReq: "http://link.com", want: "http://test.com"},
+ {lastReq: "https://gopher@test.com", newReq: "https://link.com", want: "https://test.com"},
+
+ // don't send a user and password:
+ {lastReq: "http://gopher:go@test.com", newReq: "http://link.com", want: "http://test.com"},
+ {lastReq: "https://gopher:go@test.com", newReq: "https://link.com", want: "https://test.com"},
+
+ // nothing to do:
+ {lastReq: "http://test.com", newReq: "http://link.com", want: "http://test.com"},
+ {lastReq: "https://test.com", newReq: "https://link.com", want: "https://test.com"},
+
+ // https to http doesn't send a referer:
+ {lastReq: "https://test.com", newReq: "http://link.com", want: ""},
+ {lastReq: "https://gopher:go@test.com", newReq: "http://link.com", want: ""},
+
+ // https to http should remove an existing referer:
+ {lastReq: "https://test.com", newReq: "http://link.com", explicitRef: "https://foo.com", want: ""},
+ {lastReq: "https://gopher:go@test.com", newReq: "http://link.com", explicitRef: "https://foo.com", want: ""},
+
+ // don't override an existing referer:
+ {lastReq: "https://test.com", newReq: "https://link.com", explicitRef: "https://foo.com", want: "https://foo.com"},
+ {lastReq: "https://gopher:go@test.com", newReq: "https://link.com", explicitRef: "https://foo.com", want: "https://foo.com"},
+ }
+ for _, tt := range tests {
+ l, err := url.Parse(tt.lastReq)
+ if err != nil {
+ t.Fatal(err)
+ }
+ n, err := url.Parse(tt.newReq)
+ if err != nil {
+ t.Fatal(err)
+ }
+ r := ExportRefererForURL(l, n, tt.explicitRef)
+ if r != tt.want {
+ t.Errorf("refererForURL(%q, %q) = %q; want %q", tt.lastReq, tt.newReq, r, tt.want)
+ }
+ }
+}
+
+// issue15577Tripper returns a Response with a redirect response
+// header and doesn't populate its Response.Request field.
+type issue15577Tripper struct{}
+
+func (issue15577Tripper) RoundTrip(*Request) (*Response, error) {
+ resp := &Response{
+ StatusCode: 303,
+ Header: map[string][]string{"Location": {"http://www.example.com/"}},
+ Body: io.NopCloser(strings.NewReader("")),
+ }
+ return resp, nil
+}
+
+// Issue 15577: don't assume the roundtripper's response populates its Request field.
+func TestClientRedirectResponseWithoutRequest(t *testing.T) {
+ c := &Client{
+ CheckRedirect: func(*Request, []*Request) error { return fmt.Errorf("no redirects!") },
+ Transport: issue15577Tripper{},
+ }
+ // Check that this doesn't crash:
+ c.Get("http://dummy.tld")
+}
+
+// Issue 4800: copy (some) headers when Client follows a redirect.
+// Issue 35104: Since both URLs have the same host (localhost)
+// but different ports, sensitive headers like Cookie and Authorization
+// are preserved.
+func TestClientCopyHeadersOnRedirect(t *testing.T) { run(t, testClientCopyHeadersOnRedirect) }
+func testClientCopyHeadersOnRedirect(t *testing.T, mode testMode) {
+ const (
+ ua = "some-agent/1.2"
+ xfoo = "foo-val"
+ )
+ var ts2URL string
+ ts1 := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ want := Header{
+ "User-Agent": []string{ua},
+ "X-Foo": []string{xfoo},
+ "Referer": []string{ts2URL},
+ "Accept-Encoding": []string{"gzip"},
+ "Cookie": []string{"foo=bar"},
+ "Authorization": []string{"secretpassword"},
+ }
+ if !reflect.DeepEqual(r.Header, want) {
+ t.Errorf("Request.Header = %#v; want %#v", r.Header, want)
+ }
+ if t.Failed() {
+ w.Header().Set("Result", "got errors")
+ } else {
+ w.Header().Set("Result", "ok")
+ }
+ })).ts
+ ts2 := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ Redirect(w, r, ts1.URL, StatusFound)
+ })).ts
+ ts2URL = ts2.URL
+
+ c := ts1.Client()
+ c.CheckRedirect = func(r *Request, via []*Request) error {
+ want := Header{
+ "User-Agent": []string{ua},
+ "X-Foo": []string{xfoo},
+ "Referer": []string{ts2URL},
+ "Cookie": []string{"foo=bar"},
+ "Authorization": []string{"secretpassword"},
+ }
+ if !reflect.DeepEqual(r.Header, want) {
+ t.Errorf("CheckRedirect Request.Header = %#v; want %#v", r.Header, want)
+ }
+ return nil
+ }
+
+ req, _ := NewRequest("GET", ts2.URL, nil)
+ req.Header.Add("User-Agent", ua)
+ req.Header.Add("X-Foo", xfoo)
+ req.Header.Add("Cookie", "foo=bar")
+ req.Header.Add("Authorization", "secretpassword")
+ res, err := c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ if res.StatusCode != 200 {
+ t.Fatal(res.Status)
+ }
+ if got := res.Header.Get("Result"); got != "ok" {
+ t.Errorf("result = %q; want ok", got)
+ }
+}
+
+// Issue 22233: copy host when Client follows a relative redirect.
+func TestClientCopyHostOnRedirect(t *testing.T) { run(t, testClientCopyHostOnRedirect) }
+func testClientCopyHostOnRedirect(t *testing.T, mode testMode) {
+ // Virtual hostname: should not receive any request.
+ virtual := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ t.Errorf("Virtual host received request %v", r.URL)
+ w.WriteHeader(403)
+ io.WriteString(w, "should not see this response")
+ })).ts
+ defer virtual.Close()
+ virtualHost := strings.TrimPrefix(virtual.URL, "http://")
+ virtualHost = strings.TrimPrefix(virtualHost, "https://")
+ t.Logf("Virtual host is %v", virtualHost)
+
+ // Actual hostname: should not receive any request.
+ const wantBody = "response body"
+ var tsURL string
+ var tsHost string
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ switch r.URL.Path {
+ case "/":
+ // Relative redirect.
+ if r.Host != virtualHost {
+ t.Errorf("Serving /: Request.Host = %#v; want %#v", r.Host, virtualHost)
+ w.WriteHeader(404)
+ return
+ }
+ w.Header().Set("Location", "/hop")
+ w.WriteHeader(302)
+ case "/hop":
+ // Absolute redirect.
+ if r.Host != virtualHost {
+ t.Errorf("Serving /hop: Request.Host = %#v; want %#v", r.Host, virtualHost)
+ w.WriteHeader(404)
+ return
+ }
+ w.Header().Set("Location", tsURL+"/final")
+ w.WriteHeader(302)
+ case "/final":
+ if r.Host != tsHost {
+ t.Errorf("Serving /final: Request.Host = %#v; want %#v", r.Host, tsHost)
+ w.WriteHeader(404)
+ return
+ }
+ w.WriteHeader(200)
+ io.WriteString(w, wantBody)
+ default:
+ t.Errorf("Serving unexpected path %q", r.URL.Path)
+ w.WriteHeader(404)
+ }
+ })).ts
+ tsURL = ts.URL
+ tsHost = strings.TrimPrefix(ts.URL, "http://")
+ tsHost = strings.TrimPrefix(tsHost, "https://")
+ t.Logf("Server host is %v", tsHost)
+
+ c := ts.Client()
+ req, _ := NewRequest("GET", ts.URL, nil)
+ req.Host = virtualHost
+ resp, err := c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode != 200 {
+ t.Fatal(resp.Status)
+ }
+ if got, err := io.ReadAll(resp.Body); err != nil || string(got) != wantBody {
+ t.Errorf("body = %q; want %q", got, wantBody)
+ }
+}
+
+// Issue 17494: cookies should be altered when Client follows redirects.
+func TestClientAltersCookiesOnRedirect(t *testing.T) { run(t, testClientAltersCookiesOnRedirect) }
+func testClientAltersCookiesOnRedirect(t *testing.T, mode testMode) {
+ cookieMap := func(cs []*Cookie) map[string][]string {
+ m := make(map[string][]string)
+ for _, c := range cs {
+ m[c.Name] = append(m[c.Name], c.Value)
+ }
+ return m
+ }
+
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ var want map[string][]string
+ got := cookieMap(r.Cookies())
+
+ c, _ := r.Cookie("Cycle")
+ switch c.Value {
+ case "0":
+ want = map[string][]string{
+ "Cookie1": {"OldValue1a", "OldValue1b"},
+ "Cookie2": {"OldValue2"},
+ "Cookie3": {"OldValue3a", "OldValue3b"},
+ "Cookie4": {"OldValue4"},
+ "Cycle": {"0"},
+ }
+ SetCookie(w, &Cookie{Name: "Cycle", Value: "1", Path: "/"})
+ SetCookie(w, &Cookie{Name: "Cookie2", Path: "/", MaxAge: -1}) // Delete cookie from Header
+ Redirect(w, r, "/", StatusFound)
+ case "1":
+ want = map[string][]string{
+ "Cookie1": {"OldValue1a", "OldValue1b"},
+ "Cookie3": {"OldValue3a", "OldValue3b"},
+ "Cookie4": {"OldValue4"},
+ "Cycle": {"1"},
+ }
+ SetCookie(w, &Cookie{Name: "Cycle", Value: "2", Path: "/"})
+ SetCookie(w, &Cookie{Name: "Cookie3", Value: "NewValue3", Path: "/"}) // Modify cookie in Header
+ SetCookie(w, &Cookie{Name: "Cookie4", Value: "NewValue4", Path: "/"}) // Modify cookie in Jar
+ Redirect(w, r, "/", StatusFound)
+ case "2":
+ want = map[string][]string{
+ "Cookie1": {"OldValue1a", "OldValue1b"},
+ "Cookie3": {"NewValue3"},
+ "Cookie4": {"NewValue4"},
+ "Cycle": {"2"},
+ }
+ SetCookie(w, &Cookie{Name: "Cycle", Value: "3", Path: "/"})
+ SetCookie(w, &Cookie{Name: "Cookie5", Value: "NewValue5", Path: "/"}) // Insert cookie into Jar
+ Redirect(w, r, "/", StatusFound)
+ case "3":
+ want = map[string][]string{
+ "Cookie1": {"OldValue1a", "OldValue1b"},
+ "Cookie3": {"NewValue3"},
+ "Cookie4": {"NewValue4"},
+ "Cookie5": {"NewValue5"},
+ "Cycle": {"3"},
+ }
+ // Don't redirect to ensure the loop ends.
+ default:
+ t.Errorf("unexpected redirect cycle")
+ return
+ }
+
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("redirect %s, Cookie = %v, want %v", c.Value, got, want)
+ }
+ })).ts
+
+ jar, _ := cookiejar.New(nil)
+ c := ts.Client()
+ c.Jar = jar
+
+ u, _ := url.Parse(ts.URL)
+ req, _ := NewRequest("GET", ts.URL, nil)
+ req.AddCookie(&Cookie{Name: "Cookie1", Value: "OldValue1a"})
+ req.AddCookie(&Cookie{Name: "Cookie1", Value: "OldValue1b"})
+ req.AddCookie(&Cookie{Name: "Cookie2", Value: "OldValue2"})
+ req.AddCookie(&Cookie{Name: "Cookie3", Value: "OldValue3a"})
+ req.AddCookie(&Cookie{Name: "Cookie3", Value: "OldValue3b"})
+ jar.SetCookies(u, []*Cookie{{Name: "Cookie4", Value: "OldValue4", Path: "/"}})
+ jar.SetCookies(u, []*Cookie{{Name: "Cycle", Value: "0", Path: "/"}})
+ res, err := c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ if res.StatusCode != 200 {
+ t.Fatal(res.Status)
+ }
+}
+
+// Part of Issue 4800
+func TestShouldCopyHeaderOnRedirect(t *testing.T) {
+ tests := []struct {
+ header string
+ initialURL string
+ destURL string
+ want bool
+ }{
+ {"User-Agent", "http://foo.com/", "http://bar.com/", true},
+ {"X-Foo", "http://foo.com/", "http://bar.com/", true},
+
+ // Sensitive headers:
+ {"cookie", "http://foo.com/", "http://bar.com/", false},
+ {"cookie2", "http://foo.com/", "http://bar.com/", false},
+ {"authorization", "http://foo.com/", "http://bar.com/", false},
+ {"authorization", "http://foo.com/", "https://foo.com/", true},
+ {"authorization", "http://foo.com:1234/", "http://foo.com:4321/", true},
+ {"www-authenticate", "http://foo.com/", "http://bar.com/", false},
+ {"authorization", "http://foo.com/", "http://[::1%25.foo.com]/", false},
+
+ // But subdomains should work:
+ {"www-authenticate", "http://foo.com/", "http://foo.com/", true},
+ {"www-authenticate", "http://foo.com/", "http://sub.foo.com/", true},
+ {"www-authenticate", "http://foo.com/", "http://notfoo.com/", false},
+ {"www-authenticate", "http://foo.com/", "https://foo.com/", true},
+ {"www-authenticate", "http://foo.com:80/", "http://foo.com/", true},
+ {"www-authenticate", "http://foo.com:80/", "http://sub.foo.com/", true},
+ {"www-authenticate", "http://foo.com:443/", "https://foo.com/", true},
+ {"www-authenticate", "http://foo.com:443/", "https://sub.foo.com/", true},
+ {"www-authenticate", "http://foo.com:1234/", "http://foo.com/", true},
+
+ {"authorization", "http://foo.com/", "http://foo.com/", true},
+ {"authorization", "http://foo.com/", "http://sub.foo.com/", true},
+ {"authorization", "http://foo.com/", "http://notfoo.com/", false},
+ {"authorization", "http://foo.com/", "https://foo.com/", true},
+ {"authorization", "http://foo.com:80/", "http://foo.com/", true},
+ {"authorization", "http://foo.com:80/", "http://sub.foo.com/", true},
+ {"authorization", "http://foo.com:443/", "https://foo.com/", true},
+ {"authorization", "http://foo.com:443/", "https://sub.foo.com/", true},
+ {"authorization", "http://foo.com:1234/", "http://foo.com/", true},
+ }
+ for i, tt := range tests {
+ u0, err := url.Parse(tt.initialURL)
+ if err != nil {
+ t.Errorf("%d. initial URL %q parse error: %v", i, tt.initialURL, err)
+ continue
+ }
+ u1, err := url.Parse(tt.destURL)
+ if err != nil {
+ t.Errorf("%d. dest URL %q parse error: %v", i, tt.destURL, err)
+ continue
+ }
+ got := Export_shouldCopyHeaderOnRedirect(tt.header, u0, u1)
+ if got != tt.want {
+ t.Errorf("%d. shouldCopyHeaderOnRedirect(%q, %q => %q) = %v; want %v",
+ i, tt.header, tt.initialURL, tt.destURL, got, tt.want)
+ }
+ }
+}
+
+func TestClientRedirectTypes(t *testing.T) { run(t, testClientRedirectTypes) }
+func testClientRedirectTypes(t *testing.T, mode testMode) {
+ tests := [...]struct {
+ method string
+ serverStatus int
+ wantMethod string // desired subsequent client method
+ }{
+ 0: {method: "POST", serverStatus: 301, wantMethod: "GET"},
+ 1: {method: "POST", serverStatus: 302, wantMethod: "GET"},
+ 2: {method: "POST", serverStatus: 303, wantMethod: "GET"},
+ 3: {method: "POST", serverStatus: 307, wantMethod: "POST"},
+ 4: {method: "POST", serverStatus: 308, wantMethod: "POST"},
+
+ 5: {method: "HEAD", serverStatus: 301, wantMethod: "HEAD"},
+ 6: {method: "HEAD", serverStatus: 302, wantMethod: "HEAD"},
+ 7: {method: "HEAD", serverStatus: 303, wantMethod: "HEAD"},
+ 8: {method: "HEAD", serverStatus: 307, wantMethod: "HEAD"},
+ 9: {method: "HEAD", serverStatus: 308, wantMethod: "HEAD"},
+
+ 10: {method: "GET", serverStatus: 301, wantMethod: "GET"},
+ 11: {method: "GET", serverStatus: 302, wantMethod: "GET"},
+ 12: {method: "GET", serverStatus: 303, wantMethod: "GET"},
+ 13: {method: "GET", serverStatus: 307, wantMethod: "GET"},
+ 14: {method: "GET", serverStatus: 308, wantMethod: "GET"},
+
+ 15: {method: "DELETE", serverStatus: 301, wantMethod: "GET"},
+ 16: {method: "DELETE", serverStatus: 302, wantMethod: "GET"},
+ 17: {method: "DELETE", serverStatus: 303, wantMethod: "GET"},
+ 18: {method: "DELETE", serverStatus: 307, wantMethod: "DELETE"},
+ 19: {method: "DELETE", serverStatus: 308, wantMethod: "DELETE"},
+
+ 20: {method: "PUT", serverStatus: 301, wantMethod: "GET"},
+ 21: {method: "PUT", serverStatus: 302, wantMethod: "GET"},
+ 22: {method: "PUT", serverStatus: 303, wantMethod: "GET"},
+ 23: {method: "PUT", serverStatus: 307, wantMethod: "PUT"},
+ 24: {method: "PUT", serverStatus: 308, wantMethod: "PUT"},
+
+ 25: {method: "MADEUPMETHOD", serverStatus: 301, wantMethod: "GET"},
+ 26: {method: "MADEUPMETHOD", serverStatus: 302, wantMethod: "GET"},
+ 27: {method: "MADEUPMETHOD", serverStatus: 303, wantMethod: "GET"},
+ 28: {method: "MADEUPMETHOD", serverStatus: 307, wantMethod: "MADEUPMETHOD"},
+ 29: {method: "MADEUPMETHOD", serverStatus: 308, wantMethod: "MADEUPMETHOD"},
+ }
+
+ handlerc := make(chan HandlerFunc, 1)
+
+ ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
+ h := <-handlerc
+ h(rw, req)
+ })).ts
+
+ c := ts.Client()
+ for i, tt := range tests {
+ handlerc <- func(w ResponseWriter, r *Request) {
+ w.Header().Set("Location", ts.URL)
+ w.WriteHeader(tt.serverStatus)
+ }
+
+ req, err := NewRequest(tt.method, ts.URL, nil)
+ if err != nil {
+ t.Errorf("#%d: NewRequest: %v", i, err)
+ continue
+ }
+
+ c.CheckRedirect = func(req *Request, via []*Request) error {
+ if got, want := req.Method, tt.wantMethod; got != want {
+ return fmt.Errorf("#%d: got next method %q; want %q", i, got, want)
+ }
+ handlerc <- func(rw ResponseWriter, req *Request) {
+ // TODO: Check that the body is valid when we do 307 and 308 support
+ }
+ return nil
+ }
+
+ res, err := c.Do(req)
+ if err != nil {
+ t.Errorf("#%d: Response: %v", i, err)
+ continue
+ }
+
+ res.Body.Close()
+ }
+}
+
+// issue18239Body is an io.ReadCloser for TestTransportBodyReadError.
+// Its Read returns readErr and increments *readCalls atomically.
+// Its Close returns nil and increments *closeCalls atomically.
+type issue18239Body struct {
+ readCalls *int32
+ closeCalls *int32
+ readErr error
+}
+
+func (b issue18239Body) Read([]byte) (int, error) {
+ atomic.AddInt32(b.readCalls, 1)
+ return 0, b.readErr
+}
+
+func (b issue18239Body) Close() error {
+ atomic.AddInt32(b.closeCalls, 1)
+ return nil
+}
+
+// Issue 18239: make sure the Transport doesn't retry requests with bodies
+// if Request.GetBody is not defined.
+func TestTransportBodyReadError(t *testing.T) { run(t, testTransportBodyReadError) }
+func testTransportBodyReadError(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.URL.Path == "/ping" {
+ return
+ }
+ buf := make([]byte, 1)
+ n, err := r.Body.Read(buf)
+ w.Header().Set("X-Body-Read", fmt.Sprintf("%v, %v", n, err))
+ })).ts
+ c := ts.Client()
+ tr := c.Transport.(*Transport)
+
+ // Do one initial successful request to create an idle TCP connection
+ // for the subsequent request to reuse. (The Transport only retries
+ // requests on reused connections.)
+ res, err := c.Get(ts.URL + "/ping")
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+
+ var readCallsAtomic int32
+ var closeCallsAtomic int32 // atomic
+ someErr := errors.New("some body read error")
+ body := issue18239Body{&readCallsAtomic, &closeCallsAtomic, someErr}
+
+ req, err := NewRequest("POST", ts.URL, body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ req = req.WithT(t)
+ _, err = tr.RoundTrip(req)
+ if err != someErr {
+ t.Errorf("Got error: %v; want Request.Body read error: %v", err, someErr)
+ }
+
+ // And verify that our Body wasn't used multiple times, which
+ // would indicate retries. (as it buggily was during part of
+ // Go 1.8's dev cycle)
+ readCalls := atomic.LoadInt32(&readCallsAtomic)
+ closeCalls := atomic.LoadInt32(&closeCallsAtomic)
+ if readCalls != 1 {
+ t.Errorf("read calls = %d; want 1", readCalls)
+ }
+ if closeCalls != 1 {
+ t.Errorf("close calls = %d; want 1", closeCalls)
+ }
+}
+
+type roundTripperWithoutCloseIdle struct{}
+
+func (roundTripperWithoutCloseIdle) RoundTrip(*Request) (*Response, error) { panic("unused") }
+
+type roundTripperWithCloseIdle func() // underlying func is CloseIdleConnections func
+
+func (roundTripperWithCloseIdle) RoundTrip(*Request) (*Response, error) { panic("unused") }
+func (f roundTripperWithCloseIdle) CloseIdleConnections() { f() }
+
+func TestClientCloseIdleConnections(t *testing.T) {
+ c := &Client{Transport: roundTripperWithoutCloseIdle{}}
+ c.CloseIdleConnections() // verify we don't crash at least
+
+ closed := false
+ var tr RoundTripper = roundTripperWithCloseIdle(func() {
+ closed = true
+ })
+ c = &Client{Transport: tr}
+ c.CloseIdleConnections()
+ if !closed {
+ t.Error("not closed")
+ }
+}
+
+func TestClientPropagatesTimeoutToContext(t *testing.T) {
+ errDial := errors.New("not actually dialing")
+ c := &Client{
+ Timeout: 5 * time.Second,
+ Transport: &Transport{
+ DialContext: func(ctx context.Context, netw, addr string) (net.Conn, error) {
+ deadline, ok := ctx.Deadline()
+ if !ok {
+ t.Error("no deadline")
+ } else {
+ t.Logf("deadline in %v", deadline.Sub(time.Now()).Round(time.Second/10))
+ }
+ return nil, errDial
+ },
+ },
+ }
+ c.Get("https://example.tld/")
+}
+
+// Issue 33545: lock-in the behavior promised by Client.Do's
+// docs about request cancellation vs timing out.
+func TestClientDoCanceledVsTimeout(t *testing.T) { run(t, testClientDoCanceledVsTimeout) }
+func testClientDoCanceledVsTimeout(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Write([]byte("Hello, World!"))
+ }))
+
+ cases := []string{"timeout", "canceled"}
+
+ for _, name := range cases {
+ t.Run(name, func(t *testing.T) {
+ var ctx context.Context
+ var cancel func()
+ if name == "timeout" {
+ ctx, cancel = context.WithTimeout(context.Background(), -time.Nanosecond)
+ } else {
+ ctx, cancel = context.WithCancel(context.Background())
+ cancel()
+ }
+ defer cancel()
+
+ req, _ := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
+ _, err := cst.c.Do(req)
+ if err == nil {
+ t.Fatal("Unexpectedly got a nil error")
+ }
+
+ ue := err.(*url.Error)
+
+ var wantIsTimeout bool
+ var wantErr error = context.Canceled
+ if name == "timeout" {
+ wantErr = context.DeadlineExceeded
+ wantIsTimeout = true
+ }
+ if g, w := ue.Timeout(), wantIsTimeout; g != w {
+ t.Fatalf("url.Timeout() = %t, want %t", g, w)
+ }
+ if g, w := ue.Err, wantErr; g != w {
+ t.Errorf("url.Error.Err = %v; want %v", g, w)
+ }
+ })
+ }
+}
+
+type nilBodyRoundTripper struct{}
+
+func (nilBodyRoundTripper) RoundTrip(req *Request) (*Response, error) {
+ return &Response{
+ StatusCode: StatusOK,
+ Status: StatusText(StatusOK),
+ Body: nil,
+ Request: req,
+ }, nil
+}
+
+func TestClientPopulatesNilResponseBody(t *testing.T) {
+ c := &Client{Transport: nilBodyRoundTripper{}}
+
+ resp, err := c.Get("http://localhost/anything")
+ if err != nil {
+ t.Fatalf("Client.Get rejected Response with nil Body: %v", err)
+ }
+
+ if resp.Body == nil {
+ t.Fatalf("Client failed to provide a non-nil Body as documented")
+ }
+ defer func() {
+ if err := resp.Body.Close(); err != nil {
+ t.Fatalf("error from Close on substitute Response.Body: %v", err)
+ }
+ }()
+
+ if b, err := io.ReadAll(resp.Body); err != nil {
+ t.Errorf("read error from substitute Response.Body: %v", err)
+ } else if len(b) != 0 {
+ t.Errorf("substitute Response.Body was unexpectedly non-empty: %q", b)
+ }
+}
+
+// Issue 40382: Client calls Close multiple times on Request.Body.
+func TestClientCallsCloseOnlyOnce(t *testing.T) { run(t, testClientCallsCloseOnlyOnce) }
+func testClientCallsCloseOnlyOnce(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.WriteHeader(StatusNoContent)
+ }))
+
+ // Issue occurred non-deterministically: needed to occur after a successful
+ // write (into TCP buffer) but before end of body.
+ for i := 0; i < 50 && !t.Failed(); i++ {
+ body := &issue40382Body{t: t, n: 300000}
+ req, err := NewRequest(MethodPost, cst.ts.URL, body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ resp, err := cst.tr.RoundTrip(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ resp.Body.Close()
+ }
+}
+
+// issue40382Body is an io.ReadCloser for TestClientCallsCloseOnlyOnce.
+// Its Read reads n bytes before returning io.EOF.
+// Its Close returns nil but fails the test if called more than once.
+type issue40382Body struct {
+ t *testing.T
+ n int
+ closeCallsAtomic int32
+}
+
+func (b *issue40382Body) Read(p []byte) (int, error) {
+ switch {
+ case b.n == 0:
+ return 0, io.EOF
+ case b.n < len(p):
+ p = p[:b.n]
+ fallthrough
+ default:
+ for i := range p {
+ p[i] = 'x'
+ }
+ b.n -= len(p)
+ return len(p), nil
+ }
+}
+
+func (b *issue40382Body) Close() error {
+ if atomic.AddInt32(&b.closeCallsAtomic, 1) == 2 {
+ b.t.Error("Body closed more than once")
+ }
+ return nil
+}
+
+func TestProbeZeroLengthBody(t *testing.T) { run(t, testProbeZeroLengthBody) }
+func testProbeZeroLengthBody(t *testing.T, mode testMode) {
+ reqc := make(chan struct{})
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ close(reqc)
+ if _, err := io.Copy(w, r.Body); err != nil {
+ t.Errorf("error copying request body: %v", err)
+ }
+ }))
+
+ bodyr, bodyw := io.Pipe()
+ var gotBody string
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ req, _ := NewRequest("GET", cst.ts.URL, bodyr)
+ res, err := cst.c.Do(req)
+ b, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Error(err)
+ }
+ gotBody = string(b)
+ }()
+
+ select {
+ case <-reqc:
+ // Request should be sent after trying to probe the request body for 200ms.
+ case <-time.After(60 * time.Second):
+ t.Errorf("request not sent after 60s")
+ }
+
+ // Write the request body and wait for the request to complete.
+ const content = "body"
+ bodyw.Write([]byte(content))
+ bodyw.Close()
+ wg.Wait()
+ if gotBody != content {
+ t.Fatalf("server got body %q, want %q", gotBody, content)
+ }
+}
diff --git a/src/net/http/clientserver_test.go b/src/net/http/clientserver_test.go
new file mode 100644
index 0000000..5832153
--- /dev/null
+++ b/src/net/http/clientserver_test.go
@@ -0,0 +1,1760 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Tests that use both the client & server, in both HTTP/1 and HTTP/2 mode.
+
+package http_test
+
+import (
+ "bytes"
+ "compress/gzip"
+ "context"
+ "crypto/rand"
+ "crypto/sha1"
+ "crypto/tls"
+ "fmt"
+ "hash"
+ "io"
+ "log"
+ "net"
+ . "net/http"
+ "net/http/httptest"
+ "net/http/httptrace"
+ "net/http/httputil"
+ "net/textproto"
+ "net/url"
+ "os"
+ "reflect"
+ "runtime"
+ "sort"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+)
+
+type testMode string
+
+const (
+ http1Mode = testMode("h1") // HTTP/1.1
+ https1Mode = testMode("https1") // HTTPS/1.1
+ http2Mode = testMode("h2") // HTTP/2
+)
+
+type testNotParallelOpt struct{}
+
+var (
+ testNotParallel = testNotParallelOpt{}
+)
+
+type TBRun[T any] interface {
+ testing.TB
+ Run(string, func(T)) bool
+}
+
+// run runs a client/server test in a variety of test configurations.
+//
+// Tests execute in HTTP/1.1 and HTTP/2 modes by default.
+// To run in a different set of configurations, pass a []testMode option.
+//
+// Tests call t.Parallel() by default.
+// To disable parallel execution, pass the testNotParallel option.
+func run[T TBRun[T]](t T, f func(t T, mode testMode), opts ...any) {
+ t.Helper()
+ modes := []testMode{http1Mode, http2Mode}
+ parallel := true
+ for _, opt := range opts {
+ switch opt := opt.(type) {
+ case []testMode:
+ modes = opt
+ case testNotParallelOpt:
+ parallel = false
+ default:
+ t.Fatalf("unknown option type %T", opt)
+ }
+ }
+ if t, ok := any(t).(*testing.T); ok && parallel {
+ setParallel(t)
+ }
+ for _, mode := range modes {
+ t.Run(string(mode), func(t T) {
+ t.Helper()
+ if t, ok := any(t).(*testing.T); ok && parallel {
+ setParallel(t)
+ }
+ t.Cleanup(func() {
+ afterTest(t)
+ })
+ f(t, mode)
+ })
+ }
+}
+
+type clientServerTest struct {
+ t testing.TB
+ h2 bool
+ h Handler
+ ts *httptest.Server
+ tr *Transport
+ c *Client
+}
+
+func (t *clientServerTest) close() {
+ t.tr.CloseIdleConnections()
+ t.ts.Close()
+}
+
+func (t *clientServerTest) getURL(u string) string {
+ res, err := t.c.Get(u)
+ if err != nil {
+ t.t.Fatal(err)
+ }
+ defer res.Body.Close()
+ slurp, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.t.Fatal(err)
+ }
+ return string(slurp)
+}
+
+func (t *clientServerTest) scheme() string {
+ if t.h2 {
+ return "https"
+ }
+ return "http"
+}
+
+var optQuietLog = func(ts *httptest.Server) {
+ ts.Config.ErrorLog = quietLog
+}
+
+func optWithServerLog(lg *log.Logger) func(*httptest.Server) {
+ return func(ts *httptest.Server) {
+ ts.Config.ErrorLog = lg
+ }
+}
+
+// newClientServerTest creates and starts an httptest.Server.
+//
+// The mode parameter selects the implementation to test:
+// HTTP/1, HTTP/2, etc. Tests using newClientServerTest should use
+// the 'run' function, which will start a subtests for each tested mode.
+//
+// The vararg opts parameter can include functions to configure the
+// test server or transport.
+//
+// func(*httptest.Server) // run before starting the server
+// func(*http.Transport)
+func newClientServerTest(t testing.TB, mode testMode, h Handler, opts ...any) *clientServerTest {
+ if mode == http2Mode {
+ CondSkipHTTP2(t)
+ }
+ cst := &clientServerTest{
+ t: t,
+ h2: mode == http2Mode,
+ h: h,
+ }
+ cst.ts = httptest.NewUnstartedServer(h)
+
+ var transportFuncs []func(*Transport)
+ for _, opt := range opts {
+ switch opt := opt.(type) {
+ case func(*Transport):
+ transportFuncs = append(transportFuncs, opt)
+ case func(*httptest.Server):
+ opt(cst.ts)
+ default:
+ t.Fatalf("unhandled option type %T", opt)
+ }
+ }
+
+ if cst.ts.Config.ErrorLog == nil {
+ cst.ts.Config.ErrorLog = log.New(testLogWriter{t}, "", 0)
+ }
+
+ switch mode {
+ case http1Mode:
+ cst.ts.Start()
+ case https1Mode:
+ cst.ts.StartTLS()
+ case http2Mode:
+ ExportHttp2ConfigureServer(cst.ts.Config, nil)
+ cst.ts.TLS = cst.ts.Config.TLSConfig
+ cst.ts.StartTLS()
+ default:
+ t.Fatalf("unknown test mode %v", mode)
+ }
+ cst.c = cst.ts.Client()
+ cst.tr = cst.c.Transport.(*Transport)
+ if mode == http2Mode {
+ if err := ExportHttp2ConfigureTransport(cst.tr); err != nil {
+ t.Fatal(err)
+ }
+ }
+ for _, f := range transportFuncs {
+ f(cst.tr)
+ }
+ t.Cleanup(func() {
+ cst.close()
+ })
+ return cst
+}
+
+type testLogWriter struct {
+ t testing.TB
+}
+
+func (w testLogWriter) Write(b []byte) (int, error) {
+ w.t.Logf("server log: %v", strings.TrimSpace(string(b)))
+ return len(b), nil
+}
+
+// Testing the newClientServerTest helper itself.
+func TestNewClientServerTest(t *testing.T) {
+ run(t, testNewClientServerTest, []testMode{http1Mode, https1Mode, http2Mode})
+}
+func testNewClientServerTest(t *testing.T, mode testMode) {
+ var got struct {
+ sync.Mutex
+ proto string
+ hasTLS bool
+ }
+ h := HandlerFunc(func(w ResponseWriter, r *Request) {
+ got.Lock()
+ defer got.Unlock()
+ got.proto = r.Proto
+ got.hasTLS = r.TLS != nil
+ })
+ cst := newClientServerTest(t, mode, h)
+ if _, err := cst.c.Head(cst.ts.URL); err != nil {
+ t.Fatal(err)
+ }
+ var wantProto string
+ var wantTLS bool
+ switch mode {
+ case http1Mode:
+ wantProto = "HTTP/1.1"
+ wantTLS = false
+ case https1Mode:
+ wantProto = "HTTP/1.1"
+ wantTLS = true
+ case http2Mode:
+ wantProto = "HTTP/2.0"
+ wantTLS = true
+ }
+ if got.proto != wantProto {
+ t.Errorf("req.Proto = %q, want %q", got.proto, wantProto)
+ }
+ if got.hasTLS != wantTLS {
+ t.Errorf("req.TLS set: %v, want %v", got.hasTLS, wantTLS)
+ }
+}
+
+func TestChunkedResponseHeaders(t *testing.T) { run(t, testChunkedResponseHeaders) }
+func testChunkedResponseHeaders(t *testing.T, mode testMode) {
+ log.SetOutput(io.Discard) // is noisy otherwise
+ defer log.SetOutput(os.Stderr)
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Length", "intentional gibberish") // we check that this is deleted
+ w.(Flusher).Flush()
+ fmt.Fprintf(w, "I am a chunked response.")
+ }))
+
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Fatalf("Get error: %v", err)
+ }
+ defer res.Body.Close()
+ if g, e := res.ContentLength, int64(-1); g != e {
+ t.Errorf("expected ContentLength of %d; got %d", e, g)
+ }
+ wantTE := []string{"chunked"}
+ if mode == http2Mode {
+ wantTE = nil
+ }
+ if !reflect.DeepEqual(res.TransferEncoding, wantTE) {
+ t.Errorf("TransferEncoding = %v; want %v", res.TransferEncoding, wantTE)
+ }
+ if got, haveCL := res.Header["Content-Length"]; haveCL {
+ t.Errorf("Unexpected Content-Length: %q", got)
+ }
+}
+
+type reqFunc func(c *Client, url string) (*Response, error)
+
+// h12Compare is a test that compares HTTP/1 and HTTP/2 behavior
+// against each other.
+type h12Compare struct {
+ Handler func(ResponseWriter, *Request) // required
+ ReqFunc reqFunc // optional
+ CheckResponse func(proto string, res *Response) // optional
+ EarlyCheckResponse func(proto string, res *Response) // optional; pre-normalize
+ Opts []any
+}
+
+func (tt h12Compare) reqFunc() reqFunc {
+ if tt.ReqFunc == nil {
+ return (*Client).Get
+ }
+ return tt.ReqFunc
+}
+
+func (tt h12Compare) run(t *testing.T) {
+ setParallel(t)
+ cst1 := newClientServerTest(t, http1Mode, HandlerFunc(tt.Handler), tt.Opts...)
+ defer cst1.close()
+ cst2 := newClientServerTest(t, http2Mode, HandlerFunc(tt.Handler), tt.Opts...)
+ defer cst2.close()
+
+ res1, err := tt.reqFunc()(cst1.c, cst1.ts.URL)
+ if err != nil {
+ t.Errorf("HTTP/1 request: %v", err)
+ return
+ }
+ res2, err := tt.reqFunc()(cst2.c, cst2.ts.URL)
+ if err != nil {
+ t.Errorf("HTTP/2 request: %v", err)
+ return
+ }
+
+ if fn := tt.EarlyCheckResponse; fn != nil {
+ fn("HTTP/1.1", res1)
+ fn("HTTP/2.0", res2)
+ }
+
+ tt.normalizeRes(t, res1, "HTTP/1.1")
+ tt.normalizeRes(t, res2, "HTTP/2.0")
+ res1body, res2body := res1.Body, res2.Body
+
+ eres1 := mostlyCopy(res1)
+ eres2 := mostlyCopy(res2)
+ if !reflect.DeepEqual(eres1, eres2) {
+ t.Errorf("Response headers to handler differed:\nhttp/1 (%v):\n\t%#v\nhttp/2 (%v):\n\t%#v",
+ cst1.ts.URL, eres1, cst2.ts.URL, eres2)
+ }
+ if !reflect.DeepEqual(res1body, res2body) {
+ t.Errorf("Response bodies to handler differed.\nhttp1: %v\nhttp2: %v\n", res1body, res2body)
+ }
+ if fn := tt.CheckResponse; fn != nil {
+ res1.Body, res2.Body = res1body, res2body
+ fn("HTTP/1.1", res1)
+ fn("HTTP/2.0", res2)
+ }
+}
+
+func mostlyCopy(r *Response) *Response {
+ c := *r
+ c.Body = nil
+ c.TransferEncoding = nil
+ c.TLS = nil
+ c.Request = nil
+ return &c
+}
+
+type slurpResult struct {
+ io.ReadCloser
+ body []byte
+ err error
+}
+
+func (sr slurpResult) String() string { return fmt.Sprintf("body %q; err %v", sr.body, sr.err) }
+
+func (tt h12Compare) normalizeRes(t *testing.T, res *Response, wantProto string) {
+ if res.Proto == wantProto || res.Proto == "HTTP/IGNORE" {
+ res.Proto, res.ProtoMajor, res.ProtoMinor = "", 0, 0
+ } else {
+ t.Errorf("got %q response; want %q", res.Proto, wantProto)
+ }
+ slurp, err := io.ReadAll(res.Body)
+
+ res.Body.Close()
+ res.Body = slurpResult{
+ ReadCloser: io.NopCloser(bytes.NewReader(slurp)),
+ body: slurp,
+ err: err,
+ }
+ for i, v := range res.Header["Date"] {
+ res.Header["Date"][i] = strings.Repeat("x", len(v))
+ }
+ if res.Request == nil {
+ t.Errorf("for %s, no request", wantProto)
+ }
+ if (res.TLS != nil) != (wantProto == "HTTP/2.0") {
+ t.Errorf("TLS set = %v; want %v", res.TLS != nil, res.TLS == nil)
+ }
+}
+
+// Issue 13532
+func TestH12_HeadContentLengthNoBody(t *testing.T) {
+ h12Compare{
+ ReqFunc: (*Client).Head,
+ Handler: func(w ResponseWriter, r *Request) {
+ },
+ }.run(t)
+}
+
+func TestH12_HeadContentLengthSmallBody(t *testing.T) {
+ h12Compare{
+ ReqFunc: (*Client).Head,
+ Handler: func(w ResponseWriter, r *Request) {
+ io.WriteString(w, "small")
+ },
+ }.run(t)
+}
+
+func TestH12_HeadContentLengthLargeBody(t *testing.T) {
+ h12Compare{
+ ReqFunc: (*Client).Head,
+ Handler: func(w ResponseWriter, r *Request) {
+ chunk := strings.Repeat("x", 512<<10)
+ for i := 0; i < 10; i++ {
+ io.WriteString(w, chunk)
+ }
+ },
+ }.run(t)
+}
+
+func TestH12_200NoBody(t *testing.T) {
+ h12Compare{Handler: func(w ResponseWriter, r *Request) {}}.run(t)
+}
+
+func TestH2_204NoBody(t *testing.T) { testH12_noBody(t, 204) }
+func TestH2_304NoBody(t *testing.T) { testH12_noBody(t, 304) }
+func TestH2_404NoBody(t *testing.T) { testH12_noBody(t, 404) }
+
+func testH12_noBody(t *testing.T, status int) {
+ h12Compare{Handler: func(w ResponseWriter, r *Request) {
+ w.WriteHeader(status)
+ }}.run(t)
+}
+
+func TestH12_SmallBody(t *testing.T) {
+ h12Compare{Handler: func(w ResponseWriter, r *Request) {
+ io.WriteString(w, "small body")
+ }}.run(t)
+}
+
+func TestH12_ExplicitContentLength(t *testing.T) {
+ h12Compare{Handler: func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Length", "3")
+ io.WriteString(w, "foo")
+ }}.run(t)
+}
+
+func TestH12_FlushBeforeBody(t *testing.T) {
+ h12Compare{Handler: func(w ResponseWriter, r *Request) {
+ w.(Flusher).Flush()
+ io.WriteString(w, "foo")
+ }}.run(t)
+}
+
+func TestH12_FlushMidBody(t *testing.T) {
+ h12Compare{Handler: func(w ResponseWriter, r *Request) {
+ io.WriteString(w, "foo")
+ w.(Flusher).Flush()
+ io.WriteString(w, "bar")
+ }}.run(t)
+}
+
+func TestH12_Head_ExplicitLen(t *testing.T) {
+ h12Compare{
+ ReqFunc: (*Client).Head,
+ Handler: func(w ResponseWriter, r *Request) {
+ if r.Method != "HEAD" {
+ t.Errorf("unexpected method %q", r.Method)
+ }
+ w.Header().Set("Content-Length", "1235")
+ },
+ }.run(t)
+}
+
+func TestH12_Head_ImplicitLen(t *testing.T) {
+ h12Compare{
+ ReqFunc: (*Client).Head,
+ Handler: func(w ResponseWriter, r *Request) {
+ if r.Method != "HEAD" {
+ t.Errorf("unexpected method %q", r.Method)
+ }
+ io.WriteString(w, "foo")
+ },
+ }.run(t)
+}
+
+func TestH12_HandlerWritesTooLittle(t *testing.T) {
+ h12Compare{
+ Handler: func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Length", "3")
+ io.WriteString(w, "12") // one byte short
+ },
+ CheckResponse: func(proto string, res *Response) {
+ sr, ok := res.Body.(slurpResult)
+ if !ok {
+ t.Errorf("%s body is %T; want slurpResult", proto, res.Body)
+ return
+ }
+ if sr.err != io.ErrUnexpectedEOF {
+ t.Errorf("%s read error = %v; want io.ErrUnexpectedEOF", proto, sr.err)
+ }
+ if string(sr.body) != "12" {
+ t.Errorf("%s body = %q; want %q", proto, sr.body, "12")
+ }
+ },
+ }.run(t)
+}
+
+// Tests that the HTTP/1 and HTTP/2 servers prevent handlers from
+// writing more than they declared. This test does not test whether
+// the transport deals with too much data, though, since the server
+// doesn't make it possible to send bogus data. For those tests, see
+// transport_test.go (for HTTP/1) or x/net/http2/transport_test.go
+// (for HTTP/2).
+func TestH12_HandlerWritesTooMuch(t *testing.T) {
+ h12Compare{
+ Handler: func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Length", "3")
+ w.(Flusher).Flush()
+ io.WriteString(w, "123")
+ w.(Flusher).Flush()
+ n, err := io.WriteString(w, "x") // too many
+ if n > 0 || err == nil {
+ t.Errorf("for proto %q, final write = %v, %v; want 0, some error", r.Proto, n, err)
+ }
+ },
+ }.run(t)
+}
+
+// Verify that both our HTTP/1 and HTTP/2 request and auto-decompress gzip.
+// Some hosts send gzip even if you don't ask for it; see golang.org/issue/13298
+func TestH12_AutoGzip(t *testing.T) {
+ h12Compare{
+ Handler: func(w ResponseWriter, r *Request) {
+ if ae := r.Header.Get("Accept-Encoding"); ae != "gzip" {
+ t.Errorf("%s Accept-Encoding = %q; want gzip", r.Proto, ae)
+ }
+ w.Header().Set("Content-Encoding", "gzip")
+ gz := gzip.NewWriter(w)
+ io.WriteString(gz, "I am some gzipped content. Go go go go go go go go go go go go should compress well.")
+ gz.Close()
+ },
+ }.run(t)
+}
+
+func TestH12_AutoGzip_Disabled(t *testing.T) {
+ h12Compare{
+ Opts: []any{
+ func(tr *Transport) { tr.DisableCompression = true },
+ },
+ Handler: func(w ResponseWriter, r *Request) {
+ fmt.Fprintf(w, "%q", r.Header["Accept-Encoding"])
+ if ae := r.Header.Get("Accept-Encoding"); ae != "" {
+ t.Errorf("%s Accept-Encoding = %q; want empty", r.Proto, ae)
+ }
+ },
+ }.run(t)
+}
+
+// Test304Responses verifies that 304s don't declare that they're
+// chunking in their response headers and aren't allowed to produce
+// output.
+func Test304Responses(t *testing.T) { run(t, test304Responses) }
+func test304Responses(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.WriteHeader(StatusNotModified)
+ _, err := w.Write([]byte("illegal body"))
+ if err != ErrBodyNotAllowed {
+ t.Errorf("on Write, expected ErrBodyNotAllowed, got %v", err)
+ }
+ }))
+ defer cst.close()
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(res.TransferEncoding) > 0 {
+ t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding)
+ }
+ body, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Error(err)
+ }
+ if len(body) > 0 {
+ t.Errorf("got unexpected body %q", string(body))
+ }
+}
+
+func TestH12_ServerEmptyContentLength(t *testing.T) {
+ h12Compare{
+ Handler: func(w ResponseWriter, r *Request) {
+ w.Header()["Content-Type"] = []string{""}
+ io.WriteString(w, "<html><body>hi</body></html>")
+ },
+ }.run(t)
+}
+
+func TestH12_RequestContentLength_Known_NonZero(t *testing.T) {
+ h12requestContentLength(t, func() io.Reader { return strings.NewReader("FOUR") }, 4)
+}
+
+func TestH12_RequestContentLength_Known_Zero(t *testing.T) {
+ h12requestContentLength(t, func() io.Reader { return nil }, 0)
+}
+
+func TestH12_RequestContentLength_Unknown(t *testing.T) {
+ h12requestContentLength(t, func() io.Reader { return struct{ io.Reader }{strings.NewReader("Stuff")} }, -1)
+}
+
+func h12requestContentLength(t *testing.T, bodyfn func() io.Reader, wantLen int64) {
+ h12Compare{
+ Handler: func(w ResponseWriter, r *Request) {
+ w.Header().Set("Got-Length", fmt.Sprint(r.ContentLength))
+ fmt.Fprintf(w, "Req.ContentLength=%v", r.ContentLength)
+ },
+ ReqFunc: func(c *Client, url string) (*Response, error) {
+ return c.Post(url, "text/plain", bodyfn())
+ },
+ CheckResponse: func(proto string, res *Response) {
+ if got, want := res.Header.Get("Got-Length"), fmt.Sprint(wantLen); got != want {
+ t.Errorf("Proto %q got length %q; want %q", proto, got, want)
+ }
+ },
+ }.run(t)
+}
+
+// Tests that closing the Request.Cancel channel also while still
+// reading the response body. Issue 13159.
+func TestCancelRequestMidBody(t *testing.T) { run(t, testCancelRequestMidBody) }
+func testCancelRequestMidBody(t *testing.T, mode testMode) {
+ unblock := make(chan bool)
+ didFlush := make(chan bool, 1)
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ io.WriteString(w, "Hello")
+ w.(Flusher).Flush()
+ didFlush <- true
+ <-unblock
+ io.WriteString(w, ", world.")
+ }))
+ defer close(unblock)
+
+ req, _ := NewRequest("GET", cst.ts.URL, nil)
+ cancel := make(chan struct{})
+ req.Cancel = cancel
+
+ res, err := cst.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ <-didFlush
+
+ // Read a bit before we cancel. (Issue 13626)
+ // We should have "Hello" at least sitting there.
+ firstRead := make([]byte, 10)
+ n, err := res.Body.Read(firstRead)
+ if err != nil {
+ t.Fatal(err)
+ }
+ firstRead = firstRead[:n]
+
+ close(cancel)
+
+ rest, err := io.ReadAll(res.Body)
+ all := string(firstRead) + string(rest)
+ if all != "Hello" {
+ t.Errorf("Read %q (%q + %q); want Hello", all, firstRead, rest)
+ }
+ if err != ExportErrRequestCanceled {
+ t.Errorf("ReadAll error = %v; want %v", err, ExportErrRequestCanceled)
+ }
+}
+
+// Tests that clients can send trailers to a server and that the server can read them.
+func TestTrailersClientToServer(t *testing.T) { run(t, testTrailersClientToServer) }
+func testTrailersClientToServer(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ var decl []string
+ for k := range r.Trailer {
+ decl = append(decl, k)
+ }
+ sort.Strings(decl)
+
+ slurp, err := io.ReadAll(r.Body)
+ if err != nil {
+ t.Errorf("Server reading request body: %v", err)
+ }
+ if string(slurp) != "foo" {
+ t.Errorf("Server read request body %q; want foo", slurp)
+ }
+ if r.Trailer == nil {
+ io.WriteString(w, "nil Trailer")
+ } else {
+ fmt.Fprintf(w, "decl: %v, vals: %s, %s",
+ decl,
+ r.Trailer.Get("Client-Trailer-A"),
+ r.Trailer.Get("Client-Trailer-B"))
+ }
+ }))
+
+ var req *Request
+ req, _ = NewRequest("POST", cst.ts.URL, io.MultiReader(
+ eofReaderFunc(func() {
+ req.Trailer["Client-Trailer-A"] = []string{"valuea"}
+ }),
+ strings.NewReader("foo"),
+ eofReaderFunc(func() {
+ req.Trailer["Client-Trailer-B"] = []string{"valueb"}
+ }),
+ ))
+ req.Trailer = Header{
+ "Client-Trailer-A": nil, // to be set later
+ "Client-Trailer-B": nil, // to be set later
+ }
+ req.ContentLength = -1
+ res, err := cst.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if err := wantBody(res, err, "decl: [Client-Trailer-A Client-Trailer-B], vals: valuea, valueb"); err != nil {
+ t.Error(err)
+ }
+}
+
+// Tests that servers send trailers to a client and that the client can read them.
+func TestTrailersServerToClient(t *testing.T) {
+ run(t, func(t *testing.T, mode testMode) {
+ testTrailersServerToClient(t, mode, false)
+ })
+}
+func TestTrailersServerToClientFlush(t *testing.T) {
+ run(t, func(t *testing.T, mode testMode) {
+ testTrailersServerToClient(t, mode, true)
+ })
+}
+
+func testTrailersServerToClient(t *testing.T, mode testMode, flush bool) {
+ const body = "Some body"
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B")
+ w.Header().Add("Trailer", "Server-Trailer-C")
+
+ io.WriteString(w, body)
+ if flush {
+ w.(Flusher).Flush()
+ }
+
+ // How handlers set Trailers: declare it ahead of time
+ // with the Trailer header, and then mutate the
+ // Header() of those values later, after the response
+ // has been written (we wrote to w above).
+ w.Header().Set("Server-Trailer-A", "valuea")
+ w.Header().Set("Server-Trailer-C", "valuec") // skipping B
+ w.Header().Set("Server-Trailer-NotDeclared", "should be omitted")
+ }))
+
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ wantHeader := Header{
+ "Content-Type": {"text/plain; charset=utf-8"},
+ }
+ wantLen := -1
+ if mode == http2Mode && !flush {
+ // In HTTP/1.1, any use of trailers forces HTTP/1.1
+ // chunking and a flush at the first write. That's
+ // unnecessary with HTTP/2's framing, so the server
+ // is able to calculate the length while still sending
+ // trailers afterwards.
+ wantLen = len(body)
+ wantHeader["Content-Length"] = []string{fmt.Sprint(wantLen)}
+ }
+ if res.ContentLength != int64(wantLen) {
+ t.Errorf("ContentLength = %v; want %v", res.ContentLength, wantLen)
+ }
+
+ delete(res.Header, "Date") // irrelevant for test
+ if !reflect.DeepEqual(res.Header, wantHeader) {
+ t.Errorf("Header = %v; want %v", res.Header, wantHeader)
+ }
+
+ if got, want := res.Trailer, (Header{
+ "Server-Trailer-A": nil,
+ "Server-Trailer-B": nil,
+ "Server-Trailer-C": nil,
+ }); !reflect.DeepEqual(got, want) {
+ t.Errorf("Trailer before body read = %v; want %v", got, want)
+ }
+
+ if err := wantBody(res, nil, body); err != nil {
+ t.Fatal(err)
+ }
+
+ if got, want := res.Trailer, (Header{
+ "Server-Trailer-A": {"valuea"},
+ "Server-Trailer-B": nil,
+ "Server-Trailer-C": {"valuec"},
+ }); !reflect.DeepEqual(got, want) {
+ t.Errorf("Trailer after body read = %v; want %v", got, want)
+ }
+}
+
+// Don't allow a Body.Read after Body.Close. Issue 13648.
+func TestResponseBodyReadAfterClose(t *testing.T) { run(t, testResponseBodyReadAfterClose) }
+func testResponseBodyReadAfterClose(t *testing.T, mode testMode) {
+ const body = "Some body"
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ io.WriteString(w, body)
+ }))
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ data, err := io.ReadAll(res.Body)
+ if len(data) != 0 || err == nil {
+ t.Fatalf("ReadAll returned %q, %v; want error", data, err)
+ }
+}
+
+func TestConcurrentReadWriteReqBody(t *testing.T) { run(t, testConcurrentReadWriteReqBody) }
+func testConcurrentReadWriteReqBody(t *testing.T, mode testMode) {
+ const reqBody = "some request body"
+ const resBody = "some response body"
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ var wg sync.WaitGroup
+ wg.Add(2)
+ didRead := make(chan bool, 1)
+ // Read in one goroutine.
+ go func() {
+ defer wg.Done()
+ data, err := io.ReadAll(r.Body)
+ if string(data) != reqBody {
+ t.Errorf("Handler read %q; want %q", data, reqBody)
+ }
+ if err != nil {
+ t.Errorf("Handler Read: %v", err)
+ }
+ didRead <- true
+ }()
+ // Write in another goroutine.
+ go func() {
+ defer wg.Done()
+ if mode != http2Mode {
+ // our HTTP/1 implementation intentionally
+ // doesn't permit writes during read (mostly
+ // due to it being undefined); if that is ever
+ // relaxed, change this.
+ <-didRead
+ }
+ io.WriteString(w, resBody)
+ }()
+ wg.Wait()
+ }))
+ req, _ := NewRequest("POST", cst.ts.URL, strings.NewReader(reqBody))
+ req.Header.Add("Expect", "100-continue") // just to complicate things
+ res, err := cst.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ data, err := io.ReadAll(res.Body)
+ defer res.Body.Close()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if string(data) != resBody {
+ t.Errorf("read %q; want %q", data, resBody)
+ }
+}
+
+func TestConnectRequest(t *testing.T) { run(t, testConnectRequest) }
+func testConnectRequest(t *testing.T, mode testMode) {
+ gotc := make(chan *Request, 1)
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ gotc <- r
+ }))
+
+ u, err := url.Parse(cst.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ tests := []struct {
+ req *Request
+ want string
+ }{
+ {
+ req: &Request{
+ Method: "CONNECT",
+ Header: Header{},
+ URL: u,
+ },
+ want: u.Host,
+ },
+ {
+ req: &Request{
+ Method: "CONNECT",
+ Header: Header{},
+ URL: u,
+ Host: "example.com:123",
+ },
+ want: "example.com:123",
+ },
+ }
+
+ for i, tt := range tests {
+ res, err := cst.c.Do(tt.req)
+ if err != nil {
+ t.Errorf("%d. RoundTrip = %v", i, err)
+ continue
+ }
+ res.Body.Close()
+ req := <-gotc
+ if req.Method != "CONNECT" {
+ t.Errorf("method = %q; want CONNECT", req.Method)
+ }
+ if req.Host != tt.want {
+ t.Errorf("Host = %q; want %q", req.Host, tt.want)
+ }
+ if req.URL.Host != tt.want {
+ t.Errorf("URL.Host = %q; want %q", req.URL.Host, tt.want)
+ }
+ }
+}
+
+func TestTransportUserAgent(t *testing.T) { run(t, testTransportUserAgent) }
+func testTransportUserAgent(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ fmt.Fprintf(w, "%q", r.Header["User-Agent"])
+ }))
+
+ either := func(a, b string) string {
+ if mode == http2Mode {
+ return b
+ }
+ return a
+ }
+
+ tests := []struct {
+ setup func(*Request)
+ want string
+ }{
+ {
+ func(r *Request) {},
+ either(`["Go-http-client/1.1"]`, `["Go-http-client/2.0"]`),
+ },
+ {
+ func(r *Request) { r.Header.Set("User-Agent", "foo/1.2.3") },
+ `["foo/1.2.3"]`,
+ },
+ {
+ func(r *Request) { r.Header["User-Agent"] = []string{"single", "or", "multiple"} },
+ `["single"]`,
+ },
+ {
+ func(r *Request) { r.Header.Set("User-Agent", "") },
+ `[]`,
+ },
+ {
+ func(r *Request) { r.Header["User-Agent"] = nil },
+ `[]`,
+ },
+ }
+ for i, tt := range tests {
+ req, _ := NewRequest("GET", cst.ts.URL, nil)
+ tt.setup(req)
+ res, err := cst.c.Do(req)
+ if err != nil {
+ t.Errorf("%d. RoundTrip = %v", i, err)
+ continue
+ }
+ slurp, err := io.ReadAll(res.Body)
+ res.Body.Close()
+ if err != nil {
+ t.Errorf("%d. read body = %v", i, err)
+ continue
+ }
+ if string(slurp) != tt.want {
+ t.Errorf("%d. body mismatch.\n got: %s\nwant: %s\n", i, slurp, tt.want)
+ }
+ }
+}
+
+func TestStarRequestMethod(t *testing.T) {
+ for _, method := range []string{"FOO", "OPTIONS"} {
+ t.Run(method, func(t *testing.T) {
+ run(t, func(t *testing.T, mode testMode) {
+ testStarRequest(t, method, mode)
+ })
+ })
+ }
+}
+func testStarRequest(t *testing.T, method string, mode testMode) {
+ gotc := make(chan *Request, 1)
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("foo", "bar")
+ gotc <- r
+ w.(Flusher).Flush()
+ }))
+
+ u, err := url.Parse(cst.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ u.Path = "*"
+
+ req := &Request{
+ Method: method,
+ Header: Header{},
+ URL: u,
+ }
+
+ res, err := cst.c.Do(req)
+ if err != nil {
+ t.Fatalf("RoundTrip = %v", err)
+ }
+ res.Body.Close()
+
+ wantFoo := "bar"
+ wantLen := int64(-1)
+ if method == "OPTIONS" {
+ wantFoo = ""
+ wantLen = 0
+ }
+ if res.StatusCode != 200 {
+ t.Errorf("status code = %v; want %d", res.Status, 200)
+ }
+ if res.ContentLength != wantLen {
+ t.Errorf("content length = %v; want %d", res.ContentLength, wantLen)
+ }
+ if got := res.Header.Get("foo"); got != wantFoo {
+ t.Errorf("response \"foo\" header = %q; want %q", got, wantFoo)
+ }
+ select {
+ case req = <-gotc:
+ default:
+ req = nil
+ }
+ if req == nil {
+ if method != "OPTIONS" {
+ t.Fatalf("handler never got request")
+ }
+ return
+ }
+ if req.Method != method {
+ t.Errorf("method = %q; want %q", req.Method, method)
+ }
+ if req.URL.Path != "*" {
+ t.Errorf("URL.Path = %q; want *", req.URL.Path)
+ }
+ if req.RequestURI != "*" {
+ t.Errorf("RequestURI = %q; want *", req.RequestURI)
+ }
+}
+
+// Issue 13957
+func TestTransportDiscardsUnneededConns(t *testing.T) {
+ run(t, testTransportDiscardsUnneededConns, []testMode{http2Mode})
+}
+func testTransportDiscardsUnneededConns(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ fmt.Fprintf(w, "Hello, %v", r.RemoteAddr)
+ }))
+ defer cst.close()
+
+ var numOpen, numClose int32 // atomic
+
+ tlsConfig := &tls.Config{InsecureSkipVerify: true}
+ tr := &Transport{
+ TLSClientConfig: tlsConfig,
+ DialTLS: func(_, addr string) (net.Conn, error) {
+ time.Sleep(10 * time.Millisecond)
+ rc, err := net.Dial("tcp", addr)
+ if err != nil {
+ return nil, err
+ }
+ atomic.AddInt32(&numOpen, 1)
+ c := noteCloseConn{rc, func() { atomic.AddInt32(&numClose, 1) }}
+ return tls.Client(c, tlsConfig), nil
+ },
+ }
+ if err := ExportHttp2ConfigureTransport(tr); err != nil {
+ t.Fatal(err)
+ }
+ defer tr.CloseIdleConnections()
+
+ c := &Client{Transport: tr}
+
+ const N = 10
+ gotBody := make(chan string, N)
+ var wg sync.WaitGroup
+ for i := 0; i < N; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ resp, err := c.Get(cst.ts.URL)
+ if err != nil {
+ // Try to work around spurious connection reset on loaded system.
+ // See golang.org/issue/33585 and golang.org/issue/36797.
+ time.Sleep(10 * time.Millisecond)
+ resp, err = c.Get(cst.ts.URL)
+ if err != nil {
+ t.Errorf("Get: %v", err)
+ return
+ }
+ }
+ defer resp.Body.Close()
+ slurp, err := io.ReadAll(resp.Body)
+ if err != nil {
+ t.Error(err)
+ }
+ gotBody <- string(slurp)
+ }()
+ }
+ wg.Wait()
+ close(gotBody)
+
+ var last string
+ for got := range gotBody {
+ if last == "" {
+ last = got
+ continue
+ }
+ if got != last {
+ t.Errorf("Response body changed: %q -> %q", last, got)
+ }
+ }
+
+ var open, close int32
+ for i := 0; i < 150; i++ {
+ open, close = atomic.LoadInt32(&numOpen), atomic.LoadInt32(&numClose)
+ if open < 1 {
+ t.Fatalf("open = %d; want at least", open)
+ }
+ if close == open-1 {
+ // Success
+ return
+ }
+ time.Sleep(10 * time.Millisecond)
+ }
+ t.Errorf("%d connections opened, %d closed; want %d to close", open, close, open-1)
+}
+
+// tests that Transport doesn't retain a pointer to the provided request.
+func TestTransportGCRequest(t *testing.T) {
+ run(t, func(t *testing.T, mode testMode) {
+ t.Run("Body", func(t *testing.T) { testTransportGCRequest(t, mode, true) })
+ t.Run("NoBody", func(t *testing.T) { testTransportGCRequest(t, mode, false) })
+ })
+}
+func testTransportGCRequest(t *testing.T, mode testMode, body bool) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ io.ReadAll(r.Body)
+ if body {
+ io.WriteString(w, "Hello.")
+ }
+ }))
+
+ didGC := make(chan struct{})
+ (func() {
+ body := strings.NewReader("some body")
+ req, _ := NewRequest("POST", cst.ts.URL, body)
+ runtime.SetFinalizer(req, func(*Request) { close(didGC) })
+ res, err := cst.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if _, err := io.ReadAll(res.Body); err != nil {
+ t.Fatal(err)
+ }
+ if err := res.Body.Close(); err != nil {
+ t.Fatal(err)
+ }
+ })()
+ timeout := time.NewTimer(5 * time.Second)
+ defer timeout.Stop()
+ for {
+ select {
+ case <-didGC:
+ return
+ case <-time.After(100 * time.Millisecond):
+ runtime.GC()
+ case <-timeout.C:
+ t.Fatal("never saw GC of request")
+ }
+ }
+}
+
+func TestTransportRejectsInvalidHeaders(t *testing.T) { run(t, testTransportRejectsInvalidHeaders) }
+func testTransportRejectsInvalidHeaders(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ fmt.Fprintf(w, "Handler saw headers: %q", r.Header)
+ }), optQuietLog)
+ cst.tr.DisableKeepAlives = true
+
+ tests := []struct {
+ key, val string
+ ok bool
+ }{
+ {"Foo", "capital-key", true}, // verify h2 allows capital keys
+ {"Foo", "foo\x00bar", false}, // \x00 byte in value not allowed
+ {"Foo", "two\nlines", false}, // \n byte in value not allowed
+ {"bogus\nkey", "v", false}, // \n byte also not allowed in key
+ {"A space", "v", false}, // spaces in keys not allowed
+ {"имя", "v", false}, // key must be ascii
+ {"name", "валю", true}, // value may be non-ascii
+ {"", "v", false}, // key must be non-empty
+ {"k", "", true}, // value may be empty
+ }
+ for _, tt := range tests {
+ dialedc := make(chan bool, 1)
+ cst.tr.Dial = func(netw, addr string) (net.Conn, error) {
+ dialedc <- true
+ return net.Dial(netw, addr)
+ }
+ req, _ := NewRequest("GET", cst.ts.URL, nil)
+ req.Header[tt.key] = []string{tt.val}
+ res, err := cst.c.Do(req)
+ var body []byte
+ if err == nil {
+ body, _ = io.ReadAll(res.Body)
+ res.Body.Close()
+ }
+ var dialed bool
+ select {
+ case <-dialedc:
+ dialed = true
+ default:
+ }
+
+ if !tt.ok && dialed {
+ t.Errorf("For key %q, value %q, transport dialed. Expected local failure. Response was: (%v, %v)\nServer replied with: %s", tt.key, tt.val, res, err, body)
+ } else if (err == nil) != tt.ok {
+ t.Errorf("For key %q, value %q; got err = %v; want ok=%v", tt.key, tt.val, err, tt.ok)
+ }
+ }
+}
+
+func TestInterruptWithPanic(t *testing.T) {
+ run(t, func(t *testing.T, mode testMode) {
+ t.Run("boom", func(t *testing.T) { testInterruptWithPanic(t, mode, "boom") })
+ t.Run("nil", func(t *testing.T) { t.Setenv("GODEBUG", "panicnil=1"); testInterruptWithPanic(t, mode, nil) })
+ t.Run("ErrAbortHandler", func(t *testing.T) { testInterruptWithPanic(t, mode, ErrAbortHandler) })
+ }, testNotParallel)
+}
+func testInterruptWithPanic(t *testing.T, mode testMode, panicValue any) {
+ const msg = "hello"
+
+ testDone := make(chan struct{})
+ defer close(testDone)
+
+ var errorLog lockedBytesBuffer
+ gotHeaders := make(chan bool, 1)
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ io.WriteString(w, msg)
+ w.(Flusher).Flush()
+
+ select {
+ case <-gotHeaders:
+ case <-testDone:
+ }
+ panic(panicValue)
+ }), func(ts *httptest.Server) {
+ ts.Config.ErrorLog = log.New(&errorLog, "", 0)
+ })
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ gotHeaders <- true
+ defer res.Body.Close()
+ slurp, err := io.ReadAll(res.Body)
+ if string(slurp) != msg {
+ t.Errorf("client read %q; want %q", slurp, msg)
+ }
+ if err == nil {
+ t.Errorf("client read all successfully; want some error")
+ }
+ logOutput := func() string {
+ errorLog.Lock()
+ defer errorLog.Unlock()
+ return errorLog.String()
+ }
+ wantStackLogged := panicValue != nil && panicValue != ErrAbortHandler
+
+ waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
+ gotLog := logOutput()
+ if !wantStackLogged {
+ if gotLog == "" {
+ return true
+ }
+ t.Fatalf("want no log output; got: %s", gotLog)
+ }
+ if gotLog == "" {
+ if d > 0 {
+ t.Logf("wanted a stack trace logged; got nothing after %v", d)
+ }
+ return false
+ }
+ if !strings.Contains(gotLog, "created by ") && strings.Count(gotLog, "\n") < 6 {
+ if d > 0 {
+ t.Logf("output doesn't look like a panic stack trace after %v. Got: %s", d, gotLog)
+ }
+ return false
+ }
+ return true
+ })
+}
+
+type lockedBytesBuffer struct {
+ sync.Mutex
+ bytes.Buffer
+}
+
+func (b *lockedBytesBuffer) Write(p []byte) (int, error) {
+ b.Lock()
+ defer b.Unlock()
+ return b.Buffer.Write(p)
+}
+
+// Issue 15366
+func TestH12_AutoGzipWithDumpResponse(t *testing.T) {
+ h12Compare{
+ Handler: func(w ResponseWriter, r *Request) {
+ h := w.Header()
+ h.Set("Content-Encoding", "gzip")
+ h.Set("Content-Length", "23")
+ io.WriteString(w, "\x1f\x8b\b\x00\x00\x00\x00\x00\x00\x00s\xf3\xf7\a\x00\xab'\xd4\x1a\x03\x00\x00\x00")
+ },
+ EarlyCheckResponse: func(proto string, res *Response) {
+ if !res.Uncompressed {
+ t.Errorf("%s: expected Uncompressed to be set", proto)
+ }
+ dump, err := httputil.DumpResponse(res, true)
+ if err != nil {
+ t.Errorf("%s: DumpResponse: %v", proto, err)
+ return
+ }
+ if strings.Contains(string(dump), "Connection: close") {
+ t.Errorf("%s: should not see \"Connection: close\" in dump; got:\n%s", proto, dump)
+ }
+ if !strings.Contains(string(dump), "FOO") {
+ t.Errorf("%s: should see \"FOO\" in response; got:\n%s", proto, dump)
+ }
+ },
+ }.run(t)
+}
+
+// Issue 14607
+func TestCloseIdleConnections(t *testing.T) { run(t, testCloseIdleConnections) }
+func testCloseIdleConnections(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("X-Addr", r.RemoteAddr)
+ }))
+ get := func() string {
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ v := res.Header.Get("X-Addr")
+ if v == "" {
+ t.Fatal("didn't get X-Addr")
+ }
+ return v
+ }
+ a1 := get()
+ cst.tr.CloseIdleConnections()
+ a2 := get()
+ if a1 == a2 {
+ t.Errorf("didn't close connection")
+ }
+}
+
+type noteCloseConn struct {
+ net.Conn
+ closeFunc func()
+}
+
+func (x noteCloseConn) Close() error {
+ x.closeFunc()
+ return x.Conn.Close()
+}
+
+type testErrorReader struct{ t *testing.T }
+
+func (r testErrorReader) Read(p []byte) (n int, err error) {
+ r.t.Error("unexpected Read call")
+ return 0, io.EOF
+}
+
+func TestNoSniffExpectRequestBody(t *testing.T) { run(t, testNoSniffExpectRequestBody) }
+func testNoSniffExpectRequestBody(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.WriteHeader(StatusUnauthorized)
+ }))
+
+ // Set ExpectContinueTimeout non-zero so RoundTrip won't try to write it.
+ cst.tr.ExpectContinueTimeout = 10 * time.Second
+
+ req, err := NewRequest("POST", cst.ts.URL, testErrorReader{t})
+ if err != nil {
+ t.Fatal(err)
+ }
+ req.ContentLength = 0 // so transport is tempted to sniff it
+ req.Header.Set("Expect", "100-continue")
+ res, err := cst.tr.RoundTrip(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ if res.StatusCode != StatusUnauthorized {
+ t.Errorf("status code = %v; want %v", res.StatusCode, StatusUnauthorized)
+ }
+}
+
+func TestServerUndeclaredTrailers(t *testing.T) { run(t, testServerUndeclaredTrailers) }
+func testServerUndeclaredTrailers(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Foo", "Bar")
+ w.Header().Set("Trailer:Foo", "Baz")
+ w.(Flusher).Flush()
+ w.Header().Add("Trailer:Foo", "Baz2")
+ w.Header().Set("Trailer:Bar", "Quux")
+ }))
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if _, err := io.Copy(io.Discard, res.Body); err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ delete(res.Header, "Date")
+ delete(res.Header, "Content-Type")
+
+ if want := (Header{"Foo": {"Bar"}}); !reflect.DeepEqual(res.Header, want) {
+ t.Errorf("Header = %#v; want %#v", res.Header, want)
+ }
+ if want := (Header{"Foo": {"Baz", "Baz2"}, "Bar": {"Quux"}}); !reflect.DeepEqual(res.Trailer, want) {
+ t.Errorf("Trailer = %#v; want %#v", res.Trailer, want)
+ }
+}
+
+func TestBadResponseAfterReadingBody(t *testing.T) {
+ run(t, testBadResponseAfterReadingBody, []testMode{http1Mode})
+}
+func testBadResponseAfterReadingBody(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ _, err := io.Copy(io.Discard, r.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ c, _, err := w.(Hijacker).Hijack()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+ fmt.Fprintln(c, "some bogus crap")
+ }))
+
+ closes := 0
+ res, err := cst.c.Post(cst.ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
+ if err == nil {
+ res.Body.Close()
+ t.Fatal("expected an error to be returned from Post")
+ }
+ if closes != 1 {
+ t.Errorf("closes = %d; want 1", closes)
+ }
+}
+
+func TestWriteHeader0(t *testing.T) { run(t, testWriteHeader0) }
+func testWriteHeader0(t *testing.T, mode testMode) {
+ gotpanic := make(chan bool, 1)
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ defer close(gotpanic)
+ defer func() {
+ if e := recover(); e != nil {
+ got := fmt.Sprintf("%T, %v", e, e)
+ want := "string, invalid WriteHeader code 0"
+ if got != want {
+ t.Errorf("unexpected panic value:\n got: %v\nwant: %v\n", got, want)
+ }
+ gotpanic <- true
+
+ // Set an explicit 503. This also tests that the WriteHeader call panics
+ // before it recorded that an explicit value was set and that bogus
+ // value wasn't stuck.
+ w.WriteHeader(503)
+ }
+ }()
+ w.WriteHeader(0)
+ }))
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if res.StatusCode != 503 {
+ t.Errorf("Response: %v %q; want 503", res.StatusCode, res.Status)
+ }
+ if !<-gotpanic {
+ t.Error("expected panic in handler")
+ }
+}
+
+// Issue 23010: don't be super strict checking WriteHeader's code if
+// it's not even valid to call WriteHeader then anyway.
+func TestWriteHeaderNoCodeCheck(t *testing.T) {
+ run(t, func(t *testing.T, mode testMode) {
+ testWriteHeaderAfterWrite(t, mode, false)
+ })
+}
+func TestWriteHeaderNoCodeCheck_h1hijack(t *testing.T) {
+ testWriteHeaderAfterWrite(t, http1Mode, true)
+}
+func testWriteHeaderAfterWrite(t *testing.T, mode testMode, hijack bool) {
+ var errorLog lockedBytesBuffer
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ if hijack {
+ conn, _, _ := w.(Hijacker).Hijack()
+ defer conn.Close()
+ conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 6\r\n\r\nfoo"))
+ w.WriteHeader(0) // verify this doesn't panic if there's already output; Issue 23010
+ conn.Write([]byte("bar"))
+ return
+ }
+ io.WriteString(w, "foo")
+ w.(Flusher).Flush()
+ w.WriteHeader(0) // verify this doesn't panic if there's already output; Issue 23010
+ io.WriteString(w, "bar")
+ }), func(ts *httptest.Server) {
+ ts.Config.ErrorLog = log.New(&errorLog, "", 0)
+ })
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ body, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got, want := string(body), "foobar"; got != want {
+ t.Errorf("got = %q; want %q", got, want)
+ }
+
+ // Also check the stderr output:
+ if mode == http2Mode {
+ // TODO: also emit this log message for HTTP/2?
+ // We historically haven't, so don't check.
+ return
+ }
+ gotLog := strings.TrimSpace(errorLog.String())
+ wantLog := "http: superfluous response.WriteHeader call from net/http_test.testWriteHeaderAfterWrite.func1 (clientserver_test.go:"
+ if hijack {
+ wantLog = "http: response.WriteHeader on hijacked connection from net/http_test.testWriteHeaderAfterWrite.func1 (clientserver_test.go:"
+ }
+ if !strings.HasPrefix(gotLog, wantLog) {
+ t.Errorf("stderr output = %q; want %q", gotLog, wantLog)
+ }
+}
+
+func TestBidiStreamReverseProxy(t *testing.T) {
+ run(t, testBidiStreamReverseProxy, []testMode{http2Mode})
+}
+func testBidiStreamReverseProxy(t *testing.T, mode testMode) {
+ backend := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ if _, err := io.Copy(w, r.Body); err != nil {
+ log.Printf("bidi backend copy: %v", err)
+ }
+ }))
+
+ backURL, err := url.Parse(backend.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ rp := httputil.NewSingleHostReverseProxy(backURL)
+ rp.Transport = backend.tr
+ proxy := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ rp.ServeHTTP(w, r)
+ }))
+
+ bodyRes := make(chan any, 1) // error or hash.Hash
+ pr, pw := io.Pipe()
+ req, _ := NewRequest("PUT", proxy.ts.URL, pr)
+ const size = 4 << 20
+ go func() {
+ h := sha1.New()
+ _, err := io.CopyN(io.MultiWriter(h, pw), rand.Reader, size)
+ go pw.Close()
+ if err != nil {
+ bodyRes <- err
+ } else {
+ bodyRes <- h
+ }
+ }()
+ res, err := backend.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ hgot := sha1.New()
+ n, err := io.Copy(hgot, res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if n != size {
+ t.Fatalf("got %d bytes; want %d", n, size)
+ }
+ select {
+ case v := <-bodyRes:
+ switch v := v.(type) {
+ default:
+ t.Fatalf("body copy: %v", err)
+ case hash.Hash:
+ if !bytes.Equal(v.Sum(nil), hgot.Sum(nil)) {
+ t.Errorf("written bytes didn't match received bytes")
+ }
+ }
+ case <-time.After(10 * time.Second):
+ t.Fatal("timeout")
+ }
+
+}
+
+// Always use HTTP/1.1 for WebSocket upgrades.
+func TestH12_WebSocketUpgrade(t *testing.T) {
+ h12Compare{
+ Handler: func(w ResponseWriter, r *Request) {
+ h := w.Header()
+ h.Set("Foo", "bar")
+ },
+ ReqFunc: func(c *Client, url string) (*Response, error) {
+ req, _ := NewRequest("GET", url, nil)
+ req.Header.Set("Connection", "Upgrade")
+ req.Header.Set("Upgrade", "WebSocket")
+ return c.Do(req)
+ },
+ EarlyCheckResponse: func(proto string, res *Response) {
+ if res.Proto != "HTTP/1.1" {
+ t.Errorf("%s: expected HTTP/1.1, got %q", proto, res.Proto)
+ }
+ res.Proto = "HTTP/IGNORE" // skip later checks that Proto must be 1.1 vs 2.0
+ },
+ }.run(t)
+}
+
+func TestIdentityTransferEncoding(t *testing.T) { run(t, testIdentityTransferEncoding) }
+func testIdentityTransferEncoding(t *testing.T, mode testMode) {
+ const body = "body"
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ gotBody, _ := io.ReadAll(r.Body)
+ if got, want := string(gotBody), body; got != want {
+ t.Errorf("got request body = %q; want %q", got, want)
+ }
+ w.Header().Set("Transfer-Encoding", "identity")
+ w.WriteHeader(StatusOK)
+ w.(Flusher).Flush()
+ io.WriteString(w, body)
+ }))
+ req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader(body))
+ res, err := cst.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ gotBody, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got, want := string(gotBody), body; got != want {
+ t.Errorf("got response body = %q; want %q", got, want)
+ }
+}
+
+func TestEarlyHintsRequest(t *testing.T) { run(t, testEarlyHintsRequest) }
+func testEarlyHintsRequest(t *testing.T, mode testMode) {
+ var wg sync.WaitGroup
+ wg.Add(1)
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ h := w.Header()
+
+ h.Add("Content-Length", "123") // must be ignored
+ h.Add("Link", "</style.css>; rel=preload; as=style")
+ h.Add("Link", "</script.js>; rel=preload; as=script")
+ w.WriteHeader(StatusEarlyHints)
+
+ wg.Wait()
+
+ h.Add("Link", "</foo.js>; rel=preload; as=script")
+ w.WriteHeader(StatusEarlyHints)
+
+ w.Write([]byte("Hello"))
+ }))
+
+ checkLinkHeaders := func(t *testing.T, expected, got []string) {
+ t.Helper()
+
+ if len(expected) != len(got) {
+ t.Errorf("got %d expected %d", len(got), len(expected))
+ }
+
+ for i := range expected {
+ if expected[i] != got[i] {
+ t.Errorf("got %q expected %q", got[i], expected[i])
+ }
+ }
+ }
+
+ checkExcludedHeaders := func(t *testing.T, header textproto.MIMEHeader) {
+ t.Helper()
+
+ for _, h := range []string{"Content-Length", "Transfer-Encoding"} {
+ if v, ok := header[h]; ok {
+ t.Errorf("%s is %q; must not be sent", h, v)
+ }
+ }
+ }
+
+ var respCounter uint8
+ trace := &httptrace.ClientTrace{
+ Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
+ switch respCounter {
+ case 0:
+ checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script"}, header["Link"])
+ checkExcludedHeaders(t, header)
+
+ wg.Done()
+ case 1:
+ checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, header["Link"])
+ checkExcludedHeaders(t, header)
+
+ default:
+ t.Error("Unexpected 1xx response")
+ }
+
+ respCounter++
+
+ return nil
+ },
+ }
+ req, _ := NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), "GET", cst.ts.URL, nil)
+
+ res, err := cst.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+
+ checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, res.Header["Link"])
+ if cl := res.Header.Get("Content-Length"); cl != "123" {
+ t.Errorf("Content-Length is %q; want 123", cl)
+ }
+
+ body, _ := io.ReadAll(res.Body)
+ if string(body) != "Hello" {
+ t.Errorf("Read body %q; want Hello", body)
+ }
+}
diff --git a/src/net/http/clone.go b/src/net/http/clone.go
new file mode 100644
index 0000000..3a3375b
--- /dev/null
+++ b/src/net/http/clone.go
@@ -0,0 +1,74 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "mime/multipart"
+ "net/textproto"
+ "net/url"
+)
+
+func cloneURLValues(v url.Values) url.Values {
+ if v == nil {
+ return nil
+ }
+ // http.Header and url.Values have the same representation, so temporarily
+ // treat it like http.Header, which does have a clone:
+ return url.Values(Header(v).Clone())
+}
+
+func cloneURL(u *url.URL) *url.URL {
+ if u == nil {
+ return nil
+ }
+ u2 := new(url.URL)
+ *u2 = *u
+ if u.User != nil {
+ u2.User = new(url.Userinfo)
+ *u2.User = *u.User
+ }
+ return u2
+}
+
+func cloneMultipartForm(f *multipart.Form) *multipart.Form {
+ if f == nil {
+ return nil
+ }
+ f2 := &multipart.Form{
+ Value: (map[string][]string)(Header(f.Value).Clone()),
+ }
+ if f.File != nil {
+ m := make(map[string][]*multipart.FileHeader)
+ for k, vv := range f.File {
+ vv2 := make([]*multipart.FileHeader, len(vv))
+ for i, v := range vv {
+ vv2[i] = cloneMultipartFileHeader(v)
+ }
+ m[k] = vv2
+ }
+ f2.File = m
+ }
+ return f2
+}
+
+func cloneMultipartFileHeader(fh *multipart.FileHeader) *multipart.FileHeader {
+ if fh == nil {
+ return nil
+ }
+ fh2 := new(multipart.FileHeader)
+ *fh2 = *fh
+ fh2.Header = textproto.MIMEHeader(Header(fh.Header).Clone())
+ return fh2
+}
+
+// cloneOrMakeHeader invokes Header.Clone but if the
+// result is nil, it'll instead make and return a non-nil Header.
+func cloneOrMakeHeader(hdr Header) Header {
+ clone := hdr.Clone()
+ if clone == nil {
+ clone = make(Header)
+ }
+ return clone
+}
diff --git a/src/net/http/cookie.go b/src/net/http/cookie.go
new file mode 100644
index 0000000..912fde6
--- /dev/null
+++ b/src/net/http/cookie.go
@@ -0,0 +1,468 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "errors"
+ "fmt"
+ "log"
+ "net"
+ "net/http/internal/ascii"
+ "net/textproto"
+ "strconv"
+ "strings"
+ "time"
+)
+
+// A Cookie represents an HTTP cookie as sent in the Set-Cookie header of an
+// HTTP response or the Cookie header of an HTTP request.
+//
+// See https://tools.ietf.org/html/rfc6265 for details.
+type Cookie struct {
+ Name string
+ Value string
+
+ Path string // optional
+ Domain string // optional
+ Expires time.Time // optional
+ RawExpires string // for reading cookies only
+
+ // MaxAge=0 means no 'Max-Age' attribute specified.
+ // MaxAge<0 means delete cookie now, equivalently 'Max-Age: 0'
+ // MaxAge>0 means Max-Age attribute present and given in seconds
+ MaxAge int
+ Secure bool
+ HttpOnly bool
+ SameSite SameSite
+ Raw string
+ Unparsed []string // Raw text of unparsed attribute-value pairs
+}
+
+// SameSite allows a server to define a cookie attribute making it impossible for
+// the browser to send this cookie along with cross-site requests. The main
+// goal is to mitigate the risk of cross-origin information leakage, and provide
+// some protection against cross-site request forgery attacks.
+//
+// See https://tools.ietf.org/html/draft-ietf-httpbis-cookie-same-site-00 for details.
+type SameSite int
+
+const (
+ SameSiteDefaultMode SameSite = iota + 1
+ SameSiteLaxMode
+ SameSiteStrictMode
+ SameSiteNoneMode
+)
+
+// readSetCookies parses all "Set-Cookie" values from
+// the header h and returns the successfully parsed Cookies.
+func readSetCookies(h Header) []*Cookie {
+ cookieCount := len(h["Set-Cookie"])
+ if cookieCount == 0 {
+ return []*Cookie{}
+ }
+ cookies := make([]*Cookie, 0, cookieCount)
+ for _, line := range h["Set-Cookie"] {
+ parts := strings.Split(textproto.TrimString(line), ";")
+ if len(parts) == 1 && parts[0] == "" {
+ continue
+ }
+ parts[0] = textproto.TrimString(parts[0])
+ name, value, ok := strings.Cut(parts[0], "=")
+ if !ok {
+ continue
+ }
+ name = textproto.TrimString(name)
+ if !isCookieNameValid(name) {
+ continue
+ }
+ value, ok = parseCookieValue(value, true)
+ if !ok {
+ continue
+ }
+ c := &Cookie{
+ Name: name,
+ Value: value,
+ Raw: line,
+ }
+ for i := 1; i < len(parts); i++ {
+ parts[i] = textproto.TrimString(parts[i])
+ if len(parts[i]) == 0 {
+ continue
+ }
+
+ attr, val, _ := strings.Cut(parts[i], "=")
+ lowerAttr, isASCII := ascii.ToLower(attr)
+ if !isASCII {
+ continue
+ }
+ val, ok = parseCookieValue(val, false)
+ if !ok {
+ c.Unparsed = append(c.Unparsed, parts[i])
+ continue
+ }
+
+ switch lowerAttr {
+ case "samesite":
+ lowerVal, ascii := ascii.ToLower(val)
+ if !ascii {
+ c.SameSite = SameSiteDefaultMode
+ continue
+ }
+ switch lowerVal {
+ case "lax":
+ c.SameSite = SameSiteLaxMode
+ case "strict":
+ c.SameSite = SameSiteStrictMode
+ case "none":
+ c.SameSite = SameSiteNoneMode
+ default:
+ c.SameSite = SameSiteDefaultMode
+ }
+ continue
+ case "secure":
+ c.Secure = true
+ continue
+ case "httponly":
+ c.HttpOnly = true
+ continue
+ case "domain":
+ c.Domain = val
+ continue
+ case "max-age":
+ secs, err := strconv.Atoi(val)
+ if err != nil || secs != 0 && val[0] == '0' {
+ break
+ }
+ if secs <= 0 {
+ secs = -1
+ }
+ c.MaxAge = secs
+ continue
+ case "expires":
+ c.RawExpires = val
+ exptime, err := time.Parse(time.RFC1123, val)
+ if err != nil {
+ exptime, err = time.Parse("Mon, 02-Jan-2006 15:04:05 MST", val)
+ if err != nil {
+ c.Expires = time.Time{}
+ break
+ }
+ }
+ c.Expires = exptime.UTC()
+ continue
+ case "path":
+ c.Path = val
+ continue
+ }
+ c.Unparsed = append(c.Unparsed, parts[i])
+ }
+ cookies = append(cookies, c)
+ }
+ return cookies
+}
+
+// SetCookie adds a Set-Cookie header to the provided ResponseWriter's headers.
+// The provided cookie must have a valid Name. Invalid cookies may be
+// silently dropped.
+func SetCookie(w ResponseWriter, cookie *Cookie) {
+ if v := cookie.String(); v != "" {
+ w.Header().Add("Set-Cookie", v)
+ }
+}
+
+// String returns the serialization of the cookie for use in a Cookie
+// header (if only Name and Value are set) or a Set-Cookie response
+// header (if other fields are set).
+// If c is nil or c.Name is invalid, the empty string is returned.
+func (c *Cookie) String() string {
+ if c == nil || !isCookieNameValid(c.Name) {
+ return ""
+ }
+ // extraCookieLength derived from typical length of cookie attributes
+ // see RFC 6265 Sec 4.1.
+ const extraCookieLength = 110
+ var b strings.Builder
+ b.Grow(len(c.Name) + len(c.Value) + len(c.Domain) + len(c.Path) + extraCookieLength)
+ b.WriteString(c.Name)
+ b.WriteRune('=')
+ b.WriteString(sanitizeCookieValue(c.Value))
+
+ if len(c.Path) > 0 {
+ b.WriteString("; Path=")
+ b.WriteString(sanitizeCookiePath(c.Path))
+ }
+ if len(c.Domain) > 0 {
+ if validCookieDomain(c.Domain) {
+ // A c.Domain containing illegal characters is not
+ // sanitized but simply dropped which turns the cookie
+ // into a host-only cookie. A leading dot is okay
+ // but won't be sent.
+ d := c.Domain
+ if d[0] == '.' {
+ d = d[1:]
+ }
+ b.WriteString("; Domain=")
+ b.WriteString(d)
+ } else {
+ log.Printf("net/http: invalid Cookie.Domain %q; dropping domain attribute", c.Domain)
+ }
+ }
+ var buf [len(TimeFormat)]byte
+ if validCookieExpires(c.Expires) {
+ b.WriteString("; Expires=")
+ b.Write(c.Expires.UTC().AppendFormat(buf[:0], TimeFormat))
+ }
+ if c.MaxAge > 0 {
+ b.WriteString("; Max-Age=")
+ b.Write(strconv.AppendInt(buf[:0], int64(c.MaxAge), 10))
+ } else if c.MaxAge < 0 {
+ b.WriteString("; Max-Age=0")
+ }
+ if c.HttpOnly {
+ b.WriteString("; HttpOnly")
+ }
+ if c.Secure {
+ b.WriteString("; Secure")
+ }
+ switch c.SameSite {
+ case SameSiteDefaultMode:
+ // Skip, default mode is obtained by not emitting the attribute.
+ case SameSiteNoneMode:
+ b.WriteString("; SameSite=None")
+ case SameSiteLaxMode:
+ b.WriteString("; SameSite=Lax")
+ case SameSiteStrictMode:
+ b.WriteString("; SameSite=Strict")
+ }
+ return b.String()
+}
+
+// Valid reports whether the cookie is valid.
+func (c *Cookie) Valid() error {
+ if c == nil {
+ return errors.New("http: nil Cookie")
+ }
+ if !isCookieNameValid(c.Name) {
+ return errors.New("http: invalid Cookie.Name")
+ }
+ if !c.Expires.IsZero() && !validCookieExpires(c.Expires) {
+ return errors.New("http: invalid Cookie.Expires")
+ }
+ for i := 0; i < len(c.Value); i++ {
+ if !validCookieValueByte(c.Value[i]) {
+ return fmt.Errorf("http: invalid byte %q in Cookie.Value", c.Value[i])
+ }
+ }
+ if len(c.Path) > 0 {
+ for i := 0; i < len(c.Path); i++ {
+ if !validCookiePathByte(c.Path[i]) {
+ return fmt.Errorf("http: invalid byte %q in Cookie.Path", c.Path[i])
+ }
+ }
+ }
+ if len(c.Domain) > 0 {
+ if !validCookieDomain(c.Domain) {
+ return errors.New("http: invalid Cookie.Domain")
+ }
+ }
+ return nil
+}
+
+// readCookies parses all "Cookie" values from the header h and
+// returns the successfully parsed Cookies.
+//
+// if filter isn't empty, only cookies of that name are returned.
+func readCookies(h Header, filter string) []*Cookie {
+ lines := h["Cookie"]
+ if len(lines) == 0 {
+ return []*Cookie{}
+ }
+
+ cookies := make([]*Cookie, 0, len(lines)+strings.Count(lines[0], ";"))
+ for _, line := range lines {
+ line = textproto.TrimString(line)
+
+ var part string
+ for len(line) > 0 { // continue since we have rest
+ part, line, _ = strings.Cut(line, ";")
+ part = textproto.TrimString(part)
+ if part == "" {
+ continue
+ }
+ name, val, _ := strings.Cut(part, "=")
+ name = textproto.TrimString(name)
+ if !isCookieNameValid(name) {
+ continue
+ }
+ if filter != "" && filter != name {
+ continue
+ }
+ val, ok := parseCookieValue(val, true)
+ if !ok {
+ continue
+ }
+ cookies = append(cookies, &Cookie{Name: name, Value: val})
+ }
+ }
+ return cookies
+}
+
+// validCookieDomain reports whether v is a valid cookie domain-value.
+func validCookieDomain(v string) bool {
+ if isCookieDomainName(v) {
+ return true
+ }
+ if net.ParseIP(v) != nil && !strings.Contains(v, ":") {
+ return true
+ }
+ return false
+}
+
+// validCookieExpires reports whether v is a valid cookie expires-value.
+func validCookieExpires(t time.Time) bool {
+ // IETF RFC 6265 Section 5.1.1.5, the year must not be less than 1601
+ return t.Year() >= 1601
+}
+
+// isCookieDomainName reports whether s is a valid domain name or a valid
+// domain name with a leading dot '.'. It is almost a direct copy of
+// package net's isDomainName.
+func isCookieDomainName(s string) bool {
+ if len(s) == 0 {
+ return false
+ }
+ if len(s) > 255 {
+ return false
+ }
+
+ if s[0] == '.' {
+ // A cookie a domain attribute may start with a leading dot.
+ s = s[1:]
+ }
+ last := byte('.')
+ ok := false // Ok once we've seen a letter.
+ partlen := 0
+ for i := 0; i < len(s); i++ {
+ c := s[i]
+ switch {
+ default:
+ return false
+ case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z':
+ // No '_' allowed here (in contrast to package net).
+ ok = true
+ partlen++
+ case '0' <= c && c <= '9':
+ // fine
+ partlen++
+ case c == '-':
+ // Byte before dash cannot be dot.
+ if last == '.' {
+ return false
+ }
+ partlen++
+ case c == '.':
+ // Byte before dot cannot be dot, dash.
+ if last == '.' || last == '-' {
+ return false
+ }
+ if partlen > 63 || partlen == 0 {
+ return false
+ }
+ partlen = 0
+ }
+ last = c
+ }
+ if last == '-' || partlen > 63 {
+ return false
+ }
+
+ return ok
+}
+
+var cookieNameSanitizer = strings.NewReplacer("\n", "-", "\r", "-")
+
+func sanitizeCookieName(n string) string {
+ return cookieNameSanitizer.Replace(n)
+}
+
+// sanitizeCookieValue produces a suitable cookie-value from v.
+// https://tools.ietf.org/html/rfc6265#section-4.1.1
+//
+// cookie-value = *cookie-octet / ( DQUOTE *cookie-octet DQUOTE )
+// cookie-octet = %x21 / %x23-2B / %x2D-3A / %x3C-5B / %x5D-7E
+// ; US-ASCII characters excluding CTLs,
+// ; whitespace DQUOTE, comma, semicolon,
+// ; and backslash
+//
+// We loosen this as spaces and commas are common in cookie values
+// but we produce a quoted cookie-value if and only if v contains
+// commas or spaces.
+// See https://golang.org/issue/7243 for the discussion.
+func sanitizeCookieValue(v string) string {
+ v = sanitizeOrWarn("Cookie.Value", validCookieValueByte, v)
+ if len(v) == 0 {
+ return v
+ }
+ if strings.ContainsAny(v, " ,") {
+ return `"` + v + `"`
+ }
+ return v
+}
+
+func validCookieValueByte(b byte) bool {
+ return 0x20 <= b && b < 0x7f && b != '"' && b != ';' && b != '\\'
+}
+
+// path-av = "Path=" path-value
+// path-value = <any CHAR except CTLs or ";">
+func sanitizeCookiePath(v string) string {
+ return sanitizeOrWarn("Cookie.Path", validCookiePathByte, v)
+}
+
+func validCookiePathByte(b byte) bool {
+ return 0x20 <= b && b < 0x7f && b != ';'
+}
+
+func sanitizeOrWarn(fieldName string, valid func(byte) bool, v string) string {
+ ok := true
+ for i := 0; i < len(v); i++ {
+ if valid(v[i]) {
+ continue
+ }
+ log.Printf("net/http: invalid byte %q in %s; dropping invalid bytes", v[i], fieldName)
+ ok = false
+ break
+ }
+ if ok {
+ return v
+ }
+ buf := make([]byte, 0, len(v))
+ for i := 0; i < len(v); i++ {
+ if b := v[i]; valid(b) {
+ buf = append(buf, b)
+ }
+ }
+ return string(buf)
+}
+
+func parseCookieValue(raw string, allowDoubleQuote bool) (string, bool) {
+ // Strip the quotes, if present.
+ if allowDoubleQuote && len(raw) > 1 && raw[0] == '"' && raw[len(raw)-1] == '"' {
+ raw = raw[1 : len(raw)-1]
+ }
+ for i := 0; i < len(raw); i++ {
+ if !validCookieValueByte(raw[i]) {
+ return "", false
+ }
+ }
+ return raw, true
+}
+
+func isCookieNameValid(raw string) bool {
+ if raw == "" {
+ return false
+ }
+ return strings.IndexFunc(raw, isNotToken) < 0
+}
diff --git a/src/net/http/cookie_test.go b/src/net/http/cookie_test.go
new file mode 100644
index 0000000..e5bd46a
--- /dev/null
+++ b/src/net/http/cookie_test.go
@@ -0,0 +1,652 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "encoding/json"
+ "fmt"
+ "log"
+ "os"
+ "reflect"
+ "strings"
+ "testing"
+ "time"
+)
+
+var writeSetCookiesTests = []struct {
+ Cookie *Cookie
+ Raw string
+}{
+ {
+ &Cookie{Name: "cookie-1", Value: "v$1"},
+ "cookie-1=v$1",
+ },
+ {
+ &Cookie{Name: "cookie-2", Value: "two", MaxAge: 3600},
+ "cookie-2=two; Max-Age=3600",
+ },
+ {
+ &Cookie{Name: "cookie-3", Value: "three", Domain: ".example.com"},
+ "cookie-3=three; Domain=example.com",
+ },
+ {
+ &Cookie{Name: "cookie-4", Value: "four", Path: "/restricted/"},
+ "cookie-4=four; Path=/restricted/",
+ },
+ {
+ &Cookie{Name: "cookie-5", Value: "five", Domain: "wrong;bad.abc"},
+ "cookie-5=five",
+ },
+ {
+ &Cookie{Name: "cookie-6", Value: "six", Domain: "bad-.abc"},
+ "cookie-6=six",
+ },
+ {
+ &Cookie{Name: "cookie-7", Value: "seven", Domain: "127.0.0.1"},
+ "cookie-7=seven; Domain=127.0.0.1",
+ },
+ {
+ &Cookie{Name: "cookie-8", Value: "eight", Domain: "::1"},
+ "cookie-8=eight",
+ },
+ {
+ &Cookie{Name: "cookie-9", Value: "expiring", Expires: time.Unix(1257894000, 0)},
+ "cookie-9=expiring; Expires=Tue, 10 Nov 2009 23:00:00 GMT",
+ },
+ // According to IETF 6265 Section 5.1.1.5, the year cannot be less than 1601
+ {
+ &Cookie{Name: "cookie-10", Value: "expiring-1601", Expires: time.Date(1601, 1, 1, 1, 1, 1, 1, time.UTC)},
+ "cookie-10=expiring-1601; Expires=Mon, 01 Jan 1601 01:01:01 GMT",
+ },
+ {
+ &Cookie{Name: "cookie-11", Value: "invalid-expiry", Expires: time.Date(1600, 1, 1, 1, 1, 1, 1, time.UTC)},
+ "cookie-11=invalid-expiry",
+ },
+ {
+ &Cookie{Name: "cookie-12", Value: "samesite-default", SameSite: SameSiteDefaultMode},
+ "cookie-12=samesite-default",
+ },
+ {
+ &Cookie{Name: "cookie-13", Value: "samesite-lax", SameSite: SameSiteLaxMode},
+ "cookie-13=samesite-lax; SameSite=Lax",
+ },
+ {
+ &Cookie{Name: "cookie-14", Value: "samesite-strict", SameSite: SameSiteStrictMode},
+ "cookie-14=samesite-strict; SameSite=Strict",
+ },
+ {
+ &Cookie{Name: "cookie-15", Value: "samesite-none", SameSite: SameSiteNoneMode},
+ "cookie-15=samesite-none; SameSite=None",
+ },
+ // The "special" cookies have values containing commas or spaces which
+ // are disallowed by RFC 6265 but are common in the wild.
+ {
+ &Cookie{Name: "special-1", Value: "a z"},
+ `special-1="a z"`,
+ },
+ {
+ &Cookie{Name: "special-2", Value: " z"},
+ `special-2=" z"`,
+ },
+ {
+ &Cookie{Name: "special-3", Value: "a "},
+ `special-3="a "`,
+ },
+ {
+ &Cookie{Name: "special-4", Value: " "},
+ `special-4=" "`,
+ },
+ {
+ &Cookie{Name: "special-5", Value: "a,z"},
+ `special-5="a,z"`,
+ },
+ {
+ &Cookie{Name: "special-6", Value: ",z"},
+ `special-6=",z"`,
+ },
+ {
+ &Cookie{Name: "special-7", Value: "a,"},
+ `special-7="a,"`,
+ },
+ {
+ &Cookie{Name: "special-8", Value: ","},
+ `special-8=","`,
+ },
+ {
+ &Cookie{Name: "empty-value", Value: ""},
+ `empty-value=`,
+ },
+ {
+ nil,
+ ``,
+ },
+ {
+ &Cookie{Name: ""},
+ ``,
+ },
+ {
+ &Cookie{Name: "\t"},
+ ``,
+ },
+ {
+ &Cookie{Name: "\r"},
+ ``,
+ },
+ {
+ &Cookie{Name: "a\nb", Value: "v"},
+ ``,
+ },
+ {
+ &Cookie{Name: "a\nb", Value: "v"},
+ ``,
+ },
+ {
+ &Cookie{Name: "a\rb", Value: "v"},
+ ``,
+ },
+}
+
+func TestWriteSetCookies(t *testing.T) {
+ defer log.SetOutput(os.Stderr)
+ var logbuf strings.Builder
+ log.SetOutput(&logbuf)
+
+ for i, tt := range writeSetCookiesTests {
+ if g, e := tt.Cookie.String(), tt.Raw; g != e {
+ t.Errorf("Test %d, expecting:\n%s\nGot:\n%s\n", i, e, g)
+ continue
+ }
+ }
+
+ if got, sub := logbuf.String(), "dropping domain attribute"; !strings.Contains(got, sub) {
+ t.Errorf("Expected substring %q in log output. Got:\n%s", sub, got)
+ }
+}
+
+type headerOnlyResponseWriter Header
+
+func (ho headerOnlyResponseWriter) Header() Header {
+ return Header(ho)
+}
+
+func (ho headerOnlyResponseWriter) Write([]byte) (int, error) {
+ panic("NOIMPL")
+}
+
+func (ho headerOnlyResponseWriter) WriteHeader(int) {
+ panic("NOIMPL")
+}
+
+func TestSetCookie(t *testing.T) {
+ m := make(Header)
+ SetCookie(headerOnlyResponseWriter(m), &Cookie{Name: "cookie-1", Value: "one", Path: "/restricted/"})
+ SetCookie(headerOnlyResponseWriter(m), &Cookie{Name: "cookie-2", Value: "two", MaxAge: 3600})
+ if l := len(m["Set-Cookie"]); l != 2 {
+ t.Fatalf("expected %d cookies, got %d", 2, l)
+ }
+ if g, e := m["Set-Cookie"][0], "cookie-1=one; Path=/restricted/"; g != e {
+ t.Errorf("cookie #1: want %q, got %q", e, g)
+ }
+ if g, e := m["Set-Cookie"][1], "cookie-2=two; Max-Age=3600"; g != e {
+ t.Errorf("cookie #2: want %q, got %q", e, g)
+ }
+}
+
+var addCookieTests = []struct {
+ Cookies []*Cookie
+ Raw string
+}{
+ {
+ []*Cookie{},
+ "",
+ },
+ {
+ []*Cookie{{Name: "cookie-1", Value: "v$1"}},
+ "cookie-1=v$1",
+ },
+ {
+ []*Cookie{
+ {Name: "cookie-1", Value: "v$1"},
+ {Name: "cookie-2", Value: "v$2"},
+ {Name: "cookie-3", Value: "v$3"},
+ },
+ "cookie-1=v$1; cookie-2=v$2; cookie-3=v$3",
+ },
+}
+
+func TestAddCookie(t *testing.T) {
+ for i, tt := range addCookieTests {
+ req, _ := NewRequest("GET", "http://example.com/", nil)
+ for _, c := range tt.Cookies {
+ req.AddCookie(c)
+ }
+ if g := req.Header.Get("Cookie"); g != tt.Raw {
+ t.Errorf("Test %d:\nwant: %s\n got: %s\n", i, tt.Raw, g)
+ continue
+ }
+ }
+}
+
+var readSetCookiesTests = []struct {
+ Header Header
+ Cookies []*Cookie
+}{
+ {
+ Header{"Set-Cookie": {"Cookie-1=v$1"}},
+ []*Cookie{{Name: "Cookie-1", Value: "v$1", Raw: "Cookie-1=v$1"}},
+ },
+ {
+ Header{"Set-Cookie": {"NID=99=YsDT5i3E-CXax-; expires=Wed, 23-Nov-2011 01:05:03 GMT; path=/; domain=.google.ch; HttpOnly"}},
+ []*Cookie{{
+ Name: "NID",
+ Value: "99=YsDT5i3E-CXax-",
+ Path: "/",
+ Domain: ".google.ch",
+ HttpOnly: true,
+ Expires: time.Date(2011, 11, 23, 1, 5, 3, 0, time.UTC),
+ RawExpires: "Wed, 23-Nov-2011 01:05:03 GMT",
+ Raw: "NID=99=YsDT5i3E-CXax-; expires=Wed, 23-Nov-2011 01:05:03 GMT; path=/; domain=.google.ch; HttpOnly",
+ }},
+ },
+ {
+ Header{"Set-Cookie": {".ASPXAUTH=7E3AA; expires=Wed, 07-Mar-2012 14:25:06 GMT; path=/; HttpOnly"}},
+ []*Cookie{{
+ Name: ".ASPXAUTH",
+ Value: "7E3AA",
+ Path: "/",
+ Expires: time.Date(2012, 3, 7, 14, 25, 6, 0, time.UTC),
+ RawExpires: "Wed, 07-Mar-2012 14:25:06 GMT",
+ HttpOnly: true,
+ Raw: ".ASPXAUTH=7E3AA; expires=Wed, 07-Mar-2012 14:25:06 GMT; path=/; HttpOnly",
+ }},
+ },
+ {
+ Header{"Set-Cookie": {"ASP.NET_SessionId=foo; path=/; HttpOnly"}},
+ []*Cookie{{
+ Name: "ASP.NET_SessionId",
+ Value: "foo",
+ Path: "/",
+ HttpOnly: true,
+ Raw: "ASP.NET_SessionId=foo; path=/; HttpOnly",
+ }},
+ },
+ {
+ Header{"Set-Cookie": {"samesitedefault=foo; SameSite"}},
+ []*Cookie{{
+ Name: "samesitedefault",
+ Value: "foo",
+ SameSite: SameSiteDefaultMode,
+ Raw: "samesitedefault=foo; SameSite",
+ }},
+ },
+ {
+ Header{"Set-Cookie": {"samesiteinvalidisdefault=foo; SameSite=invalid"}},
+ []*Cookie{{
+ Name: "samesiteinvalidisdefault",
+ Value: "foo",
+ SameSite: SameSiteDefaultMode,
+ Raw: "samesiteinvalidisdefault=foo; SameSite=invalid",
+ }},
+ },
+ {
+ Header{"Set-Cookie": {"samesitelax=foo; SameSite=Lax"}},
+ []*Cookie{{
+ Name: "samesitelax",
+ Value: "foo",
+ SameSite: SameSiteLaxMode,
+ Raw: "samesitelax=foo; SameSite=Lax",
+ }},
+ },
+ {
+ Header{"Set-Cookie": {"samesitestrict=foo; SameSite=Strict"}},
+ []*Cookie{{
+ Name: "samesitestrict",
+ Value: "foo",
+ SameSite: SameSiteStrictMode,
+ Raw: "samesitestrict=foo; SameSite=Strict",
+ }},
+ },
+ {
+ Header{"Set-Cookie": {"samesitenone=foo; SameSite=None"}},
+ []*Cookie{{
+ Name: "samesitenone",
+ Value: "foo",
+ SameSite: SameSiteNoneMode,
+ Raw: "samesitenone=foo; SameSite=None",
+ }},
+ },
+ // Make sure we can properly read back the Set-Cookie headers we create
+ // for values containing spaces or commas:
+ {
+ Header{"Set-Cookie": {`special-1=a z`}},
+ []*Cookie{{Name: "special-1", Value: "a z", Raw: `special-1=a z`}},
+ },
+ {
+ Header{"Set-Cookie": {`special-2=" z"`}},
+ []*Cookie{{Name: "special-2", Value: " z", Raw: `special-2=" z"`}},
+ },
+ {
+ Header{"Set-Cookie": {`special-3="a "`}},
+ []*Cookie{{Name: "special-3", Value: "a ", Raw: `special-3="a "`}},
+ },
+ {
+ Header{"Set-Cookie": {`special-4=" "`}},
+ []*Cookie{{Name: "special-4", Value: " ", Raw: `special-4=" "`}},
+ },
+ {
+ Header{"Set-Cookie": {`special-5=a,z`}},
+ []*Cookie{{Name: "special-5", Value: "a,z", Raw: `special-5=a,z`}},
+ },
+ {
+ Header{"Set-Cookie": {`special-6=",z"`}},
+ []*Cookie{{Name: "special-6", Value: ",z", Raw: `special-6=",z"`}},
+ },
+ {
+ Header{"Set-Cookie": {`special-7=a,`}},
+ []*Cookie{{Name: "special-7", Value: "a,", Raw: `special-7=a,`}},
+ },
+ {
+ Header{"Set-Cookie": {`special-8=","`}},
+ []*Cookie{{Name: "special-8", Value: ",", Raw: `special-8=","`}},
+ },
+ // Make sure we can properly read back the Set-Cookie headers
+ // for names containing spaces:
+ {
+ Header{"Set-Cookie": {`special-9 =","`}},
+ []*Cookie{{Name: "special-9", Value: ",", Raw: `special-9 =","`}},
+ },
+
+ // TODO(bradfitz): users have reported seeing this in the
+ // wild, but do browsers handle it? RFC 6265 just says "don't
+ // do that" (section 3) and then never mentions header folding
+ // again.
+ // Header{"Set-Cookie": {"ASP.NET_SessionId=foo; path=/; HttpOnly, .ASPXAUTH=7E3AA; expires=Wed, 07-Mar-2012 14:25:06 GMT; path=/; HttpOnly"}},
+}
+
+func toJSON(v any) string {
+ b, err := json.Marshal(v)
+ if err != nil {
+ return fmt.Sprintf("%#v", v)
+ }
+ return string(b)
+}
+
+func TestReadSetCookies(t *testing.T) {
+ for i, tt := range readSetCookiesTests {
+ for n := 0; n < 2; n++ { // to verify readSetCookies doesn't mutate its input
+ c := readSetCookies(tt.Header)
+ if !reflect.DeepEqual(c, tt.Cookies) {
+ t.Errorf("#%d readSetCookies: have\n%s\nwant\n%s\n", i, toJSON(c), toJSON(tt.Cookies))
+ continue
+ }
+ }
+ }
+}
+
+var readCookiesTests = []struct {
+ Header Header
+ Filter string
+ Cookies []*Cookie
+}{
+ {
+ Header{"Cookie": {"Cookie-1=v$1", "c2=v2"}},
+ "",
+ []*Cookie{
+ {Name: "Cookie-1", Value: "v$1"},
+ {Name: "c2", Value: "v2"},
+ },
+ },
+ {
+ Header{"Cookie": {"Cookie-1=v$1", "c2=v2"}},
+ "c2",
+ []*Cookie{
+ {Name: "c2", Value: "v2"},
+ },
+ },
+ {
+ Header{"Cookie": {"Cookie-1=v$1; c2=v2"}},
+ "",
+ []*Cookie{
+ {Name: "Cookie-1", Value: "v$1"},
+ {Name: "c2", Value: "v2"},
+ },
+ },
+ {
+ Header{"Cookie": {"Cookie-1=v$1; c2=v2"}},
+ "c2",
+ []*Cookie{
+ {Name: "c2", Value: "v2"},
+ },
+ },
+ {
+ Header{"Cookie": {`Cookie-1="v$1"; c2="v2"`}},
+ "",
+ []*Cookie{
+ {Name: "Cookie-1", Value: "v$1"},
+ {Name: "c2", Value: "v2"},
+ },
+ },
+ {
+ Header{"Cookie": {`Cookie-1="v$1"; c2=v2;`}},
+ "",
+ []*Cookie{
+ {Name: "Cookie-1", Value: "v$1"},
+ {Name: "c2", Value: "v2"},
+ },
+ },
+ {
+ Header{"Cookie": {``}},
+ "",
+ []*Cookie{},
+ },
+}
+
+func TestReadCookies(t *testing.T) {
+ for i, tt := range readCookiesTests {
+ for n := 0; n < 2; n++ { // to verify readCookies doesn't mutate its input
+ c := readCookies(tt.Header, tt.Filter)
+ if !reflect.DeepEqual(c, tt.Cookies) {
+ t.Errorf("#%d readCookies:\nhave: %s\nwant: %s\n", i, toJSON(c), toJSON(tt.Cookies))
+ continue
+ }
+ }
+ }
+}
+
+func TestSetCookieDoubleQuotes(t *testing.T) {
+ res := &Response{Header: Header{}}
+ res.Header.Add("Set-Cookie", `quoted0=none; max-age=30`)
+ res.Header.Add("Set-Cookie", `quoted1="cookieValue"; max-age=31`)
+ res.Header.Add("Set-Cookie", `quoted2=cookieAV; max-age="32"`)
+ res.Header.Add("Set-Cookie", `quoted3="both"; max-age="33"`)
+ got := res.Cookies()
+ want := []*Cookie{
+ {Name: "quoted0", Value: "none", MaxAge: 30},
+ {Name: "quoted1", Value: "cookieValue", MaxAge: 31},
+ {Name: "quoted2", Value: "cookieAV"},
+ {Name: "quoted3", Value: "both"},
+ }
+ if len(got) != len(want) {
+ t.Fatalf("got %d cookies, want %d", len(got), len(want))
+ }
+ for i, w := range want {
+ g := got[i]
+ if g.Name != w.Name || g.Value != w.Value || g.MaxAge != w.MaxAge {
+ t.Errorf("cookie #%d:\ngot %v\nwant %v", i, g, w)
+ }
+ }
+}
+
+func TestCookieSanitizeValue(t *testing.T) {
+ defer log.SetOutput(os.Stderr)
+ var logbuf strings.Builder
+ log.SetOutput(&logbuf)
+
+ tests := []struct {
+ in, want string
+ }{
+ {"foo", "foo"},
+ {"foo;bar", "foobar"},
+ {"foo\\bar", "foobar"},
+ {"foo\"bar", "foobar"},
+ {"\x00\x7e\x7f\x80", "\x7e"},
+ {`"withquotes"`, "withquotes"},
+ {"a z", `"a z"`},
+ {" z", `" z"`},
+ {"a ", `"a "`},
+ {"a,z", `"a,z"`},
+ {",z", `",z"`},
+ {"a,", `"a,"`},
+ }
+ for _, tt := range tests {
+ if got := sanitizeCookieValue(tt.in); got != tt.want {
+ t.Errorf("sanitizeCookieValue(%q) = %q; want %q", tt.in, got, tt.want)
+ }
+ }
+
+ if got, sub := logbuf.String(), "dropping invalid bytes"; !strings.Contains(got, sub) {
+ t.Errorf("Expected substring %q in log output. Got:\n%s", sub, got)
+ }
+}
+
+func TestCookieSanitizePath(t *testing.T) {
+ defer log.SetOutput(os.Stderr)
+ var logbuf strings.Builder
+ log.SetOutput(&logbuf)
+
+ tests := []struct {
+ in, want string
+ }{
+ {"/path", "/path"},
+ {"/path with space/", "/path with space/"},
+ {"/just;no;semicolon\x00orstuff/", "/justnosemicolonorstuff/"},
+ }
+ for _, tt := range tests {
+ if got := sanitizeCookiePath(tt.in); got != tt.want {
+ t.Errorf("sanitizeCookiePath(%q) = %q; want %q", tt.in, got, tt.want)
+ }
+ }
+
+ if got, sub := logbuf.String(), "dropping invalid bytes"; !strings.Contains(got, sub) {
+ t.Errorf("Expected substring %q in log output. Got:\n%s", sub, got)
+ }
+}
+
+func TestCookieValid(t *testing.T) {
+ tests := []struct {
+ cookie *Cookie
+ valid bool
+ }{
+ {nil, false},
+ {&Cookie{Name: ""}, false},
+ {&Cookie{Name: "invalid-value", Value: "foo\"bar"}, false},
+ {&Cookie{Name: "invalid-path", Path: "/foo;bar/"}, false},
+ {&Cookie{Name: "invalid-domain", Domain: "example.com:80"}, false},
+ {&Cookie{Name: "invalid-expiry", Value: "", Expires: time.Date(1600, 1, 1, 1, 1, 1, 1, time.UTC)}, false},
+ {&Cookie{Name: "valid-empty"}, true},
+ {&Cookie{Name: "valid-expires", Value: "foo", Path: "/bar", Domain: "example.com", Expires: time.Unix(0, 0)}, true},
+ {&Cookie{Name: "valid-max-age", Value: "foo", Path: "/bar", Domain: "example.com", MaxAge: 60}, true},
+ {&Cookie{Name: "valid-all-fields", Value: "foo", Path: "/bar", Domain: "example.com", Expires: time.Unix(0, 0), MaxAge: 0}, true},
+ }
+
+ for _, tt := range tests {
+ err := tt.cookie.Valid()
+ if err != nil && tt.valid {
+ t.Errorf("%#v.Valid() returned error %v; want nil", tt.cookie, err)
+ }
+ if err == nil && !tt.valid {
+ t.Errorf("%#v.Valid() returned nil; want error", tt.cookie)
+ }
+ }
+}
+
+func BenchmarkCookieString(b *testing.B) {
+ const wantCookieString = `cookie-9=i3e01nf61b6t23bvfmplnanol3; Path=/restricted/; Domain=example.com; Expires=Tue, 10 Nov 2009 23:00:00 GMT; Max-Age=3600`
+ c := &Cookie{
+ Name: "cookie-9",
+ Value: "i3e01nf61b6t23bvfmplnanol3",
+ Expires: time.Unix(1257894000, 0),
+ Path: "/restricted/",
+ Domain: ".example.com",
+ MaxAge: 3600,
+ }
+ var benchmarkCookieString string
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ benchmarkCookieString = c.String()
+ }
+ if have, want := benchmarkCookieString, wantCookieString; have != want {
+ b.Fatalf("Have: %v Want: %v", have, want)
+ }
+}
+
+func BenchmarkReadSetCookies(b *testing.B) {
+ header := Header{
+ "Set-Cookie": {
+ "NID=99=YsDT5i3E-CXax-; expires=Wed, 23-Nov-2011 01:05:03 GMT; path=/; domain=.google.ch; HttpOnly",
+ ".ASPXAUTH=7E3AA; expires=Wed, 07-Mar-2012 14:25:06 GMT; path=/; HttpOnly",
+ },
+ }
+ wantCookies := []*Cookie{
+ {
+ Name: "NID",
+ Value: "99=YsDT5i3E-CXax-",
+ Path: "/",
+ Domain: ".google.ch",
+ HttpOnly: true,
+ Expires: time.Date(2011, 11, 23, 1, 5, 3, 0, time.UTC),
+ RawExpires: "Wed, 23-Nov-2011 01:05:03 GMT",
+ Raw: "NID=99=YsDT5i3E-CXax-; expires=Wed, 23-Nov-2011 01:05:03 GMT; path=/; domain=.google.ch; HttpOnly",
+ },
+ {
+ Name: ".ASPXAUTH",
+ Value: "7E3AA",
+ Path: "/",
+ Expires: time.Date(2012, 3, 7, 14, 25, 6, 0, time.UTC),
+ RawExpires: "Wed, 07-Mar-2012 14:25:06 GMT",
+ HttpOnly: true,
+ Raw: ".ASPXAUTH=7E3AA; expires=Wed, 07-Mar-2012 14:25:06 GMT; path=/; HttpOnly",
+ },
+ }
+ var c []*Cookie
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ c = readSetCookies(header)
+ }
+ if !reflect.DeepEqual(c, wantCookies) {
+ b.Fatalf("readSetCookies:\nhave: %s\nwant: %s\n", toJSON(c), toJSON(wantCookies))
+ }
+}
+
+func BenchmarkReadCookies(b *testing.B) {
+ header := Header{
+ "Cookie": {
+ `de=; client_region=0; rpld1=0:hispeed.ch|20:che|21:zh|22:zurich|23:47.36|24:8.53|; rpld0=1:08|; backplane-channel=newspaper.com:1471; devicetype=0; osfam=0; rplmct=2; s_pers=%20s_vmonthnum%3D1472680800496%2526vn%253D1%7C1472680800496%3B%20s_nr%3D1471686767664-New%7C1474278767664%3B%20s_lv%3D1471686767669%7C1566294767669%3B%20s_lv_s%3DFirst%2520Visit%7C1471688567669%3B%20s_monthinvisit%3Dtrue%7C1471688567677%3B%20gvp_p5%3Dsports%253Ablog%253Aearly-lead%2520-%2520184693%2520-%252020160820%2520-%2520u-s%7C1471688567681%3B%20gvp_p51%3Dwp%2520-%2520sports%7C1471688567684%3B; s_sess=%20s_wp_ep%3Dhomepage%3B%20s._ref%3Dhttps%253A%252F%252Fwww.google.ch%252F%3B%20s_cc%3Dtrue%3B%20s_ppvl%3Dsports%25253Ablog%25253Aearly-lead%252520-%252520184693%252520-%25252020160820%252520-%252520u-lawyer%252C12%252C12%252C502%252C1231%252C502%252C1680%252C1050%252C2%252CP%3B%20s_ppv%3Dsports%25253Ablog%25253Aearly-lead%252520-%252520184693%252520-%25252020160820%252520-%252520u-s-lawyer%252C12%252C12%252C502%252C1231%252C502%252C1680%252C1050%252C2%252CP%3B%20s_dslv%3DFirst%2520Visit%3B%20s_sq%3Dwpninewspapercom%253D%252526pid%25253Dsports%2525253Ablog%2525253Aearly-lead%25252520-%25252520184693%25252520-%2525252020160820%25252520-%25252520u-s%252526pidt%25253D1%252526oid%25253Dhttps%2525253A%2525252F%2525252Fwww.newspaper.com%2525252F%2525253Fnid%2525253Dmenu_nav_homepage%252526ot%25253DA%3B`,
+ },
+ }
+ wantCookies := []*Cookie{
+ {Name: "de", Value: ""},
+ {Name: "client_region", Value: "0"},
+ {Name: "rpld1", Value: "0:hispeed.ch|20:che|21:zh|22:zurich|23:47.36|24:8.53|"},
+ {Name: "rpld0", Value: "1:08|"},
+ {Name: "backplane-channel", Value: "newspaper.com:1471"},
+ {Name: "devicetype", Value: "0"},
+ {Name: "osfam", Value: "0"},
+ {Name: "rplmct", Value: "2"},
+ {Name: "s_pers", Value: "%20s_vmonthnum%3D1472680800496%2526vn%253D1%7C1472680800496%3B%20s_nr%3D1471686767664-New%7C1474278767664%3B%20s_lv%3D1471686767669%7C1566294767669%3B%20s_lv_s%3DFirst%2520Visit%7C1471688567669%3B%20s_monthinvisit%3Dtrue%7C1471688567677%3B%20gvp_p5%3Dsports%253Ablog%253Aearly-lead%2520-%2520184693%2520-%252020160820%2520-%2520u-s%7C1471688567681%3B%20gvp_p51%3Dwp%2520-%2520sports%7C1471688567684%3B"},
+ {Name: "s_sess", Value: "%20s_wp_ep%3Dhomepage%3B%20s._ref%3Dhttps%253A%252F%252Fwww.google.ch%252F%3B%20s_cc%3Dtrue%3B%20s_ppvl%3Dsports%25253Ablog%25253Aearly-lead%252520-%252520184693%252520-%25252020160820%252520-%252520u-lawyer%252C12%252C12%252C502%252C1231%252C502%252C1680%252C1050%252C2%252CP%3B%20s_ppv%3Dsports%25253Ablog%25253Aearly-lead%252520-%252520184693%252520-%25252020160820%252520-%252520u-s-lawyer%252C12%252C12%252C502%252C1231%252C502%252C1680%252C1050%252C2%252CP%3B%20s_dslv%3DFirst%2520Visit%3B%20s_sq%3Dwpninewspapercom%253D%252526pid%25253Dsports%2525253Ablog%2525253Aearly-lead%25252520-%25252520184693%25252520-%2525252020160820%25252520-%25252520u-s%252526pidt%25253D1%252526oid%25253Dhttps%2525253A%2525252F%2525252Fwww.newspaper.com%2525252F%2525253Fnid%2525253Dmenu_nav_homepage%252526ot%25253DA%3B"},
+ }
+ var c []*Cookie
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ c = readCookies(header, "")
+ }
+ if !reflect.DeepEqual(c, wantCookies) {
+ b.Fatalf("readCookies:\nhave: %s\nwant: %s\n", toJSON(c), toJSON(wantCookies))
+ }
+}
diff --git a/src/net/http/cookiejar/dummy_publicsuffix_test.go b/src/net/http/cookiejar/dummy_publicsuffix_test.go
new file mode 100644
index 0000000..9b31173
--- /dev/null
+++ b/src/net/http/cookiejar/dummy_publicsuffix_test.go
@@ -0,0 +1,21 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package cookiejar_test
+
+import "net/http/cookiejar"
+
+type dummypsl struct {
+ List cookiejar.PublicSuffixList
+}
+
+func (dummypsl) PublicSuffix(domain string) string {
+ return domain
+}
+
+func (dummypsl) String() string {
+ return "dummy"
+}
+
+var publicsuffix = dummypsl{}
diff --git a/src/net/http/cookiejar/example_test.go b/src/net/http/cookiejar/example_test.go
new file mode 100644
index 0000000..91728ca
--- /dev/null
+++ b/src/net/http/cookiejar/example_test.go
@@ -0,0 +1,65 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package cookiejar_test
+
+import (
+ "fmt"
+ "log"
+ "net/http"
+ "net/http/cookiejar"
+ "net/http/httptest"
+ "net/url"
+)
+
+func ExampleNew() {
+ // Start a server to give us cookies.
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if cookie, err := r.Cookie("Flavor"); err != nil {
+ http.SetCookie(w, &http.Cookie{Name: "Flavor", Value: "Chocolate Chip"})
+ } else {
+ cookie.Value = "Oatmeal Raisin"
+ http.SetCookie(w, cookie)
+ }
+ }))
+ defer ts.Close()
+
+ u, err := url.Parse(ts.URL)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ // All users of cookiejar should import "golang.org/x/net/publicsuffix"
+ jar, err := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List})
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ client := &http.Client{
+ Jar: jar,
+ }
+
+ if _, err = client.Get(u.String()); err != nil {
+ log.Fatal(err)
+ }
+
+ fmt.Println("After 1st request:")
+ for _, cookie := range jar.Cookies(u) {
+ fmt.Printf(" %s: %s\n", cookie.Name, cookie.Value)
+ }
+
+ if _, err = client.Get(u.String()); err != nil {
+ log.Fatal(err)
+ }
+
+ fmt.Println("After 2nd request:")
+ for _, cookie := range jar.Cookies(u) {
+ fmt.Printf(" %s: %s\n", cookie.Name, cookie.Value)
+ }
+ // Output:
+ // After 1st request:
+ // Flavor: Chocolate Chip
+ // After 2nd request:
+ // Flavor: Oatmeal Raisin
+}
diff --git a/src/net/http/cookiejar/jar.go b/src/net/http/cookiejar/jar.go
new file mode 100644
index 0000000..4b16266
--- /dev/null
+++ b/src/net/http/cookiejar/jar.go
@@ -0,0 +1,547 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package cookiejar implements an in-memory RFC 6265-compliant http.CookieJar.
+package cookiejar
+
+import (
+ "errors"
+ "fmt"
+ "net"
+ "net/http"
+ "net/http/internal/ascii"
+ "net/url"
+ "sort"
+ "strings"
+ "sync"
+ "time"
+)
+
+// PublicSuffixList provides the public suffix of a domain. For example:
+// - the public suffix of "example.com" is "com",
+// - the public suffix of "foo1.foo2.foo3.co.uk" is "co.uk", and
+// - the public suffix of "bar.pvt.k12.ma.us" is "pvt.k12.ma.us".
+//
+// Implementations of PublicSuffixList must be safe for concurrent use by
+// multiple goroutines.
+//
+// An implementation that always returns "" is valid and may be useful for
+// testing but it is not secure: it means that the HTTP server for foo.com can
+// set a cookie for bar.com.
+//
+// A public suffix list implementation is in the package
+// golang.org/x/net/publicsuffix.
+type PublicSuffixList interface {
+ // PublicSuffix returns the public suffix of domain.
+ //
+ // TODO: specify which of the caller and callee is responsible for IP
+ // addresses, for leading and trailing dots, for case sensitivity, and
+ // for IDN/Punycode.
+ PublicSuffix(domain string) string
+
+ // String returns a description of the source of this public suffix
+ // list. The description will typically contain something like a time
+ // stamp or version number.
+ String() string
+}
+
+// Options are the options for creating a new Jar.
+type Options struct {
+ // PublicSuffixList is the public suffix list that determines whether
+ // an HTTP server can set a cookie for a domain.
+ //
+ // A nil value is valid and may be useful for testing but it is not
+ // secure: it means that the HTTP server for foo.co.uk can set a cookie
+ // for bar.co.uk.
+ PublicSuffixList PublicSuffixList
+}
+
+// Jar implements the http.CookieJar interface from the net/http package.
+type Jar struct {
+ psList PublicSuffixList
+
+ // mu locks the remaining fields.
+ mu sync.Mutex
+
+ // entries is a set of entries, keyed by their eTLD+1 and subkeyed by
+ // their name/domain/path.
+ entries map[string]map[string]entry
+
+ // nextSeqNum is the next sequence number assigned to a new cookie
+ // created SetCookies.
+ nextSeqNum uint64
+}
+
+// New returns a new cookie jar. A nil *Options is equivalent to a zero
+// Options.
+func New(o *Options) (*Jar, error) {
+ jar := &Jar{
+ entries: make(map[string]map[string]entry),
+ }
+ if o != nil {
+ jar.psList = o.PublicSuffixList
+ }
+ return jar, nil
+}
+
+// entry is the internal representation of a cookie.
+//
+// This struct type is not used outside of this package per se, but the exported
+// fields are those of RFC 6265.
+type entry struct {
+ Name string
+ Value string
+ Domain string
+ Path string
+ SameSite string
+ Secure bool
+ HttpOnly bool
+ Persistent bool
+ HostOnly bool
+ Expires time.Time
+ Creation time.Time
+ LastAccess time.Time
+
+ // seqNum is a sequence number so that Cookies returns cookies in a
+ // deterministic order, even for cookies that have equal Path length and
+ // equal Creation time. This simplifies testing.
+ seqNum uint64
+}
+
+// id returns the domain;path;name triple of e as an id.
+func (e *entry) id() string {
+ return fmt.Sprintf("%s;%s;%s", e.Domain, e.Path, e.Name)
+}
+
+// shouldSend determines whether e's cookie qualifies to be included in a
+// request to host/path. It is the caller's responsibility to check if the
+// cookie is expired.
+func (e *entry) shouldSend(https bool, host, path string) bool {
+ return e.domainMatch(host) && e.pathMatch(path) && (https || !e.Secure)
+}
+
+// domainMatch checks whether e's Domain allows sending e back to host.
+// It differs from "domain-match" of RFC 6265 section 5.1.3 because we treat
+// a cookie with an IP address in the Domain always as a host cookie.
+func (e *entry) domainMatch(host string) bool {
+ if e.Domain == host {
+ return true
+ }
+ return !e.HostOnly && hasDotSuffix(host, e.Domain)
+}
+
+// pathMatch implements "path-match" according to RFC 6265 section 5.1.4.
+func (e *entry) pathMatch(requestPath string) bool {
+ if requestPath == e.Path {
+ return true
+ }
+ if strings.HasPrefix(requestPath, e.Path) {
+ if e.Path[len(e.Path)-1] == '/' {
+ return true // The "/any/" matches "/any/path" case.
+ } else if requestPath[len(e.Path)] == '/' {
+ return true // The "/any" matches "/any/path" case.
+ }
+ }
+ return false
+}
+
+// hasDotSuffix reports whether s ends in "."+suffix.
+func hasDotSuffix(s, suffix string) bool {
+ return len(s) > len(suffix) && s[len(s)-len(suffix)-1] == '.' && s[len(s)-len(suffix):] == suffix
+}
+
+// Cookies implements the Cookies method of the http.CookieJar interface.
+//
+// It returns an empty slice if the URL's scheme is not HTTP or HTTPS.
+func (j *Jar) Cookies(u *url.URL) (cookies []*http.Cookie) {
+ return j.cookies(u, time.Now())
+}
+
+// cookies is like Cookies but takes the current time as a parameter.
+func (j *Jar) cookies(u *url.URL, now time.Time) (cookies []*http.Cookie) {
+ if u.Scheme != "http" && u.Scheme != "https" {
+ return cookies
+ }
+ host, err := canonicalHost(u.Host)
+ if err != nil {
+ return cookies
+ }
+ key := jarKey(host, j.psList)
+
+ j.mu.Lock()
+ defer j.mu.Unlock()
+
+ submap := j.entries[key]
+ if submap == nil {
+ return cookies
+ }
+
+ https := u.Scheme == "https"
+ path := u.Path
+ if path == "" {
+ path = "/"
+ }
+
+ modified := false
+ var selected []entry
+ for id, e := range submap {
+ if e.Persistent && !e.Expires.After(now) {
+ delete(submap, id)
+ modified = true
+ continue
+ }
+ if !e.shouldSend(https, host, path) {
+ continue
+ }
+ e.LastAccess = now
+ submap[id] = e
+ selected = append(selected, e)
+ modified = true
+ }
+ if modified {
+ if len(submap) == 0 {
+ delete(j.entries, key)
+ } else {
+ j.entries[key] = submap
+ }
+ }
+
+ // sort according to RFC 6265 section 5.4 point 2: by longest
+ // path and then by earliest creation time.
+ sort.Slice(selected, func(i, j int) bool {
+ s := selected
+ if len(s[i].Path) != len(s[j].Path) {
+ return len(s[i].Path) > len(s[j].Path)
+ }
+ if ret := s[i].Creation.Compare(s[j].Creation); ret != 0 {
+ return ret < 0
+ }
+ return s[i].seqNum < s[j].seqNum
+ })
+ for _, e := range selected {
+ cookies = append(cookies, &http.Cookie{Name: e.Name, Value: e.Value})
+ }
+
+ return cookies
+}
+
+// SetCookies implements the SetCookies method of the http.CookieJar interface.
+//
+// It does nothing if the URL's scheme is not HTTP or HTTPS.
+func (j *Jar) SetCookies(u *url.URL, cookies []*http.Cookie) {
+ j.setCookies(u, cookies, time.Now())
+}
+
+// setCookies is like SetCookies but takes the current time as parameter.
+func (j *Jar) setCookies(u *url.URL, cookies []*http.Cookie, now time.Time) {
+ if len(cookies) == 0 {
+ return
+ }
+ if u.Scheme != "http" && u.Scheme != "https" {
+ return
+ }
+ host, err := canonicalHost(u.Host)
+ if err != nil {
+ return
+ }
+ key := jarKey(host, j.psList)
+ defPath := defaultPath(u.Path)
+
+ j.mu.Lock()
+ defer j.mu.Unlock()
+
+ submap := j.entries[key]
+
+ modified := false
+ for _, cookie := range cookies {
+ e, remove, err := j.newEntry(cookie, now, defPath, host)
+ if err != nil {
+ continue
+ }
+ id := e.id()
+ if remove {
+ if submap != nil {
+ if _, ok := submap[id]; ok {
+ delete(submap, id)
+ modified = true
+ }
+ }
+ continue
+ }
+ if submap == nil {
+ submap = make(map[string]entry)
+ }
+
+ if old, ok := submap[id]; ok {
+ e.Creation = old.Creation
+ e.seqNum = old.seqNum
+ } else {
+ e.Creation = now
+ e.seqNum = j.nextSeqNum
+ j.nextSeqNum++
+ }
+ e.LastAccess = now
+ submap[id] = e
+ modified = true
+ }
+
+ if modified {
+ if len(submap) == 0 {
+ delete(j.entries, key)
+ } else {
+ j.entries[key] = submap
+ }
+ }
+}
+
+// canonicalHost strips port from host if present and returns the canonicalized
+// host name.
+func canonicalHost(host string) (string, error) {
+ var err error
+ if hasPort(host) {
+ host, _, err = net.SplitHostPort(host)
+ if err != nil {
+ return "", err
+ }
+ }
+ // Strip trailing dot from fully qualified domain names.
+ host = strings.TrimSuffix(host, ".")
+ encoded, err := toASCII(host)
+ if err != nil {
+ return "", err
+ }
+ // We know this is ascii, no need to check.
+ lower, _ := ascii.ToLower(encoded)
+ return lower, nil
+}
+
+// hasPort reports whether host contains a port number. host may be a host
+// name, an IPv4 or an IPv6 address.
+func hasPort(host string) bool {
+ colons := strings.Count(host, ":")
+ if colons == 0 {
+ return false
+ }
+ if colons == 1 {
+ return true
+ }
+ return host[0] == '[' && strings.Contains(host, "]:")
+}
+
+// jarKey returns the key to use for a jar.
+func jarKey(host string, psl PublicSuffixList) string {
+ if isIP(host) {
+ return host
+ }
+
+ var i int
+ if psl == nil {
+ i = strings.LastIndex(host, ".")
+ if i <= 0 {
+ return host
+ }
+ } else {
+ suffix := psl.PublicSuffix(host)
+ if suffix == host {
+ return host
+ }
+ i = len(host) - len(suffix)
+ if i <= 0 || host[i-1] != '.' {
+ // The provided public suffix list psl is broken.
+ // Storing cookies under host is a safe stopgap.
+ return host
+ }
+ // Only len(suffix) is used to determine the jar key from
+ // here on, so it is okay if psl.PublicSuffix("www.buggy.psl")
+ // returns "com" as the jar key is generated from host.
+ }
+ prevDot := strings.LastIndex(host[:i-1], ".")
+ return host[prevDot+1:]
+}
+
+// isIP reports whether host is an IP address.
+func isIP(host string) bool {
+ if strings.ContainsAny(host, ":%") {
+ // Probable IPv6 address.
+ // Hostnames can't contain : or %, so this is definitely not a valid host.
+ // Treating it as an IP is the more conservative option, and avoids the risk
+ // of interpeting ::1%.www.example.com as a subtomain of www.example.com.
+ return true
+ }
+ return net.ParseIP(host) != nil
+}
+
+// defaultPath returns the directory part of a URL's path according to
+// RFC 6265 section 5.1.4.
+func defaultPath(path string) string {
+ if len(path) == 0 || path[0] != '/' {
+ return "/" // Path is empty or malformed.
+ }
+
+ i := strings.LastIndex(path, "/") // Path starts with "/", so i != -1.
+ if i == 0 {
+ return "/" // Path has the form "/abc".
+ }
+ return path[:i] // Path is either of form "/abc/xyz" or "/abc/xyz/".
+}
+
+// newEntry creates an entry from an http.Cookie c. now is the current time and
+// is compared to c.Expires to determine deletion of c. defPath and host are the
+// default-path and the canonical host name of the URL c was received from.
+//
+// remove records whether the jar should delete this cookie, as it has already
+// expired with respect to now. In this case, e may be incomplete, but it will
+// be valid to call e.id (which depends on e's Name, Domain and Path).
+//
+// A malformed c.Domain will result in an error.
+func (j *Jar) newEntry(c *http.Cookie, now time.Time, defPath, host string) (e entry, remove bool, err error) {
+ e.Name = c.Name
+
+ if c.Path == "" || c.Path[0] != '/' {
+ e.Path = defPath
+ } else {
+ e.Path = c.Path
+ }
+
+ e.Domain, e.HostOnly, err = j.domainAndType(host, c.Domain)
+ if err != nil {
+ return e, false, err
+ }
+
+ // MaxAge takes precedence over Expires.
+ if c.MaxAge < 0 {
+ return e, true, nil
+ } else if c.MaxAge > 0 {
+ e.Expires = now.Add(time.Duration(c.MaxAge) * time.Second)
+ e.Persistent = true
+ } else {
+ if c.Expires.IsZero() {
+ e.Expires = endOfTime
+ e.Persistent = false
+ } else {
+ if !c.Expires.After(now) {
+ return e, true, nil
+ }
+ e.Expires = c.Expires
+ e.Persistent = true
+ }
+ }
+
+ e.Value = c.Value
+ e.Secure = c.Secure
+ e.HttpOnly = c.HttpOnly
+
+ switch c.SameSite {
+ case http.SameSiteDefaultMode:
+ e.SameSite = "SameSite"
+ case http.SameSiteStrictMode:
+ e.SameSite = "SameSite=Strict"
+ case http.SameSiteLaxMode:
+ e.SameSite = "SameSite=Lax"
+ }
+
+ return e, false, nil
+}
+
+var (
+ errIllegalDomain = errors.New("cookiejar: illegal cookie domain attribute")
+ errMalformedDomain = errors.New("cookiejar: malformed cookie domain attribute")
+ errNoHostname = errors.New("cookiejar: no host name available (IP only)")
+)
+
+// endOfTime is the time when session (non-persistent) cookies expire.
+// This instant is representable in most date/time formats (not just
+// Go's time.Time) and should be far enough in the future.
+var endOfTime = time.Date(9999, 12, 31, 23, 59, 59, 0, time.UTC)
+
+// domainAndType determines the cookie's domain and hostOnly attribute.
+func (j *Jar) domainAndType(host, domain string) (string, bool, error) {
+ if domain == "" {
+ // No domain attribute in the SetCookie header indicates a
+ // host cookie.
+ return host, true, nil
+ }
+
+ if isIP(host) {
+ // RFC 6265 is not super clear here, a sensible interpretation
+ // is that cookies with an IP address in the domain-attribute
+ // are allowed.
+
+ // RFC 6265 section 5.2.3 mandates to strip an optional leading
+ // dot in the domain-attribute before processing the cookie.
+ //
+ // Most browsers don't do that for IP addresses, only curl
+ // (version 7.54) and IE (version 11) do not reject a
+ // Set-Cookie: a=1; domain=.127.0.0.1
+ // This leading dot is optional and serves only as hint for
+ // humans to indicate that a cookie with "domain=.bbc.co.uk"
+ // would be sent to every subdomain of bbc.co.uk.
+ // It just doesn't make sense on IP addresses.
+ // The other processing and validation steps in RFC 6265 just
+ // collapse to:
+ if host != domain {
+ return "", false, errIllegalDomain
+ }
+
+ // According to RFC 6265 such cookies should be treated as
+ // domain cookies.
+ // As there are no subdomains of an IP address the treatment
+ // according to RFC 6265 would be exactly the same as that of
+ // a host-only cookie. Contemporary browsers (and curl) do
+ // allows such cookies but treat them as host-only cookies.
+ // So do we as it just doesn't make sense to label them as
+ // domain cookies when there is no domain; the whole notion of
+ // domain cookies requires a domain name to be well defined.
+ return host, true, nil
+ }
+
+ // From here on: If the cookie is valid, it is a domain cookie (with
+ // the one exception of a public suffix below).
+ // See RFC 6265 section 5.2.3.
+ if domain[0] == '.' {
+ domain = domain[1:]
+ }
+
+ if len(domain) == 0 || domain[0] == '.' {
+ // Received either "Domain=." or "Domain=..some.thing",
+ // both are illegal.
+ return "", false, errMalformedDomain
+ }
+
+ domain, isASCII := ascii.ToLower(domain)
+ if !isASCII {
+ // Received non-ASCII domain, e.g. "perché.com" instead of "xn--perch-fsa.com"
+ return "", false, errMalformedDomain
+ }
+
+ if domain[len(domain)-1] == '.' {
+ // We received stuff like "Domain=www.example.com.".
+ // Browsers do handle such stuff (actually differently) but
+ // RFC 6265 seems to be clear here (e.g. section 4.1.2.3) in
+ // requiring a reject. 4.1.2.3 is not normative, but
+ // "Domain Matching" (5.1.3) and "Canonicalized Host Names"
+ // (5.1.2) are.
+ return "", false, errMalformedDomain
+ }
+
+ // See RFC 6265 section 5.3 #5.
+ if j.psList != nil {
+ if ps := j.psList.PublicSuffix(domain); ps != "" && !hasDotSuffix(domain, ps) {
+ if host == domain {
+ // This is the one exception in which a cookie
+ // with a domain attribute is a host cookie.
+ return host, true, nil
+ }
+ return "", false, errIllegalDomain
+ }
+ }
+
+ // The domain must domain-match host: www.mycompany.com cannot
+ // set cookies for .ourcompetitors.com.
+ if host != domain && !hasDotSuffix(host, domain) {
+ return "", false, errIllegalDomain
+ }
+
+ return domain, false, nil
+}
diff --git a/src/net/http/cookiejar/jar_test.go b/src/net/http/cookiejar/jar_test.go
new file mode 100644
index 0000000..251f7c1
--- /dev/null
+++ b/src/net/http/cookiejar/jar_test.go
@@ -0,0 +1,1355 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package cookiejar
+
+import (
+ "fmt"
+ "net/http"
+ "net/url"
+ "sort"
+ "strings"
+ "testing"
+ "time"
+)
+
+// tNow is the synthetic current time used as now during testing.
+var tNow = time.Date(2013, 1, 1, 12, 0, 0, 0, time.UTC)
+
+// testPSL implements PublicSuffixList with just two rules: "co.uk"
+// and the default rule "*".
+// The implementation has two intentional bugs:
+//
+// PublicSuffix("www.buggy.psl") == "xy"
+// PublicSuffix("www2.buggy.psl") == "com"
+type testPSL struct{}
+
+func (testPSL) String() string {
+ return "testPSL"
+}
+func (testPSL) PublicSuffix(d string) string {
+ if d == "co.uk" || strings.HasSuffix(d, ".co.uk") {
+ return "co.uk"
+ }
+ if d == "www.buggy.psl" {
+ return "xy"
+ }
+ if d == "www2.buggy.psl" {
+ return "com"
+ }
+ return d[strings.LastIndex(d, ".")+1:]
+}
+
+// newTestJar creates an empty Jar with testPSL as the public suffix list.
+func newTestJar() *Jar {
+ jar, err := New(&Options{PublicSuffixList: testPSL{}})
+ if err != nil {
+ panic(err)
+ }
+ return jar
+}
+
+var hasDotSuffixTests = [...]struct {
+ s, suffix string
+}{
+ {"", ""},
+ {"", "."},
+ {"", "x"},
+ {".", ""},
+ {".", "."},
+ {".", ".."},
+ {".", "x"},
+ {".", "x."},
+ {".", ".x"},
+ {".", ".x."},
+ {"x", ""},
+ {"x", "."},
+ {"x", ".."},
+ {"x", "x"},
+ {"x", "x."},
+ {"x", ".x"},
+ {"x", ".x."},
+ {".x", ""},
+ {".x", "."},
+ {".x", ".."},
+ {".x", "x"},
+ {".x", "x."},
+ {".x", ".x"},
+ {".x", ".x."},
+ {"x.", ""},
+ {"x.", "."},
+ {"x.", ".."},
+ {"x.", "x"},
+ {"x.", "x."},
+ {"x.", ".x"},
+ {"x.", ".x."},
+ {"com", ""},
+ {"com", "m"},
+ {"com", "om"},
+ {"com", "com"},
+ {"com", ".com"},
+ {"com", "x.com"},
+ {"com", "xcom"},
+ {"com", "xorg"},
+ {"com", "org"},
+ {"com", "rg"},
+ {"foo.com", ""},
+ {"foo.com", "m"},
+ {"foo.com", "om"},
+ {"foo.com", "com"},
+ {"foo.com", ".com"},
+ {"foo.com", "o.com"},
+ {"foo.com", "oo.com"},
+ {"foo.com", "foo.com"},
+ {"foo.com", ".foo.com"},
+ {"foo.com", "x.foo.com"},
+ {"foo.com", "xfoo.com"},
+ {"foo.com", "xfoo.org"},
+ {"foo.com", "foo.org"},
+ {"foo.com", "oo.org"},
+ {"foo.com", "o.org"},
+ {"foo.com", ".org"},
+ {"foo.com", "org"},
+ {"foo.com", "rg"},
+}
+
+func TestHasDotSuffix(t *testing.T) {
+ for _, tc := range hasDotSuffixTests {
+ got := hasDotSuffix(tc.s, tc.suffix)
+ want := strings.HasSuffix(tc.s, "."+tc.suffix)
+ if got != want {
+ t.Errorf("s=%q, suffix=%q: got %v, want %v", tc.s, tc.suffix, got, want)
+ }
+ }
+}
+
+var canonicalHostTests = map[string]string{
+ "www.example.com": "www.example.com",
+ "WWW.EXAMPLE.COM": "www.example.com",
+ "wWw.eXAmple.CoM": "www.example.com",
+ "www.example.com:80": "www.example.com",
+ "192.168.0.10": "192.168.0.10",
+ "192.168.0.5:8080": "192.168.0.5",
+ "2001:4860:0:2001::68": "2001:4860:0:2001::68",
+ "[2001:4860:0:::68]:8080": "2001:4860:0:::68",
+ "www.bücher.de": "www.xn--bcher-kva.de",
+ "www.example.com.": "www.example.com",
+ // TODO: Fix canonicalHost so that all of the following malformed
+ // domain names trigger an error. (This list is not exhaustive, e.g.
+ // malformed internationalized domain names are missing.)
+ ".": "",
+ "..": ".",
+ "...": "..",
+ ".net": ".net",
+ ".net.": ".net",
+ "a..": "a.",
+ "b.a..": "b.a.",
+ "weird.stuff...": "weird.stuff..",
+ "[bad.unmatched.bracket:": "error",
+}
+
+func TestCanonicalHost(t *testing.T) {
+ for h, want := range canonicalHostTests {
+ got, err := canonicalHost(h)
+ if want == "error" {
+ if err == nil {
+ t.Errorf("%q: got %q and nil error, want non-nil", h, got)
+ }
+ continue
+ }
+ if err != nil {
+ t.Errorf("%q: %v", h, err)
+ continue
+ }
+ if got != want {
+ t.Errorf("%q: got %q, want %q", h, got, want)
+ continue
+ }
+ }
+}
+
+var hasPortTests = map[string]bool{
+ "www.example.com": false,
+ "www.example.com:80": true,
+ "127.0.0.1": false,
+ "127.0.0.1:8080": true,
+ "2001:4860:0:2001::68": false,
+ "[2001::0:::68]:80": true,
+}
+
+func TestHasPort(t *testing.T) {
+ for host, want := range hasPortTests {
+ if got := hasPort(host); got != want {
+ t.Errorf("%q: got %t, want %t", host, got, want)
+ }
+ }
+}
+
+var jarKeyTests = map[string]string{
+ "foo.www.example.com": "example.com",
+ "www.example.com": "example.com",
+ "example.com": "example.com",
+ "com": "com",
+ "foo.www.bbc.co.uk": "bbc.co.uk",
+ "www.bbc.co.uk": "bbc.co.uk",
+ "bbc.co.uk": "bbc.co.uk",
+ "co.uk": "co.uk",
+ "uk": "uk",
+ "192.168.0.5": "192.168.0.5",
+ "www.buggy.psl": "www.buggy.psl",
+ "www2.buggy.psl": "buggy.psl",
+ // The following are actual outputs of canonicalHost for
+ // malformed inputs to canonicalHost (see above).
+ "": "",
+ ".": ".",
+ "..": ".",
+ ".net": ".net",
+ "a.": "a.",
+ "b.a.": "a.",
+ "weird.stuff..": ".",
+}
+
+func TestJarKey(t *testing.T) {
+ for host, want := range jarKeyTests {
+ if got := jarKey(host, testPSL{}); got != want {
+ t.Errorf("%q: got %q, want %q", host, got, want)
+ }
+ }
+}
+
+var jarKeyNilPSLTests = map[string]string{
+ "foo.www.example.com": "example.com",
+ "www.example.com": "example.com",
+ "example.com": "example.com",
+ "com": "com",
+ "foo.www.bbc.co.uk": "co.uk",
+ "www.bbc.co.uk": "co.uk",
+ "bbc.co.uk": "co.uk",
+ "co.uk": "co.uk",
+ "uk": "uk",
+ "192.168.0.5": "192.168.0.5",
+ // The following are actual outputs of canonicalHost for
+ // malformed inputs to canonicalHost.
+ "": "",
+ ".": ".",
+ "..": "..",
+ ".net": ".net",
+ "a.": "a.",
+ "b.a.": "a.",
+ "weird.stuff..": "stuff..",
+}
+
+func TestJarKeyNilPSL(t *testing.T) {
+ for host, want := range jarKeyNilPSLTests {
+ if got := jarKey(host, nil); got != want {
+ t.Errorf("%q: got %q, want %q", host, got, want)
+ }
+ }
+}
+
+var isIPTests = map[string]bool{
+ "127.0.0.1": true,
+ "1.2.3.4": true,
+ "2001:4860:0:2001::68": true,
+ "::1%zone": true,
+ "example.com": false,
+ "1.1.1.300": false,
+ "www.foo.bar.net": false,
+ "123.foo.bar.net": false,
+}
+
+func TestIsIP(t *testing.T) {
+ for host, want := range isIPTests {
+ if got := isIP(host); got != want {
+ t.Errorf("%q: got %t, want %t", host, got, want)
+ }
+ }
+}
+
+var defaultPathTests = map[string]string{
+ "/": "/",
+ "/abc": "/",
+ "/abc/": "/abc",
+ "/abc/xyz": "/abc",
+ "/abc/xyz/": "/abc/xyz",
+ "/a/b/c.html": "/a/b",
+ "": "/",
+ "strange": "/",
+ "//": "/",
+ "/a//b": "/a/",
+ "/a/./b": "/a/.",
+ "/a/../b": "/a/..",
+}
+
+func TestDefaultPath(t *testing.T) {
+ for path, want := range defaultPathTests {
+ if got := defaultPath(path); got != want {
+ t.Errorf("%q: got %q, want %q", path, got, want)
+ }
+ }
+}
+
+var domainAndTypeTests = [...]struct {
+ host string // host Set-Cookie header was received from
+ domain string // domain attribute in Set-Cookie header
+ wantDomain string // expected domain of cookie
+ wantHostOnly bool // expected host-cookie flag
+ wantErr error // expected error
+}{
+ {"www.example.com", "", "www.example.com", true, nil},
+ {"127.0.0.1", "", "127.0.0.1", true, nil},
+ {"2001:4860:0:2001::68", "", "2001:4860:0:2001::68", true, nil},
+ {"www.example.com", "example.com", "example.com", false, nil},
+ {"www.example.com", ".example.com", "example.com", false, nil},
+ {"www.example.com", "www.example.com", "www.example.com", false, nil},
+ {"www.example.com", ".www.example.com", "www.example.com", false, nil},
+ {"foo.sso.example.com", "sso.example.com", "sso.example.com", false, nil},
+ {"bar.co.uk", "bar.co.uk", "bar.co.uk", false, nil},
+ {"foo.bar.co.uk", ".bar.co.uk", "bar.co.uk", false, nil},
+ {"127.0.0.1", "127.0.0.1", "127.0.0.1", true, nil},
+ {"2001:4860:0:2001::68", "2001:4860:0:2001::68", "2001:4860:0:2001::68", true, nil},
+ {"www.example.com", ".", "", false, errMalformedDomain},
+ {"www.example.com", "..", "", false, errMalformedDomain},
+ {"www.example.com", "other.com", "", false, errIllegalDomain},
+ {"www.example.com", "com", "", false, errIllegalDomain},
+ {"www.example.com", ".com", "", false, errIllegalDomain},
+ {"foo.bar.co.uk", ".co.uk", "", false, errIllegalDomain},
+ {"127.www.0.0.1", "127.0.0.1", "", false, errIllegalDomain},
+ {"com", "", "com", true, nil},
+ {"com", "com", "com", true, nil},
+ {"com", ".com", "com", true, nil},
+ {"co.uk", "", "co.uk", true, nil},
+ {"co.uk", "co.uk", "co.uk", true, nil},
+ {"co.uk", ".co.uk", "co.uk", true, nil},
+}
+
+func TestDomainAndType(t *testing.T) {
+ jar := newTestJar()
+ for _, tc := range domainAndTypeTests {
+ domain, hostOnly, err := jar.domainAndType(tc.host, tc.domain)
+ if err != tc.wantErr {
+ t.Errorf("%q/%q: got %q error, want %v",
+ tc.host, tc.domain, err, tc.wantErr)
+ continue
+ }
+ if err != nil {
+ continue
+ }
+ if domain != tc.wantDomain || hostOnly != tc.wantHostOnly {
+ t.Errorf("%q/%q: got %q/%t want %q/%t",
+ tc.host, tc.domain, domain, hostOnly,
+ tc.wantDomain, tc.wantHostOnly)
+ }
+ }
+}
+
+// expiresIn creates an expires attribute delta seconds from tNow.
+func expiresIn(delta int) string {
+ t := tNow.Add(time.Duration(delta) * time.Second)
+ return "expires=" + t.Format(time.RFC1123)
+}
+
+// mustParseURL parses s to a URL and panics on error.
+func mustParseURL(s string) *url.URL {
+ u, err := url.Parse(s)
+ if err != nil || u.Scheme == "" || u.Host == "" {
+ panic(fmt.Sprintf("Unable to parse URL %s.", s))
+ }
+ return u
+}
+
+// jarTest encapsulates the following actions on a jar:
+// 1. Perform SetCookies with fromURL and the cookies from setCookies.
+// (Done at time tNow + 0 ms.)
+// 2. Check that the entries in the jar matches content.
+// (Done at time tNow + 1001 ms.)
+// 3. For each query in tests: Check that Cookies with toURL yields the
+// cookies in want.
+// (Query n done at tNow + (n+2)*1001 ms.)
+type jarTest struct {
+ description string // The description of what this test is supposed to test
+ fromURL string // The full URL of the request from which Set-Cookie headers where received
+ setCookies []string // All the cookies received from fromURL
+ content string // The whole (non-expired) content of the jar
+ queries []query // Queries to test the Jar.Cookies method
+}
+
+// query contains one test of the cookies returned from Jar.Cookies.
+type query struct {
+ toURL string // the URL in the Cookies call
+ want string // the expected list of cookies (order matters)
+}
+
+// run runs the jarTest.
+func (test jarTest) run(t *testing.T, jar *Jar) {
+ now := tNow
+
+ // Populate jar with cookies.
+ setCookies := make([]*http.Cookie, len(test.setCookies))
+ for i, cs := range test.setCookies {
+ cookies := (&http.Response{Header: http.Header{"Set-Cookie": {cs}}}).Cookies()
+ if len(cookies) != 1 {
+ panic(fmt.Sprintf("Wrong cookie line %q: %#v", cs, cookies))
+ }
+ setCookies[i] = cookies[0]
+ }
+ jar.setCookies(mustParseURL(test.fromURL), setCookies, now)
+ now = now.Add(1001 * time.Millisecond)
+
+ // Serialize non-expired entries in the form "name1=val1 name2=val2".
+ var cs []string
+ for _, submap := range jar.entries {
+ for _, cookie := range submap {
+ if !cookie.Expires.After(now) {
+ continue
+ }
+ cs = append(cs, cookie.Name+"="+cookie.Value)
+ }
+ }
+ sort.Strings(cs)
+ got := strings.Join(cs, " ")
+
+ // Make sure jar content matches our expectations.
+ if got != test.content {
+ t.Errorf("Test %q Content\ngot %q\nwant %q",
+ test.description, got, test.content)
+ }
+
+ // Test different calls to Cookies.
+ for i, query := range test.queries {
+ now = now.Add(1001 * time.Millisecond)
+ var s []string
+ for _, c := range jar.cookies(mustParseURL(query.toURL), now) {
+ s = append(s, c.Name+"="+c.Value)
+ }
+ if got := strings.Join(s, " "); got != query.want {
+ t.Errorf("Test %q #%d\ngot %q\nwant %q", test.description, i, got, query.want)
+ }
+ }
+}
+
+// basicsTests contains fundamental tests. Each jarTest has to be performed on
+// a fresh, empty Jar.
+var basicsTests = [...]jarTest{
+ {
+ "Retrieval of a plain host cookie.",
+ "http://www.host.test/",
+ []string{"A=a"},
+ "A=a",
+ []query{
+ {"http://www.host.test", "A=a"},
+ {"http://www.host.test/", "A=a"},
+ {"http://www.host.test/some/path", "A=a"},
+ {"https://www.host.test", "A=a"},
+ {"https://www.host.test/", "A=a"},
+ {"https://www.host.test/some/path", "A=a"},
+ {"ftp://www.host.test", ""},
+ {"ftp://www.host.test/", ""},
+ {"ftp://www.host.test/some/path", ""},
+ {"http://www.other.org", ""},
+ {"http://sibling.host.test", ""},
+ {"http://deep.www.host.test", ""},
+ },
+ },
+ {
+ "Secure cookies are not returned to http.",
+ "http://www.host.test/",
+ []string{"A=a; secure"},
+ "A=a",
+ []query{
+ {"http://www.host.test", ""},
+ {"http://www.host.test/", ""},
+ {"http://www.host.test/some/path", ""},
+ {"https://www.host.test", "A=a"},
+ {"https://www.host.test/", "A=a"},
+ {"https://www.host.test/some/path", "A=a"},
+ },
+ },
+ {
+ "Explicit path.",
+ "http://www.host.test/",
+ []string{"A=a; path=/some/path"},
+ "A=a",
+ []query{
+ {"http://www.host.test", ""},
+ {"http://www.host.test/", ""},
+ {"http://www.host.test/some", ""},
+ {"http://www.host.test/some/", ""},
+ {"http://www.host.test/some/path", "A=a"},
+ {"http://www.host.test/some/paths", ""},
+ {"http://www.host.test/some/path/foo", "A=a"},
+ {"http://www.host.test/some/path/foo/", "A=a"},
+ },
+ },
+ {
+ "Implicit path #1: path is a directory.",
+ "http://www.host.test/some/path/",
+ []string{"A=a"},
+ "A=a",
+ []query{
+ {"http://www.host.test", ""},
+ {"http://www.host.test/", ""},
+ {"http://www.host.test/some", ""},
+ {"http://www.host.test/some/", ""},
+ {"http://www.host.test/some/path", "A=a"},
+ {"http://www.host.test/some/paths", ""},
+ {"http://www.host.test/some/path/foo", "A=a"},
+ {"http://www.host.test/some/path/foo/", "A=a"},
+ },
+ },
+ {
+ "Implicit path #2: path is not a directory.",
+ "http://www.host.test/some/path/index.html",
+ []string{"A=a"},
+ "A=a",
+ []query{
+ {"http://www.host.test", ""},
+ {"http://www.host.test/", ""},
+ {"http://www.host.test/some", ""},
+ {"http://www.host.test/some/", ""},
+ {"http://www.host.test/some/path", "A=a"},
+ {"http://www.host.test/some/paths", ""},
+ {"http://www.host.test/some/path/foo", "A=a"},
+ {"http://www.host.test/some/path/foo/", "A=a"},
+ },
+ },
+ {
+ "Implicit path #3: no path in URL at all.",
+ "http://www.host.test",
+ []string{"A=a"},
+ "A=a",
+ []query{
+ {"http://www.host.test", "A=a"},
+ {"http://www.host.test/", "A=a"},
+ {"http://www.host.test/some/path", "A=a"},
+ },
+ },
+ {
+ "Cookies are sorted by path length.",
+ "http://www.host.test/",
+ []string{
+ "A=a; path=/foo/bar",
+ "B=b; path=/foo/bar/baz/qux",
+ "C=c; path=/foo/bar/baz",
+ "D=d; path=/foo"},
+ "A=a B=b C=c D=d",
+ []query{
+ {"http://www.host.test/foo/bar/baz/qux", "B=b C=c A=a D=d"},
+ {"http://www.host.test/foo/bar/baz/", "C=c A=a D=d"},
+ {"http://www.host.test/foo/bar", "A=a D=d"},
+ },
+ },
+ {
+ "Creation time determines sorting on same length paths.",
+ "http://www.host.test/",
+ []string{
+ "A=a; path=/foo/bar",
+ "X=x; path=/foo/bar",
+ "Y=y; path=/foo/bar/baz/qux",
+ "B=b; path=/foo/bar/baz/qux",
+ "C=c; path=/foo/bar/baz",
+ "W=w; path=/foo/bar/baz",
+ "Z=z; path=/foo",
+ "D=d; path=/foo"},
+ "A=a B=b C=c D=d W=w X=x Y=y Z=z",
+ []query{
+ {"http://www.host.test/foo/bar/baz/qux", "Y=y B=b C=c W=w A=a X=x Z=z D=d"},
+ {"http://www.host.test/foo/bar/baz/", "C=c W=w A=a X=x Z=z D=d"},
+ {"http://www.host.test/foo/bar", "A=a X=x Z=z D=d"},
+ },
+ },
+ {
+ "Sorting of same-name cookies.",
+ "http://www.host.test/",
+ []string{
+ "A=1; path=/",
+ "A=2; path=/path",
+ "A=3; path=/quux",
+ "A=4; path=/path/foo",
+ "A=5; domain=.host.test; path=/path",
+ "A=6; domain=.host.test; path=/quux",
+ "A=7; domain=.host.test; path=/path/foo",
+ },
+ "A=1 A=2 A=3 A=4 A=5 A=6 A=7",
+ []query{
+ {"http://www.host.test/path", "A=2 A=5 A=1"},
+ {"http://www.host.test/path/foo", "A=4 A=7 A=2 A=5 A=1"},
+ },
+ },
+ {
+ "Disallow domain cookie on public suffix.",
+ "http://www.bbc.co.uk",
+ []string{
+ "a=1",
+ "b=2; domain=co.uk",
+ },
+ "a=1",
+ []query{{"http://www.bbc.co.uk", "a=1"}},
+ },
+ {
+ "Host cookie on IP.",
+ "http://192.168.0.10",
+ []string{"a=1"},
+ "a=1",
+ []query{{"http://192.168.0.10", "a=1"}},
+ },
+ {
+ "Domain cookies on IP.",
+ "http://192.168.0.10",
+ []string{
+ "a=1; domain=192.168.0.10", // allowed
+ "b=2; domain=172.31.9.9", // rejected, can't set cookie for other IP
+ "c=3; domain=.192.168.0.10", // rejected like in most browsers
+ },
+ "a=1",
+ []query{
+ {"http://192.168.0.10", "a=1"},
+ {"http://172.31.9.9", ""},
+ {"http://www.fancy.192.168.0.10", ""},
+ },
+ },
+ {
+ "Port is ignored #1.",
+ "http://www.host.test/",
+ []string{"a=1"},
+ "a=1",
+ []query{
+ {"http://www.host.test", "a=1"},
+ {"http://www.host.test:8080/", "a=1"},
+ },
+ },
+ {
+ "Port is ignored #2.",
+ "http://www.host.test:8080/",
+ []string{"a=1"},
+ "a=1",
+ []query{
+ {"http://www.host.test", "a=1"},
+ {"http://www.host.test:8080/", "a=1"},
+ {"http://www.host.test:1234/", "a=1"},
+ },
+ },
+ {
+ "IPv6 zone is not treated as a host.",
+ "https://example.com/",
+ []string{"a=1"},
+ "a=1",
+ []query{
+ {"https://[::1%25.example.com]:80/", ""},
+ },
+ },
+}
+
+func TestBasics(t *testing.T) {
+ for _, test := range basicsTests {
+ jar := newTestJar()
+ test.run(t, jar)
+ }
+}
+
+// updateAndDeleteTests contains jarTests which must be performed on the same
+// Jar.
+var updateAndDeleteTests = [...]jarTest{
+ {
+ "Set initial cookies.",
+ "http://www.host.test",
+ []string{
+ "a=1",
+ "b=2; secure",
+ "c=3; httponly",
+ "d=4; secure; httponly"},
+ "a=1 b=2 c=3 d=4",
+ []query{
+ {"http://www.host.test", "a=1 c=3"},
+ {"https://www.host.test", "a=1 b=2 c=3 d=4"},
+ },
+ },
+ {
+ "Update value via http.",
+ "http://www.host.test",
+ []string{
+ "a=w",
+ "b=x; secure",
+ "c=y; httponly",
+ "d=z; secure; httponly"},
+ "a=w b=x c=y d=z",
+ []query{
+ {"http://www.host.test", "a=w c=y"},
+ {"https://www.host.test", "a=w b=x c=y d=z"},
+ },
+ },
+ {
+ "Clear Secure flag from an http.",
+ "http://www.host.test/",
+ []string{
+ "b=xx",
+ "d=zz; httponly"},
+ "a=w b=xx c=y d=zz",
+ []query{{"http://www.host.test", "a=w b=xx c=y d=zz"}},
+ },
+ {
+ "Delete all.",
+ "http://www.host.test/",
+ []string{
+ "a=1; max-Age=-1", // delete via MaxAge
+ "b=2; " + expiresIn(-10), // delete via Expires
+ "c=2; max-age=-1; " + expiresIn(-10), // delete via both
+ "d=4; max-age=-1; " + expiresIn(10)}, // MaxAge takes precedence
+ "",
+ []query{{"http://www.host.test", ""}},
+ },
+ {
+ "Refill #1.",
+ "http://www.host.test",
+ []string{
+ "A=1",
+ "A=2; path=/foo",
+ "A=3; domain=.host.test",
+ "A=4; path=/foo; domain=.host.test"},
+ "A=1 A=2 A=3 A=4",
+ []query{{"http://www.host.test/foo", "A=2 A=4 A=1 A=3"}},
+ },
+ {
+ "Refill #2.",
+ "http://www.google.com",
+ []string{
+ "A=6",
+ "A=7; path=/foo",
+ "A=8; domain=.google.com",
+ "A=9; path=/foo; domain=.google.com"},
+ "A=1 A=2 A=3 A=4 A=6 A=7 A=8 A=9",
+ []query{
+ {"http://www.host.test/foo", "A=2 A=4 A=1 A=3"},
+ {"http://www.google.com/foo", "A=7 A=9 A=6 A=8"},
+ },
+ },
+ {
+ "Delete A7.",
+ "http://www.google.com",
+ []string{"A=; path=/foo; max-age=-1"},
+ "A=1 A=2 A=3 A=4 A=6 A=8 A=9",
+ []query{
+ {"http://www.host.test/foo", "A=2 A=4 A=1 A=3"},
+ {"http://www.google.com/foo", "A=9 A=6 A=8"},
+ },
+ },
+ {
+ "Delete A4.",
+ "http://www.host.test",
+ []string{"A=; path=/foo; domain=host.test; max-age=-1"},
+ "A=1 A=2 A=3 A=6 A=8 A=9",
+ []query{
+ {"http://www.host.test/foo", "A=2 A=1 A=3"},
+ {"http://www.google.com/foo", "A=9 A=6 A=8"},
+ },
+ },
+ {
+ "Delete A6.",
+ "http://www.google.com",
+ []string{"A=; max-age=-1"},
+ "A=1 A=2 A=3 A=8 A=9",
+ []query{
+ {"http://www.host.test/foo", "A=2 A=1 A=3"},
+ {"http://www.google.com/foo", "A=9 A=8"},
+ },
+ },
+ {
+ "Delete A3.",
+ "http://www.host.test",
+ []string{"A=; domain=host.test; max-age=-1"},
+ "A=1 A=2 A=8 A=9",
+ []query{
+ {"http://www.host.test/foo", "A=2 A=1"},
+ {"http://www.google.com/foo", "A=9 A=8"},
+ },
+ },
+ {
+ "No cross-domain delete.",
+ "http://www.host.test",
+ []string{
+ "A=; domain=google.com; max-age=-1",
+ "A=; path=/foo; domain=google.com; max-age=-1"},
+ "A=1 A=2 A=8 A=9",
+ []query{
+ {"http://www.host.test/foo", "A=2 A=1"},
+ {"http://www.google.com/foo", "A=9 A=8"},
+ },
+ },
+ {
+ "Delete A8 and A9.",
+ "http://www.google.com",
+ []string{
+ "A=; domain=google.com; max-age=-1",
+ "A=; path=/foo; domain=google.com; max-age=-1"},
+ "A=1 A=2",
+ []query{
+ {"http://www.host.test/foo", "A=2 A=1"},
+ {"http://www.google.com/foo", ""},
+ },
+ },
+}
+
+func TestUpdateAndDelete(t *testing.T) {
+ jar := newTestJar()
+ for _, test := range updateAndDeleteTests {
+ test.run(t, jar)
+ }
+}
+
+func TestExpiration(t *testing.T) {
+ jar := newTestJar()
+ jarTest{
+ "Expiration.",
+ "http://www.host.test",
+ []string{
+ "a=1",
+ "b=2; max-age=3",
+ "c=3; " + expiresIn(3),
+ "d=4; max-age=5",
+ "e=5; " + expiresIn(5),
+ "f=6; max-age=100",
+ },
+ "a=1 b=2 c=3 d=4 e=5 f=6", // executed at t0 + 1001 ms
+ []query{
+ {"http://www.host.test", "a=1 b=2 c=3 d=4 e=5 f=6"}, // t0 + 2002 ms
+ {"http://www.host.test", "a=1 d=4 e=5 f=6"}, // t0 + 3003 ms
+ {"http://www.host.test", "a=1 d=4 e=5 f=6"}, // t0 + 4004 ms
+ {"http://www.host.test", "a=1 f=6"}, // t0 + 5005 ms
+ {"http://www.host.test", "a=1 f=6"}, // t0 + 6006 ms
+ },
+ }.run(t, jar)
+}
+
+//
+// Tests derived from Chromium's cookie_store_unittest.h.
+//
+
+// See http://src.chromium.org/viewvc/chrome/trunk/src/net/cookies/cookie_store_unittest.h?revision=159685&content-type=text/plain
+// Some of the original tests are in a bad condition (e.g.
+// DomainWithTrailingDotTest) or are not RFC 6265 conforming (e.g.
+// TestNonDottedAndTLD #1 and #6) and have not been ported.
+
+// chromiumBasicsTests contains fundamental tests. Each jarTest has to be
+// performed on a fresh, empty Jar.
+var chromiumBasicsTests = [...]jarTest{
+ {
+ "DomainWithTrailingDotTest.",
+ "http://www.google.com/",
+ []string{
+ "a=1; domain=.www.google.com.",
+ "b=2; domain=.www.google.com.."},
+ "",
+ []query{
+ {"http://www.google.com", ""},
+ },
+ },
+ {
+ "ValidSubdomainTest #1.",
+ "http://a.b.c.d.com",
+ []string{
+ "a=1; domain=.a.b.c.d.com",
+ "b=2; domain=.b.c.d.com",
+ "c=3; domain=.c.d.com",
+ "d=4; domain=.d.com"},
+ "a=1 b=2 c=3 d=4",
+ []query{
+ {"http://a.b.c.d.com", "a=1 b=2 c=3 d=4"},
+ {"http://b.c.d.com", "b=2 c=3 d=4"},
+ {"http://c.d.com", "c=3 d=4"},
+ {"http://d.com", "d=4"},
+ },
+ },
+ {
+ "ValidSubdomainTest #2.",
+ "http://a.b.c.d.com",
+ []string{
+ "a=1; domain=.a.b.c.d.com",
+ "b=2; domain=.b.c.d.com",
+ "c=3; domain=.c.d.com",
+ "d=4; domain=.d.com",
+ "X=bcd; domain=.b.c.d.com",
+ "X=cd; domain=.c.d.com"},
+ "X=bcd X=cd a=1 b=2 c=3 d=4",
+ []query{
+ {"http://b.c.d.com", "b=2 c=3 d=4 X=bcd X=cd"},
+ {"http://c.d.com", "c=3 d=4 X=cd"},
+ },
+ },
+ {
+ "InvalidDomainTest #1.",
+ "http://foo.bar.com",
+ []string{
+ "a=1; domain=.yo.foo.bar.com",
+ "b=2; domain=.foo.com",
+ "c=3; domain=.bar.foo.com",
+ "d=4; domain=.foo.bar.com.net",
+ "e=5; domain=ar.com",
+ "f=6; domain=.",
+ "g=7; domain=/",
+ "h=8; domain=http://foo.bar.com",
+ "i=9; domain=..foo.bar.com",
+ "j=10; domain=..bar.com",
+ "k=11; domain=.foo.bar.com?blah",
+ "l=12; domain=.foo.bar.com/blah",
+ "m=12; domain=.foo.bar.com:80",
+ "n=14; domain=.foo.bar.com:",
+ "o=15; domain=.foo.bar.com#sup",
+ },
+ "", // Jar is empty.
+ []query{{"http://foo.bar.com", ""}},
+ },
+ {
+ "InvalidDomainTest #2.",
+ "http://foo.com.com",
+ []string{"a=1; domain=.foo.com.com.com"},
+ "",
+ []query{{"http://foo.bar.com", ""}},
+ },
+ {
+ "DomainWithoutLeadingDotTest #1.",
+ "http://manage.hosted.filefront.com",
+ []string{"a=1; domain=filefront.com"},
+ "a=1",
+ []query{{"http://www.filefront.com", "a=1"}},
+ },
+ {
+ "DomainWithoutLeadingDotTest #2.",
+ "http://www.google.com",
+ []string{"a=1; domain=www.google.com"},
+ "a=1",
+ []query{
+ {"http://www.google.com", "a=1"},
+ {"http://sub.www.google.com", "a=1"},
+ {"http://something-else.com", ""},
+ },
+ },
+ {
+ "CaseInsensitiveDomainTest.",
+ "http://www.google.com",
+ []string{
+ "a=1; domain=.GOOGLE.COM",
+ "b=2; domain=.www.gOOgLE.coM"},
+ "a=1 b=2",
+ []query{{"http://www.google.com", "a=1 b=2"}},
+ },
+ {
+ "TestIpAddress #1.",
+ "http://1.2.3.4/foo",
+ []string{"a=1; path=/"},
+ "a=1",
+ []query{{"http://1.2.3.4/foo", "a=1"}},
+ },
+ {
+ "TestIpAddress #2.",
+ "http://1.2.3.4/foo",
+ []string{
+ "a=1; domain=.1.2.3.4",
+ "b=2; domain=.3.4"},
+ "",
+ []query{{"http://1.2.3.4/foo", ""}},
+ },
+ {
+ "TestIpAddress #3.",
+ "http://1.2.3.4/foo",
+ []string{"a=1; domain=1.2.3.3"},
+ "",
+ []query{{"http://1.2.3.4/foo", ""}},
+ },
+ {
+ "TestIpAddress #4.",
+ "http://1.2.3.4/foo",
+ []string{"a=1; domain=1.2.3.4"},
+ "a=1",
+ []query{{"http://1.2.3.4/foo", "a=1"}},
+ },
+ {
+ "TestNonDottedAndTLD #2.",
+ "http://com./index.html",
+ []string{"a=1"},
+ "a=1",
+ []query{
+ {"http://com./index.html", "a=1"},
+ {"http://no-cookies.com./index.html", ""},
+ },
+ },
+ {
+ "TestNonDottedAndTLD #3.",
+ "http://a.b",
+ []string{
+ "a=1; domain=.b",
+ "b=2; domain=b"},
+ "",
+ []query{{"http://bar.foo", ""}},
+ },
+ {
+ "TestNonDottedAndTLD #4.",
+ "http://google.com",
+ []string{
+ "a=1; domain=.com",
+ "b=2; domain=com"},
+ "",
+ []query{{"http://google.com", ""}},
+ },
+ {
+ "TestNonDottedAndTLD #5.",
+ "http://google.co.uk",
+ []string{
+ "a=1; domain=.co.uk",
+ "b=2; domain=.uk"},
+ "",
+ []query{
+ {"http://google.co.uk", ""},
+ {"http://else.co.com", ""},
+ {"http://else.uk", ""},
+ },
+ },
+ {
+ "TestHostEndsWithDot.",
+ "http://www.google.com",
+ []string{
+ "a=1",
+ "b=2; domain=.www.google.com."},
+ "a=1",
+ []query{{"http://www.google.com", "a=1"}},
+ },
+ {
+ "PathTest",
+ "http://www.google.izzle",
+ []string{"a=1; path=/wee"},
+ "a=1",
+ []query{
+ {"http://www.google.izzle/wee", "a=1"},
+ {"http://www.google.izzle/wee/", "a=1"},
+ {"http://www.google.izzle/wee/war", "a=1"},
+ {"http://www.google.izzle/wee/war/more/more", "a=1"},
+ {"http://www.google.izzle/weehee", ""},
+ {"http://www.google.izzle/", ""},
+ },
+ },
+}
+
+func TestChromiumBasics(t *testing.T) {
+ for _, test := range chromiumBasicsTests {
+ jar := newTestJar()
+ test.run(t, jar)
+ }
+}
+
+// chromiumDomainTests contains jarTests which must be executed all on the
+// same Jar.
+var chromiumDomainTests = [...]jarTest{
+ {
+ "Fill #1.",
+ "http://www.google.izzle",
+ []string{"A=B"},
+ "A=B",
+ []query{{"http://www.google.izzle", "A=B"}},
+ },
+ {
+ "Fill #2.",
+ "http://www.google.izzle",
+ []string{"C=D; domain=.google.izzle"},
+ "A=B C=D",
+ []query{{"http://www.google.izzle", "A=B C=D"}},
+ },
+ {
+ "Verify A is a host cookie and not accessible from subdomain.",
+ "http://unused.nil",
+ []string{},
+ "A=B C=D",
+ []query{{"http://foo.www.google.izzle", "C=D"}},
+ },
+ {
+ "Verify domain cookies are found on proper domain.",
+ "http://www.google.izzle",
+ []string{"E=F; domain=.www.google.izzle"},
+ "A=B C=D E=F",
+ []query{{"http://www.google.izzle", "A=B C=D E=F"}},
+ },
+ {
+ "Leading dots in domain attributes are optional.",
+ "http://www.google.izzle",
+ []string{"G=H; domain=www.google.izzle"},
+ "A=B C=D E=F G=H",
+ []query{{"http://www.google.izzle", "A=B C=D E=F G=H"}},
+ },
+ {
+ "Verify domain enforcement works #1.",
+ "http://www.google.izzle",
+ []string{"K=L; domain=.bar.www.google.izzle"},
+ "A=B C=D E=F G=H",
+ []query{{"http://bar.www.google.izzle", "C=D E=F G=H"}},
+ },
+ {
+ "Verify domain enforcement works #2.",
+ "http://unused.nil",
+ []string{},
+ "A=B C=D E=F G=H",
+ []query{{"http://www.google.izzle", "A=B C=D E=F G=H"}},
+ },
+}
+
+func TestChromiumDomain(t *testing.T) {
+ jar := newTestJar()
+ for _, test := range chromiumDomainTests {
+ test.run(t, jar)
+ }
+
+}
+
+// chromiumDeletionTests must be performed all on the same Jar.
+var chromiumDeletionTests = [...]jarTest{
+ {
+ "Create session cookie a1.",
+ "http://www.google.com",
+ []string{"a=1"},
+ "a=1",
+ []query{{"http://www.google.com", "a=1"}},
+ },
+ {
+ "Delete sc a1 via MaxAge.",
+ "http://www.google.com",
+ []string{"a=1; max-age=-1"},
+ "",
+ []query{{"http://www.google.com", ""}},
+ },
+ {
+ "Create session cookie b2.",
+ "http://www.google.com",
+ []string{"b=2"},
+ "b=2",
+ []query{{"http://www.google.com", "b=2"}},
+ },
+ {
+ "Delete sc b2 via Expires.",
+ "http://www.google.com",
+ []string{"b=2; " + expiresIn(-10)},
+ "",
+ []query{{"http://www.google.com", ""}},
+ },
+ {
+ "Create persistent cookie c3.",
+ "http://www.google.com",
+ []string{"c=3; max-age=3600"},
+ "c=3",
+ []query{{"http://www.google.com", "c=3"}},
+ },
+ {
+ "Delete pc c3 via MaxAge.",
+ "http://www.google.com",
+ []string{"c=3; max-age=-1"},
+ "",
+ []query{{"http://www.google.com", ""}},
+ },
+ {
+ "Create persistent cookie d4.",
+ "http://www.google.com",
+ []string{"d=4; max-age=3600"},
+ "d=4",
+ []query{{"http://www.google.com", "d=4"}},
+ },
+ {
+ "Delete pc d4 via Expires.",
+ "http://www.google.com",
+ []string{"d=4; " + expiresIn(-10)},
+ "",
+ []query{{"http://www.google.com", ""}},
+ },
+}
+
+func TestChromiumDeletion(t *testing.T) {
+ jar := newTestJar()
+ for _, test := range chromiumDeletionTests {
+ test.run(t, jar)
+ }
+}
+
+// domainHandlingTests tests and documents the rules for domain handling.
+// Each test must be performed on an empty new Jar.
+var domainHandlingTests = [...]jarTest{
+ {
+ "Host cookie",
+ "http://www.host.test",
+ []string{"a=1"},
+ "a=1",
+ []query{
+ {"http://www.host.test", "a=1"},
+ {"http://host.test", ""},
+ {"http://bar.host.test", ""},
+ {"http://foo.www.host.test", ""},
+ {"http://other.test", ""},
+ {"http://test", ""},
+ },
+ },
+ {
+ "Domain cookie #1",
+ "http://www.host.test",
+ []string{"a=1; domain=host.test"},
+ "a=1",
+ []query{
+ {"http://www.host.test", "a=1"},
+ {"http://host.test", "a=1"},
+ {"http://bar.host.test", "a=1"},
+ {"http://foo.www.host.test", "a=1"},
+ {"http://other.test", ""},
+ {"http://test", ""},
+ },
+ },
+ {
+ "Domain cookie #2",
+ "http://www.host.test",
+ []string{"a=1; domain=.host.test"},
+ "a=1",
+ []query{
+ {"http://www.host.test", "a=1"},
+ {"http://host.test", "a=1"},
+ {"http://bar.host.test", "a=1"},
+ {"http://foo.www.host.test", "a=1"},
+ {"http://other.test", ""},
+ {"http://test", ""},
+ },
+ },
+ {
+ "Host cookie on IDNA domain #1",
+ "http://www.bücher.test",
+ []string{"a=1"},
+ "a=1",
+ []query{
+ {"http://www.bücher.test", "a=1"},
+ {"http://www.xn--bcher-kva.test", "a=1"},
+ {"http://bücher.test", ""},
+ {"http://xn--bcher-kva.test", ""},
+ {"http://bar.bücher.test", ""},
+ {"http://bar.xn--bcher-kva.test", ""},
+ {"http://foo.www.bücher.test", ""},
+ {"http://foo.www.xn--bcher-kva.test", ""},
+ {"http://other.test", ""},
+ {"http://test", ""},
+ },
+ },
+ {
+ "Host cookie on IDNA domain #2",
+ "http://www.xn--bcher-kva.test",
+ []string{"a=1"},
+ "a=1",
+ []query{
+ {"http://www.bücher.test", "a=1"},
+ {"http://www.xn--bcher-kva.test", "a=1"},
+ {"http://bücher.test", ""},
+ {"http://xn--bcher-kva.test", ""},
+ {"http://bar.bücher.test", ""},
+ {"http://bar.xn--bcher-kva.test", ""},
+ {"http://foo.www.bücher.test", ""},
+ {"http://foo.www.xn--bcher-kva.test", ""},
+ {"http://other.test", ""},
+ {"http://test", ""},
+ },
+ },
+ {
+ "Domain cookie on IDNA domain #1",
+ "http://www.bücher.test",
+ []string{"a=1; domain=xn--bcher-kva.test"},
+ "a=1",
+ []query{
+ {"http://www.bücher.test", "a=1"},
+ {"http://www.xn--bcher-kva.test", "a=1"},
+ {"http://bücher.test", "a=1"},
+ {"http://xn--bcher-kva.test", "a=1"},
+ {"http://bar.bücher.test", "a=1"},
+ {"http://bar.xn--bcher-kva.test", "a=1"},
+ {"http://foo.www.bücher.test", "a=1"},
+ {"http://foo.www.xn--bcher-kva.test", "a=1"},
+ {"http://other.test", ""},
+ {"http://test", ""},
+ },
+ },
+ {
+ "Domain cookie on IDNA domain #2",
+ "http://www.xn--bcher-kva.test",
+ []string{"a=1; domain=xn--bcher-kva.test"},
+ "a=1",
+ []query{
+ {"http://www.bücher.test", "a=1"},
+ {"http://www.xn--bcher-kva.test", "a=1"},
+ {"http://bücher.test", "a=1"},
+ {"http://xn--bcher-kva.test", "a=1"},
+ {"http://bar.bücher.test", "a=1"},
+ {"http://bar.xn--bcher-kva.test", "a=1"},
+ {"http://foo.www.bücher.test", "a=1"},
+ {"http://foo.www.xn--bcher-kva.test", "a=1"},
+ {"http://other.test", ""},
+ {"http://test", ""},
+ },
+ },
+ {
+ "Host cookie on TLD.",
+ "http://com",
+ []string{"a=1"},
+ "a=1",
+ []query{
+ {"http://com", "a=1"},
+ {"http://any.com", ""},
+ {"http://any.test", ""},
+ },
+ },
+ {
+ "Domain cookie on TLD becomes a host cookie.",
+ "http://com",
+ []string{"a=1; domain=com"},
+ "a=1",
+ []query{
+ {"http://com", "a=1"},
+ {"http://any.com", ""},
+ {"http://any.test", ""},
+ },
+ },
+ {
+ "Host cookie on public suffix.",
+ "http://co.uk",
+ []string{"a=1"},
+ "a=1",
+ []query{
+ {"http://co.uk", "a=1"},
+ {"http://uk", ""},
+ {"http://some.co.uk", ""},
+ {"http://foo.some.co.uk", ""},
+ {"http://any.uk", ""},
+ },
+ },
+ {
+ "Domain cookie on public suffix is ignored.",
+ "http://some.co.uk",
+ []string{"a=1; domain=co.uk"},
+ "",
+ []query{
+ {"http://co.uk", ""},
+ {"http://uk", ""},
+ {"http://some.co.uk", ""},
+ {"http://foo.some.co.uk", ""},
+ {"http://any.uk", ""},
+ },
+ },
+}
+
+func TestDomainHandling(t *testing.T) {
+ for _, test := range domainHandlingTests {
+ jar := newTestJar()
+ test.run(t, jar)
+ }
+}
+
+func TestIssue19384(t *testing.T) {
+ cookies := []*http.Cookie{{Name: "name", Value: "value"}}
+ for _, host := range []string{"", ".", "..", "..."} {
+ jar, _ := New(nil)
+ u := &url.URL{Scheme: "http", Host: host, Path: "/"}
+ if got := jar.Cookies(u); len(got) != 0 {
+ t.Errorf("host %q, got %v", host, got)
+ }
+ jar.SetCookies(u, cookies)
+ if got := jar.Cookies(u); len(got) != 1 || got[0].Value != "value" {
+ t.Errorf("host %q, got %v", host, got)
+ }
+ }
+}
diff --git a/src/net/http/cookiejar/punycode.go b/src/net/http/cookiejar/punycode.go
new file mode 100644
index 0000000..c7f438d
--- /dev/null
+++ b/src/net/http/cookiejar/punycode.go
@@ -0,0 +1,151 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package cookiejar
+
+// This file implements the Punycode algorithm from RFC 3492.
+
+import (
+ "fmt"
+ "net/http/internal/ascii"
+ "strings"
+ "unicode/utf8"
+)
+
+// These parameter values are specified in section 5.
+//
+// All computation is done with int32s, so that overflow behavior is identical
+// regardless of whether int is 32-bit or 64-bit.
+const (
+ base int32 = 36
+ damp int32 = 700
+ initialBias int32 = 72
+ initialN int32 = 128
+ skew int32 = 38
+ tmax int32 = 26
+ tmin int32 = 1
+)
+
+// encode encodes a string as specified in section 6.3 and prepends prefix to
+// the result.
+//
+// The "while h < length(input)" line in the specification becomes "for
+// remaining != 0" in the Go code, because len(s) in Go is in bytes, not runes.
+func encode(prefix, s string) (string, error) {
+ output := make([]byte, len(prefix), len(prefix)+1+2*len(s))
+ copy(output, prefix)
+ delta, n, bias := int32(0), initialN, initialBias
+ b, remaining := int32(0), int32(0)
+ for _, r := range s {
+ if r < utf8.RuneSelf {
+ b++
+ output = append(output, byte(r))
+ } else {
+ remaining++
+ }
+ }
+ h := b
+ if b > 0 {
+ output = append(output, '-')
+ }
+ for remaining != 0 {
+ m := int32(0x7fffffff)
+ for _, r := range s {
+ if m > r && r >= n {
+ m = r
+ }
+ }
+ delta += (m - n) * (h + 1)
+ if delta < 0 {
+ return "", fmt.Errorf("cookiejar: invalid label %q", s)
+ }
+ n = m
+ for _, r := range s {
+ if r < n {
+ delta++
+ if delta < 0 {
+ return "", fmt.Errorf("cookiejar: invalid label %q", s)
+ }
+ continue
+ }
+ if r > n {
+ continue
+ }
+ q := delta
+ for k := base; ; k += base {
+ t := k - bias
+ if t < tmin {
+ t = tmin
+ } else if t > tmax {
+ t = tmax
+ }
+ if q < t {
+ break
+ }
+ output = append(output, encodeDigit(t+(q-t)%(base-t)))
+ q = (q - t) / (base - t)
+ }
+ output = append(output, encodeDigit(q))
+ bias = adapt(delta, h+1, h == b)
+ delta = 0
+ h++
+ remaining--
+ }
+ delta++
+ n++
+ }
+ return string(output), nil
+}
+
+func encodeDigit(digit int32) byte {
+ switch {
+ case 0 <= digit && digit < 26:
+ return byte(digit + 'a')
+ case 26 <= digit && digit < 36:
+ return byte(digit + ('0' - 26))
+ }
+ panic("cookiejar: internal error in punycode encoding")
+}
+
+// adapt is the bias adaptation function specified in section 6.1.
+func adapt(delta, numPoints int32, firstTime bool) int32 {
+ if firstTime {
+ delta /= damp
+ } else {
+ delta /= 2
+ }
+ delta += delta / numPoints
+ k := int32(0)
+ for delta > ((base-tmin)*tmax)/2 {
+ delta /= base - tmin
+ k += base
+ }
+ return k + (base-tmin+1)*delta/(delta+skew)
+}
+
+// Strictly speaking, the remaining code below deals with IDNA (RFC 5890 and
+// friends) and not Punycode (RFC 3492) per se.
+
+// acePrefix is the ASCII Compatible Encoding prefix.
+const acePrefix = "xn--"
+
+// toASCII converts a domain or domain label to its ASCII form. For example,
+// toASCII("bücher.example.com") is "xn--bcher-kva.example.com", and
+// toASCII("golang") is "golang".
+func toASCII(s string) (string, error) {
+ if ascii.Is(s) {
+ return s, nil
+ }
+ labels := strings.Split(s, ".")
+ for i, label := range labels {
+ if !ascii.Is(label) {
+ a, err := encode(acePrefix, label)
+ if err != nil {
+ return "", err
+ }
+ labels[i] = a
+ }
+ }
+ return strings.Join(labels, "."), nil
+}
diff --git a/src/net/http/cookiejar/punycode_test.go b/src/net/http/cookiejar/punycode_test.go
new file mode 100644
index 0000000..0301de1
--- /dev/null
+++ b/src/net/http/cookiejar/punycode_test.go
@@ -0,0 +1,161 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package cookiejar
+
+import (
+ "testing"
+)
+
+var punycodeTestCases = [...]struct {
+ s, encoded string
+}{
+ {"", ""},
+ {"-", "--"},
+ {"-a", "-a-"},
+ {"-a-", "-a--"},
+ {"a", "a-"},
+ {"a-", "a--"},
+ {"a-b", "a-b-"},
+ {"books", "books-"},
+ {"bücher", "bcher-kva"},
+ {"Hello世界", "Hello-ck1hg65u"},
+ {"ü", "tda"},
+ {"üý", "tdac"},
+
+ // The test cases below come from RFC 3492 section 7.1 with Errata 3026.
+ {
+ // (A) Arabic (Egyptian).
+ "\u0644\u064A\u0647\u0645\u0627\u0628\u062A\u0643\u0644" +
+ "\u0645\u0648\u0634\u0639\u0631\u0628\u064A\u061F",
+ "egbpdaj6bu4bxfgehfvwxn",
+ },
+ {
+ // (B) Chinese (simplified).
+ "\u4ED6\u4EEC\u4E3A\u4EC0\u4E48\u4E0D\u8BF4\u4E2D\u6587",
+ "ihqwcrb4cv8a8dqg056pqjye",
+ },
+ {
+ // (C) Chinese (traditional).
+ "\u4ED6\u5011\u7232\u4EC0\u9EBD\u4E0D\u8AAA\u4E2D\u6587",
+ "ihqwctvzc91f659drss3x8bo0yb",
+ },
+ {
+ // (D) Czech.
+ "\u0050\u0072\u006F\u010D\u0070\u0072\u006F\u0073\u0074" +
+ "\u011B\u006E\u0065\u006D\u006C\u0075\u0076\u00ED\u010D" +
+ "\u0065\u0073\u006B\u0079",
+ "Proprostnemluvesky-uyb24dma41a",
+ },
+ {
+ // (E) Hebrew.
+ "\u05DC\u05DE\u05D4\u05D4\u05DD\u05E4\u05E9\u05D5\u05D8" +
+ "\u05DC\u05D0\u05DE\u05D3\u05D1\u05E8\u05D9\u05DD\u05E2" +
+ "\u05D1\u05E8\u05D9\u05EA",
+ "4dbcagdahymbxekheh6e0a7fei0b",
+ },
+ {
+ // (F) Hindi (Devanagari).
+ "\u092F\u0939\u0932\u094B\u0917\u0939\u093F\u0928\u094D" +
+ "\u0926\u0940\u0915\u094D\u092F\u094B\u0902\u0928\u0939" +
+ "\u0940\u0902\u092C\u094B\u0932\u0938\u0915\u0924\u0947" +
+ "\u0939\u0948\u0902",
+ "i1baa7eci9glrd9b2ae1bj0hfcgg6iyaf8o0a1dig0cd",
+ },
+ {
+ // (G) Japanese (kanji and hiragana).
+ "\u306A\u305C\u307F\u3093\u306A\u65E5\u672C\u8A9E\u3092" +
+ "\u8A71\u3057\u3066\u304F\u308C\u306A\u3044\u306E\u304B",
+ "n8jok5ay5dzabd5bym9f0cm5685rrjetr6pdxa",
+ },
+ {
+ // (H) Korean (Hangul syllables).
+ "\uC138\uACC4\uC758\uBAA8\uB4E0\uC0AC\uB78C\uB4E4\uC774" +
+ "\uD55C\uAD6D\uC5B4\uB97C\uC774\uD574\uD55C\uB2E4\uBA74" +
+ "\uC5BC\uB9C8\uB098\uC88B\uC744\uAE4C",
+ "989aomsvi5e83db1d2a355cv1e0vak1dwrv93d5xbh15a0dt30a5j" +
+ "psd879ccm6fea98c",
+ },
+ {
+ // (I) Russian (Cyrillic).
+ "\u043F\u043E\u0447\u0435\u043C\u0443\u0436\u0435\u043E" +
+ "\u043D\u0438\u043D\u0435\u0433\u043E\u0432\u043E\u0440" +
+ "\u044F\u0442\u043F\u043E\u0440\u0443\u0441\u0441\u043A" +
+ "\u0438",
+ "b1abfaaepdrnnbgefbadotcwatmq2g4l",
+ },
+ {
+ // (J) Spanish.
+ "\u0050\u006F\u0072\u0071\u0075\u00E9\u006E\u006F\u0070" +
+ "\u0075\u0065\u0064\u0065\u006E\u0073\u0069\u006D\u0070" +
+ "\u006C\u0065\u006D\u0065\u006E\u0074\u0065\u0068\u0061" +
+ "\u0062\u006C\u0061\u0072\u0065\u006E\u0045\u0073\u0070" +
+ "\u0061\u00F1\u006F\u006C",
+ "PorqunopuedensimplementehablarenEspaol-fmd56a",
+ },
+ {
+ // (K) Vietnamese.
+ "\u0054\u1EA1\u0069\u0073\u0061\u006F\u0068\u1ECD\u006B" +
+ "\u0068\u00F4\u006E\u0067\u0074\u0068\u1EC3\u0063\u0068" +
+ "\u1EC9\u006E\u00F3\u0069\u0074\u0069\u1EBF\u006E\u0067" +
+ "\u0056\u0069\u1EC7\u0074",
+ "TisaohkhngthchnitingVit-kjcr8268qyxafd2f1b9g",
+ },
+ {
+ // (L) 3<nen>B<gumi><kinpachi><sensei>.
+ "\u0033\u5E74\u0042\u7D44\u91D1\u516B\u5148\u751F",
+ "3B-ww4c5e180e575a65lsy2b",
+ },
+ {
+ // (M) <amuro><namie>-with-SUPER-MONKEYS.
+ "\u5B89\u5BA4\u5948\u7F8E\u6075\u002D\u0077\u0069\u0074" +
+ "\u0068\u002D\u0053\u0055\u0050\u0045\u0052\u002D\u004D" +
+ "\u004F\u004E\u004B\u0045\u0059\u0053",
+ "-with-SUPER-MONKEYS-pc58ag80a8qai00g7n9n",
+ },
+ {
+ // (N) Hello-Another-Way-<sorezore><no><basho>.
+ "\u0048\u0065\u006C\u006C\u006F\u002D\u0041\u006E\u006F" +
+ "\u0074\u0068\u0065\u0072\u002D\u0057\u0061\u0079\u002D" +
+ "\u305D\u308C\u305E\u308C\u306E\u5834\u6240",
+ "Hello-Another-Way--fc4qua05auwb3674vfr0b",
+ },
+ {
+ // (O) <hitotsu><yane><no><shita>2.
+ "\u3072\u3068\u3064\u5C4B\u6839\u306E\u4E0B\u0032",
+ "2-u9tlzr9756bt3uc0v",
+ },
+ {
+ // (P) Maji<de>Koi<suru>5<byou><mae>
+ "\u004D\u0061\u006A\u0069\u3067\u004B\u006F\u0069\u3059" +
+ "\u308B\u0035\u79D2\u524D",
+ "MajiKoi5-783gue6qz075azm5e",
+ },
+ {
+ // (Q) <pafii>de<runba>
+ "\u30D1\u30D5\u30A3\u30FC\u0064\u0065\u30EB\u30F3\u30D0",
+ "de-jg4avhby1noc0d",
+ },
+ {
+ // (R) <sono><supiido><de>
+ "\u305D\u306E\u30B9\u30D4\u30FC\u30C9\u3067",
+ "d9juau41awczczp",
+ },
+ {
+ // (S) -> $1.00 <-
+ "\u002D\u003E\u0020\u0024\u0031\u002E\u0030\u0030\u0020" +
+ "\u003C\u002D",
+ "-> $1.00 <--",
+ },
+}
+
+func TestPunycode(t *testing.T) {
+ for _, tc := range punycodeTestCases {
+ if got, err := encode("", tc.s); err != nil {
+ t.Errorf(`encode("", %q): %v`, tc.s, err)
+ } else if got != tc.encoded {
+ t.Errorf(`encode("", %q): got %q, want %q`, tc.s, got, tc.encoded)
+ }
+ }
+}
diff --git a/src/net/http/doc.go b/src/net/http/doc.go
new file mode 100644
index 0000000..d9e6aaf
--- /dev/null
+++ b/src/net/http/doc.go
@@ -0,0 +1,110 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+/*
+Package http provides HTTP client and server implementations.
+
+Get, Head, Post, and PostForm make HTTP (or HTTPS) requests:
+
+ resp, err := http.Get("http://example.com/")
+ ...
+ resp, err := http.Post("http://example.com/upload", "image/jpeg", &buf)
+ ...
+ resp, err := http.PostForm("http://example.com/form",
+ url.Values{"key": {"Value"}, "id": {"123"}})
+
+The caller must close the response body when finished with it:
+
+ resp, err := http.Get("http://example.com/")
+ if err != nil {
+ // handle error
+ }
+ defer resp.Body.Close()
+ body, err := io.ReadAll(resp.Body)
+ // ...
+
+# Clients and Transports
+
+For control over HTTP client headers, redirect policy, and other
+settings, create a Client:
+
+ client := &http.Client{
+ CheckRedirect: redirectPolicyFunc,
+ }
+
+ resp, err := client.Get("http://example.com")
+ // ...
+
+ req, err := http.NewRequest("GET", "http://example.com", nil)
+ // ...
+ req.Header.Add("If-None-Match", `W/"wyzzy"`)
+ resp, err := client.Do(req)
+ // ...
+
+For control over proxies, TLS configuration, keep-alives,
+compression, and other settings, create a Transport:
+
+ tr := &http.Transport{
+ MaxIdleConns: 10,
+ IdleConnTimeout: 30 * time.Second,
+ DisableCompression: true,
+ }
+ client := &http.Client{Transport: tr}
+ resp, err := client.Get("https://example.com")
+
+Clients and Transports are safe for concurrent use by multiple
+goroutines and for efficiency should only be created once and re-used.
+
+# Servers
+
+ListenAndServe starts an HTTP server with a given address and handler.
+The handler is usually nil, which means to use DefaultServeMux.
+Handle and HandleFunc add handlers to DefaultServeMux:
+
+ http.Handle("/foo", fooHandler)
+
+ http.HandleFunc("/bar", func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprintf(w, "Hello, %q", html.EscapeString(r.URL.Path))
+ })
+
+ log.Fatal(http.ListenAndServe(":8080", nil))
+
+More control over the server's behavior is available by creating a
+custom Server:
+
+ s := &http.Server{
+ Addr: ":8080",
+ Handler: myHandler,
+ ReadTimeout: 10 * time.Second,
+ WriteTimeout: 10 * time.Second,
+ MaxHeaderBytes: 1 << 20,
+ }
+ log.Fatal(s.ListenAndServe())
+
+# HTTP/2
+
+Starting with Go 1.6, the http package has transparent support for the
+HTTP/2 protocol when using HTTPS. Programs that must disable HTTP/2
+can do so by setting Transport.TLSNextProto (for clients) or
+Server.TLSNextProto (for servers) to a non-nil, empty
+map. Alternatively, the following GODEBUG settings are
+currently supported:
+
+ GODEBUG=http2client=0 # disable HTTP/2 client support
+ GODEBUG=http2server=0 # disable HTTP/2 server support
+ GODEBUG=http2debug=1 # enable verbose HTTP/2 debug logs
+ GODEBUG=http2debug=2 # ... even more verbose, with frame dumps
+
+Please report any issues before disabling HTTP/2 support: https://golang.org/s/http2bug
+
+The http package's Transport and Server both automatically enable
+HTTP/2 support for simple configurations. To enable HTTP/2 for more
+complex configurations, to use lower-level HTTP/2 features, or to use
+a newer version of Go's http2 package, import "golang.org/x/net/http2"
+directly and use its ConfigureTransport and/or ConfigureServer
+functions. Manually configuring HTTP/2 via the golang.org/x/net/http2
+package takes precedence over the net/http package's built-in HTTP/2
+support.
+*/
+package http
diff --git a/src/net/http/example_filesystem_test.go b/src/net/http/example_filesystem_test.go
new file mode 100644
index 0000000..0e81458
--- /dev/null
+++ b/src/net/http/example_filesystem_test.go
@@ -0,0 +1,71 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http_test
+
+import (
+ "io/fs"
+ "log"
+ "net/http"
+ "strings"
+)
+
+// containsDotFile reports whether name contains a path element starting with a period.
+// The name is assumed to be a delimited by forward slashes, as guaranteed
+// by the http.FileSystem interface.
+func containsDotFile(name string) bool {
+ parts := strings.Split(name, "/")
+ for _, part := range parts {
+ if strings.HasPrefix(part, ".") {
+ return true
+ }
+ }
+ return false
+}
+
+// dotFileHidingFile is the http.File use in dotFileHidingFileSystem.
+// It is used to wrap the Readdir method of http.File so that we can
+// remove files and directories that start with a period from its output.
+type dotFileHidingFile struct {
+ http.File
+}
+
+// Readdir is a wrapper around the Readdir method of the embedded File
+// that filters out all files that start with a period in their name.
+func (f dotFileHidingFile) Readdir(n int) (fis []fs.FileInfo, err error) {
+ files, err := f.File.Readdir(n)
+ for _, file := range files { // Filters out the dot files
+ if !strings.HasPrefix(file.Name(), ".") {
+ fis = append(fis, file)
+ }
+ }
+ return
+}
+
+// dotFileHidingFileSystem is an http.FileSystem that hides
+// hidden "dot files" from being served.
+type dotFileHidingFileSystem struct {
+ http.FileSystem
+}
+
+// Open is a wrapper around the Open method of the embedded FileSystem
+// that serves a 403 permission error when name has a file or directory
+// with whose name starts with a period in its path.
+func (fsys dotFileHidingFileSystem) Open(name string) (http.File, error) {
+ if containsDotFile(name) { // If dot file, return 403 response
+ return nil, fs.ErrPermission
+ }
+
+ file, err := fsys.FileSystem.Open(name)
+ if err != nil {
+ return nil, err
+ }
+ return dotFileHidingFile{file}, err
+}
+
+func ExampleFileServer_dotFileHiding() {
+ fsys := dotFileHidingFileSystem{http.Dir(".")}
+ http.Handle("/", http.FileServer(fsys))
+ log.Fatal(http.ListenAndServe(":8080", nil))
+}
diff --git a/src/net/http/example_handle_test.go b/src/net/http/example_handle_test.go
new file mode 100644
index 0000000..10a62f6
--- /dev/null
+++ b/src/net/http/example_handle_test.go
@@ -0,0 +1,29 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http_test
+
+import (
+ "fmt"
+ "log"
+ "net/http"
+ "sync"
+)
+
+type countHandler struct {
+ mu sync.Mutex // guards n
+ n int
+}
+
+func (h *countHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ h.mu.Lock()
+ defer h.mu.Unlock()
+ h.n++
+ fmt.Fprintf(w, "count is %d\n", h.n)
+}
+
+func ExampleHandle() {
+ http.Handle("/count", new(countHandler))
+ log.Fatal(http.ListenAndServe(":8080", nil))
+}
diff --git a/src/net/http/example_test.go b/src/net/http/example_test.go
new file mode 100644
index 0000000..2f411d1
--- /dev/null
+++ b/src/net/http/example_test.go
@@ -0,0 +1,195 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http_test
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "log"
+ "net/http"
+ "os"
+ "os/signal"
+)
+
+func ExampleHijacker() {
+ http.HandleFunc("/hijack", func(w http.ResponseWriter, r *http.Request) {
+ hj, ok := w.(http.Hijacker)
+ if !ok {
+ http.Error(w, "webserver doesn't support hijacking", http.StatusInternalServerError)
+ return
+ }
+ conn, bufrw, err := hj.Hijack()
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ // Don't forget to close the connection:
+ defer conn.Close()
+ bufrw.WriteString("Now we're speaking raw TCP. Say hi: ")
+ bufrw.Flush()
+ s, err := bufrw.ReadString('\n')
+ if err != nil {
+ log.Printf("error reading string: %v", err)
+ return
+ }
+ fmt.Fprintf(bufrw, "You said: %q\nBye.\n", s)
+ bufrw.Flush()
+ })
+}
+
+func ExampleGet() {
+ res, err := http.Get("http://www.google.com/robots.txt")
+ if err != nil {
+ log.Fatal(err)
+ }
+ body, err := io.ReadAll(res.Body)
+ res.Body.Close()
+ if res.StatusCode > 299 {
+ log.Fatalf("Response failed with status code: %d and\nbody: %s\n", res.StatusCode, body)
+ }
+ if err != nil {
+ log.Fatal(err)
+ }
+ fmt.Printf("%s", body)
+}
+
+func ExampleFileServer() {
+ // Simple static webserver:
+ log.Fatal(http.ListenAndServe(":8080", http.FileServer(http.Dir("/usr/share/doc"))))
+}
+
+func ExampleFileServer_stripPrefix() {
+ // To serve a directory on disk (/tmp) under an alternate URL
+ // path (/tmpfiles/), use StripPrefix to modify the request
+ // URL's path before the FileServer sees it:
+ http.Handle("/tmpfiles/", http.StripPrefix("/tmpfiles/", http.FileServer(http.Dir("/tmp"))))
+}
+
+func ExampleStripPrefix() {
+ // To serve a directory on disk (/tmp) under an alternate URL
+ // path (/tmpfiles/), use StripPrefix to modify the request
+ // URL's path before the FileServer sees it:
+ http.Handle("/tmpfiles/", http.StripPrefix("/tmpfiles/", http.FileServer(http.Dir("/tmp"))))
+}
+
+type apiHandler struct{}
+
+func (apiHandler) ServeHTTP(http.ResponseWriter, *http.Request) {}
+
+func ExampleServeMux_Handle() {
+ mux := http.NewServeMux()
+ mux.Handle("/api/", apiHandler{})
+ mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
+ // The "/" pattern matches everything, so we need to check
+ // that we're at the root here.
+ if req.URL.Path != "/" {
+ http.NotFound(w, req)
+ return
+ }
+ fmt.Fprintf(w, "Welcome to the home page!")
+ })
+}
+
+// HTTP Trailers are a set of key/value pairs like headers that come
+// after the HTTP response, instead of before.
+func ExampleResponseWriter_trailers() {
+ mux := http.NewServeMux()
+ mux.HandleFunc("/sendstrailers", func(w http.ResponseWriter, req *http.Request) {
+ // Before any call to WriteHeader or Write, declare
+ // the trailers you will set during the HTTP
+ // response. These three headers are actually sent in
+ // the trailer.
+ w.Header().Set("Trailer", "AtEnd1, AtEnd2")
+ w.Header().Add("Trailer", "AtEnd3")
+
+ w.Header().Set("Content-Type", "text/plain; charset=utf-8") // normal header
+ w.WriteHeader(http.StatusOK)
+
+ w.Header().Set("AtEnd1", "value 1")
+ io.WriteString(w, "This HTTP response has both headers before this text and trailers at the end.\n")
+ w.Header().Set("AtEnd2", "value 2")
+ w.Header().Set("AtEnd3", "value 3") // These will appear as trailers.
+ })
+}
+
+func ExampleServer_Shutdown() {
+ var srv http.Server
+
+ idleConnsClosed := make(chan struct{})
+ go func() {
+ sigint := make(chan os.Signal, 1)
+ signal.Notify(sigint, os.Interrupt)
+ <-sigint
+
+ // We received an interrupt signal, shut down.
+ if err := srv.Shutdown(context.Background()); err != nil {
+ // Error from closing listeners, or context timeout:
+ log.Printf("HTTP server Shutdown: %v", err)
+ }
+ close(idleConnsClosed)
+ }()
+
+ if err := srv.ListenAndServe(); err != http.ErrServerClosed {
+ // Error starting or closing listener:
+ log.Fatalf("HTTP server ListenAndServe: %v", err)
+ }
+
+ <-idleConnsClosed
+}
+
+func ExampleListenAndServeTLS() {
+ http.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
+ io.WriteString(w, "Hello, TLS!\n")
+ })
+
+ // One can use generate_cert.go in crypto/tls to generate cert.pem and key.pem.
+ log.Printf("About to listen on 8443. Go to https://127.0.0.1:8443/")
+ err := http.ListenAndServeTLS(":8443", "cert.pem", "key.pem", nil)
+ log.Fatal(err)
+}
+
+func ExampleListenAndServe() {
+ // Hello world, the web server
+
+ helloHandler := func(w http.ResponseWriter, req *http.Request) {
+ io.WriteString(w, "Hello, world!\n")
+ }
+
+ http.HandleFunc("/hello", helloHandler)
+ log.Fatal(http.ListenAndServe(":8080", nil))
+}
+
+func ExampleHandleFunc() {
+ h1 := func(w http.ResponseWriter, _ *http.Request) {
+ io.WriteString(w, "Hello from a HandleFunc #1!\n")
+ }
+ h2 := func(w http.ResponseWriter, _ *http.Request) {
+ io.WriteString(w, "Hello from a HandleFunc #2!\n")
+ }
+
+ http.HandleFunc("/", h1)
+ http.HandleFunc("/endpoint", h2)
+
+ log.Fatal(http.ListenAndServe(":8080", nil))
+}
+
+func newPeopleHandler() http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprintln(w, "This is the people handler.")
+ })
+}
+
+func ExampleNotFoundHandler() {
+ mux := http.NewServeMux()
+
+ // Create sample handler to returns 404
+ mux.Handle("/resources", http.NotFoundHandler())
+
+ // Create sample handler that returns 200
+ mux.Handle("/resources/people/", newPeopleHandler())
+
+ log.Fatal(http.ListenAndServe(":8080", mux))
+}
diff --git a/src/net/http/export_test.go b/src/net/http/export_test.go
new file mode 100644
index 0000000..5d198f3
--- /dev/null
+++ b/src/net/http/export_test.go
@@ -0,0 +1,317 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Bridge package to expose http internals to tests in the http_test
+// package.
+
+package http
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "net/url"
+ "sort"
+ "sync"
+ "testing"
+ "time"
+)
+
+var (
+ DefaultUserAgent = defaultUserAgent
+ NewLoggingConn = newLoggingConn
+ ExportAppendTime = appendTime
+ ExportRefererForURL = refererForURL
+ ExportServerNewConn = (*Server).newConn
+ ExportCloseWriteAndWait = (*conn).closeWriteAndWait
+ ExportErrRequestCanceled = errRequestCanceled
+ ExportErrRequestCanceledConn = errRequestCanceledConn
+ ExportErrServerClosedIdle = errServerClosedIdle
+ ExportServeFile = serveFile
+ ExportScanETag = scanETag
+ ExportHttp2ConfigureServer = http2ConfigureServer
+ Export_shouldCopyHeaderOnRedirect = shouldCopyHeaderOnRedirect
+ Export_writeStatusLine = writeStatusLine
+ Export_is408Message = is408Message
+)
+
+var MaxWriteWaitBeforeConnReuse = &maxWriteWaitBeforeConnReuse
+
+func init() {
+ // We only want to pay for this cost during testing.
+ // When not under test, these values are always nil
+ // and never assigned to.
+ testHookMu = new(sync.Mutex)
+
+ testHookClientDoResult = func(res *Response, err error) {
+ if err != nil {
+ if _, ok := err.(*url.Error); !ok {
+ panic(fmt.Sprintf("unexpected Client.Do error of type %T; want *url.Error", err))
+ }
+ } else {
+ if res == nil {
+ panic("Client.Do returned nil, nil")
+ }
+ if res.Body == nil {
+ panic("Client.Do returned nil res.Body and no error")
+ }
+ }
+ }
+}
+
+func CondSkipHTTP2(t testing.TB) {
+ if omitBundledHTTP2 {
+ t.Skip("skipping HTTP/2 test when nethttpomithttp2 build tag in use")
+ }
+}
+
+var (
+ SetEnterRoundTripHook = hookSetter(&testHookEnterRoundTrip)
+ SetRoundTripRetried = hookSetter(&testHookRoundTripRetried)
+)
+
+func SetReadLoopBeforeNextReadHook(f func()) {
+ unnilTestHook(&f)
+ testHookReadLoopBeforeNextRead = f
+}
+
+// SetPendingDialHooks sets the hooks that run before and after handling
+// pending dials.
+func SetPendingDialHooks(before, after func()) {
+ unnilTestHook(&before)
+ unnilTestHook(&after)
+ testHookPrePendingDial, testHookPostPendingDial = before, after
+}
+
+func SetTestHookServerServe(fn func(*Server, net.Listener)) { testHookServerServe = fn }
+
+func NewTestTimeoutHandler(handler Handler, ctx context.Context) Handler {
+ return &timeoutHandler{
+ handler: handler,
+ testContext: ctx,
+ // (no body)
+ }
+}
+
+func ResetCachedEnvironment() {
+ resetProxyConfig()
+}
+
+func (t *Transport) NumPendingRequestsForTesting() int {
+ t.reqMu.Lock()
+ defer t.reqMu.Unlock()
+ return len(t.reqCanceler)
+}
+
+func (t *Transport) IdleConnKeysForTesting() (keys []string) {
+ keys = make([]string, 0)
+ t.idleMu.Lock()
+ defer t.idleMu.Unlock()
+ for key := range t.idleConn {
+ keys = append(keys, key.String())
+ }
+ sort.Strings(keys)
+ return
+}
+
+func (t *Transport) IdleConnKeyCountForTesting() int {
+ t.idleMu.Lock()
+ defer t.idleMu.Unlock()
+ return len(t.idleConn)
+}
+
+func (t *Transport) IdleConnStrsForTesting() []string {
+ var ret []string
+ t.idleMu.Lock()
+ defer t.idleMu.Unlock()
+ for _, conns := range t.idleConn {
+ for _, pc := range conns {
+ ret = append(ret, pc.conn.LocalAddr().String()+"/"+pc.conn.RemoteAddr().String())
+ }
+ }
+ sort.Strings(ret)
+ return ret
+}
+
+func (t *Transport) IdleConnStrsForTesting_h2() []string {
+ var ret []string
+ noDialPool := t.h2transport.(*http2Transport).ConnPool.(http2noDialClientConnPool)
+ pool := noDialPool.http2clientConnPool
+
+ pool.mu.Lock()
+ defer pool.mu.Unlock()
+
+ for k, ccs := range pool.conns {
+ for _, cc := range ccs {
+ if cc.idleState().canTakeNewRequest {
+ ret = append(ret, k)
+ }
+ }
+ }
+
+ sort.Strings(ret)
+ return ret
+}
+
+func (t *Transport) IdleConnCountForTesting(scheme, addr string) int {
+ t.idleMu.Lock()
+ defer t.idleMu.Unlock()
+ key := connectMethodKey{"", scheme, addr, false}
+ cacheKey := key.String()
+ for k, conns := range t.idleConn {
+ if k.String() == cacheKey {
+ return len(conns)
+ }
+ }
+ return 0
+}
+
+func (t *Transport) IdleConnWaitMapSizeForTesting() int {
+ t.idleMu.Lock()
+ defer t.idleMu.Unlock()
+ return len(t.idleConnWait)
+}
+
+func (t *Transport) IsIdleForTesting() bool {
+ t.idleMu.Lock()
+ defer t.idleMu.Unlock()
+ return t.closeIdle
+}
+
+func (t *Transport) QueueForIdleConnForTesting() {
+ t.queueForIdleConn(nil)
+}
+
+// PutIdleTestConn reports whether it was able to insert a fresh
+// persistConn for scheme, addr into the idle connection pool.
+func (t *Transport) PutIdleTestConn(scheme, addr string) bool {
+ c, _ := net.Pipe()
+ key := connectMethodKey{"", scheme, addr, false}
+
+ if t.MaxConnsPerHost > 0 {
+ // Transport is tracking conns-per-host.
+ // Increment connection count to account
+ // for new persistConn created below.
+ t.connsPerHostMu.Lock()
+ if t.connsPerHost == nil {
+ t.connsPerHost = make(map[connectMethodKey]int)
+ }
+ t.connsPerHost[key]++
+ t.connsPerHostMu.Unlock()
+ }
+
+ return t.tryPutIdleConn(&persistConn{
+ t: t,
+ conn: c, // dummy
+ closech: make(chan struct{}), // so it can be closed
+ cacheKey: key,
+ }) == nil
+}
+
+// PutIdleTestConnH2 reports whether it was able to insert a fresh
+// HTTP/2 persistConn for scheme, addr into the idle connection pool.
+func (t *Transport) PutIdleTestConnH2(scheme, addr string, alt RoundTripper) bool {
+ key := connectMethodKey{"", scheme, addr, false}
+
+ if t.MaxConnsPerHost > 0 {
+ // Transport is tracking conns-per-host.
+ // Increment connection count to account
+ // for new persistConn created below.
+ t.connsPerHostMu.Lock()
+ if t.connsPerHost == nil {
+ t.connsPerHost = make(map[connectMethodKey]int)
+ }
+ t.connsPerHost[key]++
+ t.connsPerHostMu.Unlock()
+ }
+
+ return t.tryPutIdleConn(&persistConn{
+ t: t,
+ alt: alt,
+ cacheKey: key,
+ }) == nil
+}
+
+// All test hooks must be non-nil so they can be called directly,
+// but the tests use nil to mean hook disabled.
+func unnilTestHook(f *func()) {
+ if *f == nil {
+ *f = nop
+ }
+}
+
+func hookSetter(dst *func()) func(func()) {
+ return func(fn func()) {
+ unnilTestHook(&fn)
+ *dst = fn
+ }
+}
+
+func ExportHttp2ConfigureTransport(t *Transport) error {
+ t2, err := http2configureTransports(t)
+ if err != nil {
+ return err
+ }
+ t.h2transport = t2
+ return nil
+}
+
+func (s *Server) ExportAllConnsIdle() bool {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ for c := range s.activeConn {
+ st, unixSec := c.getState()
+ if unixSec == 0 || st != StateIdle {
+ return false
+ }
+ }
+ return true
+}
+
+func (s *Server) ExportAllConnsByState() map[ConnState]int {
+ states := map[ConnState]int{}
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ for c := range s.activeConn {
+ st, _ := c.getState()
+ states[st] += 1
+ }
+ return states
+}
+
+func (r *Request) WithT(t *testing.T) *Request {
+ return r.WithContext(context.WithValue(r.Context(), tLogKey{}, t.Logf))
+}
+
+func ExportSetH2GoawayTimeout(d time.Duration) (restore func()) {
+ old := http2goAwayTimeout
+ http2goAwayTimeout = d
+ return func() { http2goAwayTimeout = old }
+}
+
+func (r *Request) ExportIsReplayable() bool { return r.isReplayable() }
+
+// ExportCloseTransportConnsAbruptly closes all idle connections from
+// tr in an abrupt way, just reaching into the underlying Conns and
+// closing them, without telling the Transport or its persistConns
+// that it's doing so. This is to simulate the server closing connections
+// on the Transport.
+func ExportCloseTransportConnsAbruptly(tr *Transport) {
+ tr.idleMu.Lock()
+ for _, pcs := range tr.idleConn {
+ for _, pc := range pcs {
+ pc.conn.Close()
+ }
+ }
+ tr.idleMu.Unlock()
+}
+
+// ResponseWriterConnForTesting returns w's underlying connection, if w
+// is a regular *response ResponseWriter.
+func ResponseWriterConnForTesting(w ResponseWriter) (c net.Conn, ok bool) {
+ if r, ok := w.(*response); ok {
+ return r.conn.rwc, true
+ }
+ return nil, false
+}
diff --git a/src/net/http/fcgi/child.go b/src/net/http/fcgi/child.go
new file mode 100644
index 0000000..dc82bf7
--- /dev/null
+++ b/src/net/http/fcgi/child.go
@@ -0,0 +1,395 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package fcgi
+
+// This file implements FastCGI from the perspective of a child process.
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "net/http"
+ "net/http/cgi"
+ "os"
+ "strings"
+ "time"
+)
+
+// request holds the state for an in-progress request. As soon as it's complete,
+// it's converted to an http.Request.
+type request struct {
+ pw *io.PipeWriter
+ reqId uint16
+ params map[string]string
+ buf [1024]byte
+ rawParams []byte
+ keepConn bool
+}
+
+// envVarsContextKey uniquely identifies a mapping of CGI
+// environment variables to their values in a request context
+type envVarsContextKey struct{}
+
+func newRequest(reqId uint16, flags uint8) *request {
+ r := &request{
+ reqId: reqId,
+ params: map[string]string{},
+ keepConn: flags&flagKeepConn != 0,
+ }
+ r.rawParams = r.buf[:0]
+ return r
+}
+
+// parseParams reads an encoded []byte into Params.
+func (r *request) parseParams() {
+ text := r.rawParams
+ r.rawParams = nil
+ for len(text) > 0 {
+ keyLen, n := readSize(text)
+ if n == 0 {
+ return
+ }
+ text = text[n:]
+ valLen, n := readSize(text)
+ if n == 0 {
+ return
+ }
+ text = text[n:]
+ if int(keyLen)+int(valLen) > len(text) {
+ return
+ }
+ key := readString(text, keyLen)
+ text = text[keyLen:]
+ val := readString(text, valLen)
+ text = text[valLen:]
+ r.params[key] = val
+ }
+}
+
+// response implements http.ResponseWriter.
+type response struct {
+ req *request
+ header http.Header
+ code int
+ wroteHeader bool
+ wroteCGIHeader bool
+ w *bufWriter
+}
+
+func newResponse(c *child, req *request) *response {
+ return &response{
+ req: req,
+ header: http.Header{},
+ w: newWriter(c.conn, typeStdout, req.reqId),
+ }
+}
+
+func (r *response) Header() http.Header {
+ return r.header
+}
+
+func (r *response) Write(p []byte) (n int, err error) {
+ if !r.wroteHeader {
+ r.WriteHeader(http.StatusOK)
+ }
+ if !r.wroteCGIHeader {
+ r.writeCGIHeader(p)
+ }
+ return r.w.Write(p)
+}
+
+func (r *response) WriteHeader(code int) {
+ if r.wroteHeader {
+ return
+ }
+ r.wroteHeader = true
+ r.code = code
+ if code == http.StatusNotModified {
+ // Must not have body.
+ r.header.Del("Content-Type")
+ r.header.Del("Content-Length")
+ r.header.Del("Transfer-Encoding")
+ }
+ if r.header.Get("Date") == "" {
+ r.header.Set("Date", time.Now().UTC().Format(http.TimeFormat))
+ }
+}
+
+// writeCGIHeader finalizes the header sent to the client and writes it to the output.
+// p is not written by writeHeader, but is the first chunk of the body
+// that will be written. It is sniffed for a Content-Type if none is
+// set explicitly.
+func (r *response) writeCGIHeader(p []byte) {
+ if r.wroteCGIHeader {
+ return
+ }
+ r.wroteCGIHeader = true
+ fmt.Fprintf(r.w, "Status: %d %s\r\n", r.code, http.StatusText(r.code))
+ if _, hasType := r.header["Content-Type"]; r.code != http.StatusNotModified && !hasType {
+ r.header.Set("Content-Type", http.DetectContentType(p))
+ }
+ r.header.Write(r.w)
+ r.w.WriteString("\r\n")
+ r.w.Flush()
+}
+
+func (r *response) Flush() {
+ if !r.wroteHeader {
+ r.WriteHeader(http.StatusOK)
+ }
+ r.w.Flush()
+}
+
+func (r *response) Close() error {
+ r.Flush()
+ return r.w.Close()
+}
+
+type child struct {
+ conn *conn
+ handler http.Handler
+
+ requests map[uint16]*request // keyed by request ID
+}
+
+func newChild(rwc io.ReadWriteCloser, handler http.Handler) *child {
+ return &child{
+ conn: newConn(rwc),
+ handler: handler,
+ requests: make(map[uint16]*request),
+ }
+}
+
+func (c *child) serve() {
+ defer c.conn.Close()
+ defer c.cleanUp()
+ var rec record
+ for {
+ if err := rec.read(c.conn.rwc); err != nil {
+ return
+ }
+ if err := c.handleRecord(&rec); err != nil {
+ return
+ }
+ }
+}
+
+var errCloseConn = errors.New("fcgi: connection should be closed")
+
+var emptyBody = io.NopCloser(strings.NewReader(""))
+
+// ErrRequestAborted is returned by Read when a handler attempts to read the
+// body of a request that has been aborted by the web server.
+var ErrRequestAborted = errors.New("fcgi: request aborted by web server")
+
+// ErrConnClosed is returned by Read when a handler attempts to read the body of
+// a request after the connection to the web server has been closed.
+var ErrConnClosed = errors.New("fcgi: connection to web server closed")
+
+func (c *child) handleRecord(rec *record) error {
+ req, ok := c.requests[rec.h.Id]
+ if !ok && rec.h.Type != typeBeginRequest && rec.h.Type != typeGetValues {
+ // The spec says to ignore unknown request IDs.
+ return nil
+ }
+
+ switch rec.h.Type {
+ case typeBeginRequest:
+ if req != nil {
+ // The server is trying to begin a request with the same ID
+ // as an in-progress request. This is an error.
+ return errors.New("fcgi: received ID that is already in-flight")
+ }
+
+ var br beginRequest
+ if err := br.read(rec.content()); err != nil {
+ return err
+ }
+ if br.role != roleResponder {
+ c.conn.writeEndRequest(rec.h.Id, 0, statusUnknownRole)
+ return nil
+ }
+ req = newRequest(rec.h.Id, br.flags)
+ c.requests[rec.h.Id] = req
+ return nil
+ case typeParams:
+ // NOTE(eds): Technically a key-value pair can straddle the boundary
+ // between two packets. We buffer until we've received all parameters.
+ if len(rec.content()) > 0 {
+ req.rawParams = append(req.rawParams, rec.content()...)
+ return nil
+ }
+ req.parseParams()
+ return nil
+ case typeStdin:
+ content := rec.content()
+ if req.pw == nil {
+ var body io.ReadCloser
+ if len(content) > 0 {
+ // body could be an io.LimitReader, but it shouldn't matter
+ // as long as both sides are behaving.
+ body, req.pw = io.Pipe()
+ } else {
+ body = emptyBody
+ }
+ go c.serveRequest(req, body)
+ }
+ if len(content) > 0 {
+ // TODO(eds): This blocks until the handler reads from the pipe.
+ // If the handler takes a long time, it might be a problem.
+ req.pw.Write(content)
+ } else {
+ delete(c.requests, req.reqId)
+ if req.pw != nil {
+ req.pw.Close()
+ }
+ }
+ return nil
+ case typeGetValues:
+ values := map[string]string{"FCGI_MPXS_CONNS": "1"}
+ c.conn.writePairs(typeGetValuesResult, 0, values)
+ return nil
+ case typeData:
+ // If the filter role is implemented, read the data stream here.
+ return nil
+ case typeAbortRequest:
+ delete(c.requests, rec.h.Id)
+ c.conn.writeEndRequest(rec.h.Id, 0, statusRequestComplete)
+ if req.pw != nil {
+ req.pw.CloseWithError(ErrRequestAborted)
+ }
+ if !req.keepConn {
+ // connection will close upon return
+ return errCloseConn
+ }
+ return nil
+ default:
+ b := make([]byte, 8)
+ b[0] = byte(rec.h.Type)
+ c.conn.writeRecord(typeUnknownType, 0, b)
+ return nil
+ }
+}
+
+// filterOutUsedEnvVars returns a new map of env vars without the
+// variables in the given envVars map that are read for creating each http.Request
+func filterOutUsedEnvVars(envVars map[string]string) map[string]string {
+ withoutUsedEnvVars := make(map[string]string)
+ for k, v := range envVars {
+ if addFastCGIEnvToContext(k) {
+ withoutUsedEnvVars[k] = v
+ }
+ }
+ return withoutUsedEnvVars
+}
+
+func (c *child) serveRequest(req *request, body io.ReadCloser) {
+ r := newResponse(c, req)
+ httpReq, err := cgi.RequestFromMap(req.params)
+ if err != nil {
+ // there was an error reading the request
+ r.WriteHeader(http.StatusInternalServerError)
+ c.conn.writeRecord(typeStderr, req.reqId, []byte(err.Error()))
+ } else {
+ httpReq.Body = body
+ withoutUsedEnvVars := filterOutUsedEnvVars(req.params)
+ envVarCtx := context.WithValue(httpReq.Context(), envVarsContextKey{}, withoutUsedEnvVars)
+ httpReq = httpReq.WithContext(envVarCtx)
+ c.handler.ServeHTTP(r, httpReq)
+ }
+ // Make sure we serve something even if nothing was written to r
+ r.Write(nil)
+ r.Close()
+ c.conn.writeEndRequest(req.reqId, 0, statusRequestComplete)
+
+ // Consume the entire body, so the host isn't still writing to
+ // us when we close the socket below in the !keepConn case,
+ // otherwise we'd send a RST. (golang.org/issue/4183)
+ // TODO(bradfitz): also bound this copy in time. Or send
+ // some sort of abort request to the host, so the host
+ // can properly cut off the client sending all the data.
+ // For now just bound it a little and
+ io.CopyN(io.Discard, body, 100<<20)
+ body.Close()
+
+ if !req.keepConn {
+ c.conn.Close()
+ }
+}
+
+func (c *child) cleanUp() {
+ for _, req := range c.requests {
+ if req.pw != nil {
+ // race with call to Close in c.serveRequest doesn't matter because
+ // Pipe(Reader|Writer).Close are idempotent
+ req.pw.CloseWithError(ErrConnClosed)
+ }
+ }
+}
+
+// Serve accepts incoming FastCGI connections on the listener l, creating a new
+// goroutine for each. The goroutine reads requests and then calls handler
+// to reply to them.
+// If l is nil, Serve accepts connections from os.Stdin.
+// If handler is nil, http.DefaultServeMux is used.
+func Serve(l net.Listener, handler http.Handler) error {
+ if l == nil {
+ var err error
+ l, err = net.FileListener(os.Stdin)
+ if err != nil {
+ return err
+ }
+ defer l.Close()
+ }
+ if handler == nil {
+ handler = http.DefaultServeMux
+ }
+ for {
+ rw, err := l.Accept()
+ if err != nil {
+ return err
+ }
+ c := newChild(rw, handler)
+ go c.serve()
+ }
+}
+
+// ProcessEnv returns FastCGI environment variables associated with the request r
+// for which no effort was made to be included in the request itself - the data
+// is hidden in the request's context. As an example, if REMOTE_USER is set for a
+// request, it will not be found anywhere in r, but it will be included in
+// ProcessEnv's response (via r's context).
+func ProcessEnv(r *http.Request) map[string]string {
+ env, _ := r.Context().Value(envVarsContextKey{}).(map[string]string)
+ return env
+}
+
+// addFastCGIEnvToContext reports whether to include the FastCGI environment variable s
+// in the http.Request.Context, accessible via ProcessEnv.
+func addFastCGIEnvToContext(s string) bool {
+ // Exclude things supported by net/http natively:
+ switch s {
+ case "CONTENT_LENGTH", "CONTENT_TYPE", "HTTPS",
+ "PATH_INFO", "QUERY_STRING", "REMOTE_ADDR",
+ "REMOTE_HOST", "REMOTE_PORT", "REQUEST_METHOD",
+ "REQUEST_URI", "SCRIPT_NAME", "SERVER_PROTOCOL":
+ return false
+ }
+ if strings.HasPrefix(s, "HTTP_") {
+ return false
+ }
+ // Explicitly include FastCGI-specific things.
+ // This list is redundant with the default "return true" below.
+ // Consider this documentation of the sorts of things we expect
+ // to maybe see.
+ switch s {
+ case "REMOTE_USER":
+ return true
+ }
+ // Unknown, so include it to be safe.
+ return true
+}
diff --git a/src/net/http/fcgi/fcgi.go b/src/net/http/fcgi/fcgi.go
new file mode 100644
index 0000000..56f7d40
--- /dev/null
+++ b/src/net/http/fcgi/fcgi.go
@@ -0,0 +1,277 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package fcgi implements the FastCGI protocol.
+//
+// See https://fast-cgi.github.io/ for an unofficial mirror of the
+// original documentation.
+//
+// Currently only the responder role is supported.
+package fcgi
+
+// This file defines the raw protocol and some utilities used by the child and
+// the host.
+
+import (
+ "bufio"
+ "bytes"
+ "encoding/binary"
+ "errors"
+ "io"
+ "sync"
+)
+
+// recType is a record type, as defined by
+// https://web.archive.org/web/20150420080736/http://www.fastcgi.com/drupal/node/6?q=node/22#S8
+type recType uint8
+
+const (
+ typeBeginRequest recType = 1
+ typeAbortRequest recType = 2
+ typeEndRequest recType = 3
+ typeParams recType = 4
+ typeStdin recType = 5
+ typeStdout recType = 6
+ typeStderr recType = 7
+ typeData recType = 8
+ typeGetValues recType = 9
+ typeGetValuesResult recType = 10
+ typeUnknownType recType = 11
+)
+
+// keep the connection between web-server and responder open after request
+const flagKeepConn = 1
+
+const (
+ maxWrite = 65535 // maximum record body
+ maxPad = 255
+)
+
+const (
+ roleResponder = iota + 1 // only Responders are implemented.
+ roleAuthorizer
+ roleFilter
+)
+
+const (
+ statusRequestComplete = iota
+ statusCantMultiplex
+ statusOverloaded
+ statusUnknownRole
+)
+
+type header struct {
+ Version uint8
+ Type recType
+ Id uint16
+ ContentLength uint16
+ PaddingLength uint8
+ Reserved uint8
+}
+
+type beginRequest struct {
+ role uint16
+ flags uint8
+ reserved [5]uint8
+}
+
+func (br *beginRequest) read(content []byte) error {
+ if len(content) != 8 {
+ return errors.New("fcgi: invalid begin request record")
+ }
+ br.role = binary.BigEndian.Uint16(content)
+ br.flags = content[2]
+ return nil
+}
+
+// for padding so we don't have to allocate all the time
+// not synchronized because we don't care what the contents are
+var pad [maxPad]byte
+
+func (h *header) init(recType recType, reqId uint16, contentLength int) {
+ h.Version = 1
+ h.Type = recType
+ h.Id = reqId
+ h.ContentLength = uint16(contentLength)
+ h.PaddingLength = uint8(-contentLength & 7)
+}
+
+// conn sends records over rwc
+type conn struct {
+ mutex sync.Mutex
+ rwc io.ReadWriteCloser
+ closeErr error
+ closed bool
+
+ // to avoid allocations
+ buf bytes.Buffer
+ h header
+}
+
+func newConn(rwc io.ReadWriteCloser) *conn {
+ return &conn{rwc: rwc}
+}
+
+// Close closes the conn if it is not already closed.
+func (c *conn) Close() error {
+ c.mutex.Lock()
+ defer c.mutex.Unlock()
+ if !c.closed {
+ c.closeErr = c.rwc.Close()
+ c.closed = true
+ }
+ return c.closeErr
+}
+
+type record struct {
+ h header
+ buf [maxWrite + maxPad]byte
+}
+
+func (rec *record) read(r io.Reader) (err error) {
+ if err = binary.Read(r, binary.BigEndian, &rec.h); err != nil {
+ return err
+ }
+ if rec.h.Version != 1 {
+ return errors.New("fcgi: invalid header version")
+ }
+ n := int(rec.h.ContentLength) + int(rec.h.PaddingLength)
+ if _, err = io.ReadFull(r, rec.buf[:n]); err != nil {
+ return err
+ }
+ return nil
+}
+
+func (r *record) content() []byte {
+ return r.buf[:r.h.ContentLength]
+}
+
+// writeRecord writes and sends a single record.
+func (c *conn) writeRecord(recType recType, reqId uint16, b []byte) error {
+ c.mutex.Lock()
+ defer c.mutex.Unlock()
+ c.buf.Reset()
+ c.h.init(recType, reqId, len(b))
+ if err := binary.Write(&c.buf, binary.BigEndian, c.h); err != nil {
+ return err
+ }
+ if _, err := c.buf.Write(b); err != nil {
+ return err
+ }
+ if _, err := c.buf.Write(pad[:c.h.PaddingLength]); err != nil {
+ return err
+ }
+ _, err := c.rwc.Write(c.buf.Bytes())
+ return err
+}
+
+func (c *conn) writeEndRequest(reqId uint16, appStatus int, protocolStatus uint8) error {
+ b := make([]byte, 8)
+ binary.BigEndian.PutUint32(b, uint32(appStatus))
+ b[4] = protocolStatus
+ return c.writeRecord(typeEndRequest, reqId, b)
+}
+
+func (c *conn) writePairs(recType recType, reqId uint16, pairs map[string]string) error {
+ w := newWriter(c, recType, reqId)
+ b := make([]byte, 8)
+ for k, v := range pairs {
+ n := encodeSize(b, uint32(len(k)))
+ n += encodeSize(b[n:], uint32(len(v)))
+ if _, err := w.Write(b[:n]); err != nil {
+ return err
+ }
+ if _, err := w.WriteString(k); err != nil {
+ return err
+ }
+ if _, err := w.WriteString(v); err != nil {
+ return err
+ }
+ }
+ w.Close()
+ return nil
+}
+
+func readSize(s []byte) (uint32, int) {
+ if len(s) == 0 {
+ return 0, 0
+ }
+ size, n := uint32(s[0]), 1
+ if size&(1<<7) != 0 {
+ if len(s) < 4 {
+ return 0, 0
+ }
+ n = 4
+ size = binary.BigEndian.Uint32(s)
+ size &^= 1 << 31
+ }
+ return size, n
+}
+
+func readString(s []byte, size uint32) string {
+ if size > uint32(len(s)) {
+ return ""
+ }
+ return string(s[:size])
+}
+
+func encodeSize(b []byte, size uint32) int {
+ if size > 127 {
+ size |= 1 << 31
+ binary.BigEndian.PutUint32(b, size)
+ return 4
+ }
+ b[0] = byte(size)
+ return 1
+}
+
+// bufWriter encapsulates bufio.Writer but also closes the underlying stream when
+// Closed.
+type bufWriter struct {
+ closer io.Closer
+ *bufio.Writer
+}
+
+func (w *bufWriter) Close() error {
+ if err := w.Writer.Flush(); err != nil {
+ w.closer.Close()
+ return err
+ }
+ return w.closer.Close()
+}
+
+func newWriter(c *conn, recType recType, reqId uint16) *bufWriter {
+ s := &streamWriter{c: c, recType: recType, reqId: reqId}
+ w := bufio.NewWriterSize(s, maxWrite)
+ return &bufWriter{s, w}
+}
+
+// streamWriter abstracts out the separation of a stream into discrete records.
+// It only writes maxWrite bytes at a time.
+type streamWriter struct {
+ c *conn
+ recType recType
+ reqId uint16
+}
+
+func (w *streamWriter) Write(p []byte) (int, error) {
+ nn := 0
+ for len(p) > 0 {
+ n := len(p)
+ if n > maxWrite {
+ n = maxWrite
+ }
+ if err := w.c.writeRecord(w.recType, w.reqId, p[:n]); err != nil {
+ return nn, err
+ }
+ nn += n
+ p = p[n:]
+ }
+ return nn, nil
+}
+
+func (w *streamWriter) Close() error {
+ // send empty record to close the stream
+ return w.c.writeRecord(w.recType, w.reqId, nil)
+}
diff --git a/src/net/http/fcgi/fcgi_test.go b/src/net/http/fcgi/fcgi_test.go
new file mode 100644
index 0000000..03c4224
--- /dev/null
+++ b/src/net/http/fcgi/fcgi_test.go
@@ -0,0 +1,453 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package fcgi
+
+import (
+ "bytes"
+ "errors"
+ "io"
+ "net/http"
+ "strings"
+ "testing"
+ "time"
+)
+
+var sizeTests = []struct {
+ size uint32
+ bytes []byte
+}{
+ {0, []byte{0x00}},
+ {127, []byte{0x7F}},
+ {128, []byte{0x80, 0x00, 0x00, 0x80}},
+ {1000, []byte{0x80, 0x00, 0x03, 0xE8}},
+ {33554431, []byte{0x81, 0xFF, 0xFF, 0xFF}},
+}
+
+func TestSize(t *testing.T) {
+ b := make([]byte, 4)
+ for i, test := range sizeTests {
+ n := encodeSize(b, test.size)
+ if !bytes.Equal(b[:n], test.bytes) {
+ t.Errorf("%d expected %x, encoded %x", i, test.bytes, b)
+ }
+ size, n := readSize(test.bytes)
+ if size != test.size {
+ t.Errorf("%d expected %d, read %d", i, test.size, size)
+ }
+ if len(test.bytes) != n {
+ t.Errorf("%d did not consume all the bytes", i)
+ }
+ }
+}
+
+var streamTests = []struct {
+ desc string
+ recType recType
+ reqId uint16
+ content []byte
+ raw []byte
+}{
+ {"single record", typeStdout, 1, nil,
+ []byte{1, byte(typeStdout), 0, 1, 0, 0, 0, 0},
+ },
+ // this data will have to be split into two records
+ {"two records", typeStdin, 300, make([]byte, 66000),
+ bytes.Join([][]byte{
+ // header for the first record
+ {1, byte(typeStdin), 0x01, 0x2C, 0xFF, 0xFF, 1, 0},
+ make([]byte, 65536),
+ // header for the second
+ {1, byte(typeStdin), 0x01, 0x2C, 0x01, 0xD1, 7, 0},
+ make([]byte, 472),
+ // header for the empty record
+ {1, byte(typeStdin), 0x01, 0x2C, 0, 0, 0, 0},
+ },
+ nil),
+ },
+}
+
+type nilCloser struct {
+ io.ReadWriter
+}
+
+func (c *nilCloser) Close() error { return nil }
+
+func TestStreams(t *testing.T) {
+ var rec record
+outer:
+ for _, test := range streamTests {
+ buf := bytes.NewBuffer(test.raw)
+ var content []byte
+ for buf.Len() > 0 {
+ if err := rec.read(buf); err != nil {
+ t.Errorf("%s: error reading record: %v", test.desc, err)
+ continue outer
+ }
+ content = append(content, rec.content()...)
+ }
+ if rec.h.Type != test.recType {
+ t.Errorf("%s: got type %d expected %d", test.desc, rec.h.Type, test.recType)
+ continue
+ }
+ if rec.h.Id != test.reqId {
+ t.Errorf("%s: got request ID %d expected %d", test.desc, rec.h.Id, test.reqId)
+ continue
+ }
+ if !bytes.Equal(content, test.content) {
+ t.Errorf("%s: read wrong content", test.desc)
+ continue
+ }
+ buf.Reset()
+ c := newConn(&nilCloser{buf})
+ w := newWriter(c, test.recType, test.reqId)
+ if _, err := w.Write(test.content); err != nil {
+ t.Errorf("%s: error writing record: %v", test.desc, err)
+ continue
+ }
+ if err := w.Close(); err != nil {
+ t.Errorf("%s: error closing stream: %v", test.desc, err)
+ continue
+ }
+ if !bytes.Equal(buf.Bytes(), test.raw) {
+ t.Errorf("%s: wrote wrong content", test.desc)
+ }
+ }
+}
+
+type writeOnlyConn struct {
+ buf []byte
+}
+
+func (c *writeOnlyConn) Write(p []byte) (int, error) {
+ c.buf = append(c.buf, p...)
+ return len(p), nil
+}
+
+func (c *writeOnlyConn) Read(p []byte) (int, error) {
+ return 0, errors.New("conn is write-only")
+}
+
+func (c *writeOnlyConn) Close() error {
+ return nil
+}
+
+func TestGetValues(t *testing.T) {
+ var rec record
+ rec.h.Type = typeGetValues
+
+ wc := new(writeOnlyConn)
+ c := newChild(wc, nil)
+ err := c.handleRecord(&rec)
+ if err != nil {
+ t.Fatalf("handleRecord: %v", err)
+ }
+
+ const want = "\x01\n\x00\x00\x00\x12\x06\x00" +
+ "\x0f\x01FCGI_MPXS_CONNS1" +
+ "\x00\x00\x00\x00\x00\x00\x01\n\x00\x00\x00\x00\x00\x00"
+ if got := string(wc.buf); got != want {
+ t.Errorf(" got: %q\nwant: %q\n", got, want)
+ }
+}
+
+func nameValuePair11(nameData, valueData string) []byte {
+ return bytes.Join(
+ [][]byte{
+ {byte(len(nameData)), byte(len(valueData))},
+ []byte(nameData),
+ []byte(valueData),
+ },
+ nil,
+ )
+}
+
+func makeRecord(
+ recordType recType,
+ requestId uint16,
+ contentData []byte,
+) []byte {
+ requestIdB1 := byte(requestId >> 8)
+ requestIdB0 := byte(requestId)
+
+ contentLength := len(contentData)
+ contentLengthB1 := byte(contentLength >> 8)
+ contentLengthB0 := byte(contentLength)
+ return bytes.Join([][]byte{
+ {1, byte(recordType), requestIdB1, requestIdB0, contentLengthB1,
+ contentLengthB0, 0, 0},
+ contentData,
+ },
+ nil)
+}
+
+// a series of FastCGI records that start a request and begin sending the
+// request body
+var streamBeginTypeStdin = bytes.Join([][]byte{
+ // set up request 1
+ makeRecord(typeBeginRequest, 1,
+ []byte{0, byte(roleResponder), 0, 0, 0, 0, 0, 0}),
+ // add required parameters to request 1
+ makeRecord(typeParams, 1, nameValuePair11("REQUEST_METHOD", "GET")),
+ makeRecord(typeParams, 1, nameValuePair11("SERVER_PROTOCOL", "HTTP/1.1")),
+ makeRecord(typeParams, 1, nil),
+ // begin sending body of request 1
+ makeRecord(typeStdin, 1, []byte("0123456789abcdef")),
+},
+ nil)
+
+var cleanUpTests = []struct {
+ input []byte
+ err error
+}{
+ // confirm that child.handleRecord closes req.pw after aborting req
+ {
+ bytes.Join([][]byte{
+ streamBeginTypeStdin,
+ makeRecord(typeAbortRequest, 1, nil),
+ },
+ nil),
+ ErrRequestAborted,
+ },
+ // confirm that child.serve closes all pipes after error reading record
+ {
+ bytes.Join([][]byte{
+ streamBeginTypeStdin,
+ nil,
+ },
+ nil),
+ ErrConnClosed,
+ },
+}
+
+type nopWriteCloser struct {
+ io.Reader
+}
+
+func (nopWriteCloser) Write(buf []byte) (int, error) {
+ return len(buf), nil
+}
+
+func (nopWriteCloser) Close() error {
+ return nil
+}
+
+// Test that child.serve closes the bodies of aborted requests and closes the
+// bodies of all requests before returning. Causes deadlock if either condition
+// isn't met. See issue 6934.
+func TestChildServeCleansUp(t *testing.T) {
+ for _, tt := range cleanUpTests {
+ input := make([]byte, len(tt.input))
+ copy(input, tt.input)
+ rc := nopWriteCloser{bytes.NewReader(input)}
+ done := make(chan struct{})
+ c := newChild(rc, http.HandlerFunc(func(
+ w http.ResponseWriter,
+ r *http.Request,
+ ) {
+ // block on reading body of request
+ _, err := io.Copy(io.Discard, r.Body)
+ if err != tt.err {
+ t.Errorf("Expected %#v, got %#v", tt.err, err)
+ }
+ // not reached if body of request isn't closed
+ close(done)
+ }))
+ c.serve()
+ // wait for body of request to be closed or all goroutines to block
+ <-done
+ }
+}
+
+type rwNopCloser struct {
+ io.Reader
+ io.Writer
+}
+
+func (rwNopCloser) Close() error {
+ return nil
+}
+
+// Verifies it doesn't crash. Issue 11824.
+func TestMalformedParams(t *testing.T) {
+ input := []byte{
+ // beginRequest, requestId=1, contentLength=8, role=1, keepConn=1
+ 1, 1, 0, 1, 0, 8, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0,
+ // params, requestId=1, contentLength=10, k1Len=50, v1Len=50 (malformed, wrong length)
+ 1, 4, 0, 1, 0, 10, 0, 0, 50, 50, 3, 4, 5, 6, 7, 8, 9, 10,
+ // end of params
+ 1, 4, 0, 1, 0, 0, 0, 0,
+ }
+ rw := rwNopCloser{bytes.NewReader(input), io.Discard}
+ c := newChild(rw, http.DefaultServeMux)
+ c.serve()
+}
+
+// a series of FastCGI records that start and end a request
+var streamFullRequestStdin = bytes.Join([][]byte{
+ // set up request
+ makeRecord(typeBeginRequest, 1,
+ []byte{0, byte(roleResponder), 0, 0, 0, 0, 0, 0}),
+ // add required parameters
+ makeRecord(typeParams, 1, nameValuePair11("REQUEST_METHOD", "GET")),
+ makeRecord(typeParams, 1, nameValuePair11("SERVER_PROTOCOL", "HTTP/1.1")),
+ // set optional parameters
+ makeRecord(typeParams, 1, nameValuePair11("REMOTE_USER", "jane.doe")),
+ makeRecord(typeParams, 1, nameValuePair11("QUERY_STRING", "/foo/bar")),
+ makeRecord(typeParams, 1, nil),
+ // begin sending body of request
+ makeRecord(typeStdin, 1, []byte("0123456789abcdef")),
+ // end request
+ makeRecord(typeEndRequest, 1, nil),
+},
+ nil)
+
+var envVarTests = []struct {
+ input []byte
+ envVar string
+ expectedVal string
+ expectedFilteredOut bool
+}{
+ {
+ streamFullRequestStdin,
+ "REMOTE_USER",
+ "jane.doe",
+ false,
+ },
+ {
+ streamFullRequestStdin,
+ "QUERY_STRING",
+ "",
+ true,
+ },
+}
+
+// Test that environment variables set for a request can be
+// read by a handler. Ensures that variables not set will not
+// be exposed to a handler.
+func TestChildServeReadsEnvVars(t *testing.T) {
+ for _, tt := range envVarTests {
+ input := make([]byte, len(tt.input))
+ copy(input, tt.input)
+ rc := nopWriteCloser{bytes.NewReader(input)}
+ done := make(chan struct{})
+ c := newChild(rc, http.HandlerFunc(func(
+ w http.ResponseWriter,
+ r *http.Request,
+ ) {
+ env := ProcessEnv(r)
+ if _, ok := env[tt.envVar]; ok && tt.expectedFilteredOut {
+ t.Errorf("Expected environment variable %s to not be set, but set to %s",
+ tt.envVar, env[tt.envVar])
+ } else if env[tt.envVar] != tt.expectedVal {
+ t.Errorf("Expected %s, got %s", tt.expectedVal, env[tt.envVar])
+ }
+ close(done)
+ }))
+ c.serve()
+ <-done
+ }
+}
+
+func TestResponseWriterSniffsContentType(t *testing.T) {
+ var tests = []struct {
+ name string
+ body string
+ wantCT string
+ }{
+ {
+ name: "no body",
+ wantCT: "text/plain; charset=utf-8",
+ },
+ {
+ name: "html",
+ body: "<html><head><title>test page</title></head><body>This is a body</body></html>",
+ wantCT: "text/html; charset=utf-8",
+ },
+ {
+ name: "text",
+ body: strings.Repeat("gopher", 86),
+ wantCT: "text/plain; charset=utf-8",
+ },
+ {
+ name: "jpg",
+ body: "\xFF\xD8\xFF" + strings.Repeat("B", 1024),
+ wantCT: "image/jpeg",
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ input := make([]byte, len(streamFullRequestStdin))
+ copy(input, streamFullRequestStdin)
+ rc := nopWriteCloser{bytes.NewReader(input)}
+ done := make(chan struct{})
+ var resp *response
+ c := newChild(rc, http.HandlerFunc(func(
+ w http.ResponseWriter,
+ r *http.Request,
+ ) {
+ io.WriteString(w, tt.body)
+ resp = w.(*response)
+ close(done)
+ }))
+ c.serve()
+ <-done
+ if got := resp.Header().Get("Content-Type"); got != tt.wantCT {
+ t.Errorf("got a Content-Type of %q; expected it to start with %q", got, tt.wantCT)
+ }
+ })
+ }
+}
+
+type signalingNopWriteCloser struct {
+ io.ReadCloser
+ closed chan bool
+}
+
+func (*signalingNopWriteCloser) Write(buf []byte) (int, error) {
+ return len(buf), nil
+}
+
+func (rc *signalingNopWriteCloser) Close() error {
+ close(rc.closed)
+ return rc.ReadCloser.Close()
+}
+
+// Test whether server properly closes connection when processing slow
+// requests
+func TestSlowRequest(t *testing.T) {
+ pr, pw := io.Pipe()
+
+ writerDone := make(chan struct{})
+ go func() {
+ for _, buf := range [][]byte{
+ streamBeginTypeStdin,
+ makeRecord(typeStdin, 1, nil),
+ } {
+ pw.Write(buf)
+ time.Sleep(100 * time.Millisecond)
+ }
+ close(writerDone)
+ }()
+ defer func() {
+ <-writerDone
+ pw.Close()
+ }()
+
+ rc := &signalingNopWriteCloser{pr, make(chan bool)}
+ handlerDone := make(chan bool)
+
+ c := newChild(rc, http.HandlerFunc(func(
+ w http.ResponseWriter,
+ r *http.Request,
+ ) {
+ w.WriteHeader(200)
+ close(handlerDone)
+ }))
+ c.serve()
+
+ <-handlerDone
+ <-rc.closed
+ t.Log("FastCGI child closed connection")
+}
diff --git a/src/net/http/filetransport.go b/src/net/http/filetransport.go
new file mode 100644
index 0000000..94684b0
--- /dev/null
+++ b/src/net/http/filetransport.go
@@ -0,0 +1,123 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "fmt"
+ "io"
+)
+
+// fileTransport implements RoundTripper for the 'file' protocol.
+type fileTransport struct {
+ fh fileHandler
+}
+
+// NewFileTransport returns a new RoundTripper, serving the provided
+// FileSystem. The returned RoundTripper ignores the URL host in its
+// incoming requests, as well as most other properties of the
+// request.
+//
+// The typical use case for NewFileTransport is to register the "file"
+// protocol with a Transport, as in:
+//
+// t := &http.Transport{}
+// t.RegisterProtocol("file", http.NewFileTransport(http.Dir("/")))
+// c := &http.Client{Transport: t}
+// res, err := c.Get("file:///etc/passwd")
+// ...
+func NewFileTransport(fs FileSystem) RoundTripper {
+ return fileTransport{fileHandler{fs}}
+}
+
+func (t fileTransport) RoundTrip(req *Request) (resp *Response, err error) {
+ // We start ServeHTTP in a goroutine, which may take a long
+ // time if the file is large. The newPopulateResponseWriter
+ // call returns a channel which either ServeHTTP or finish()
+ // sends our *Response on, once the *Response itself has been
+ // populated (even if the body itself is still being
+ // written to the res.Body, a pipe)
+ rw, resc := newPopulateResponseWriter()
+ go func() {
+ t.fh.ServeHTTP(rw, req)
+ rw.finish()
+ }()
+ return <-resc, nil
+}
+
+func newPopulateResponseWriter() (*populateResponse, <-chan *Response) {
+ pr, pw := io.Pipe()
+ rw := &populateResponse{
+ ch: make(chan *Response),
+ pw: pw,
+ res: &Response{
+ Proto: "HTTP/1.0",
+ ProtoMajor: 1,
+ Header: make(Header),
+ Close: true,
+ Body: pr,
+ },
+ }
+ return rw, rw.ch
+}
+
+// populateResponse is a ResponseWriter that populates the *Response
+// in res, and writes its body to a pipe connected to the response
+// body. Once writes begin or finish() is called, the response is sent
+// on ch.
+type populateResponse struct {
+ res *Response
+ ch chan *Response
+ wroteHeader bool
+ hasContent bool
+ sentResponse bool
+ pw *io.PipeWriter
+}
+
+func (pr *populateResponse) finish() {
+ if !pr.wroteHeader {
+ pr.WriteHeader(500)
+ }
+ if !pr.sentResponse {
+ pr.sendResponse()
+ }
+ pr.pw.Close()
+}
+
+func (pr *populateResponse) sendResponse() {
+ if pr.sentResponse {
+ return
+ }
+ pr.sentResponse = true
+
+ if pr.hasContent {
+ pr.res.ContentLength = -1
+ }
+ pr.ch <- pr.res
+}
+
+func (pr *populateResponse) Header() Header {
+ return pr.res.Header
+}
+
+func (pr *populateResponse) WriteHeader(code int) {
+ if pr.wroteHeader {
+ return
+ }
+ pr.wroteHeader = true
+
+ pr.res.StatusCode = code
+ pr.res.Status = fmt.Sprintf("%d %s", code, StatusText(code))
+}
+
+func (pr *populateResponse) Write(p []byte) (n int, err error) {
+ if !pr.wroteHeader {
+ pr.WriteHeader(StatusOK)
+ }
+ pr.hasContent = true
+ if !pr.sentResponse {
+ pr.sendResponse()
+ }
+ return pr.pw.Write(p)
+}
diff --git a/src/net/http/filetransport_test.go b/src/net/http/filetransport_test.go
new file mode 100644
index 0000000..77fc8ee
--- /dev/null
+++ b/src/net/http/filetransport_test.go
@@ -0,0 +1,64 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "io"
+ "os"
+ "path/filepath"
+ "testing"
+)
+
+func checker(t *testing.T) func(string, error) {
+ return func(call string, err error) {
+ if err == nil {
+ return
+ }
+ t.Fatalf("%s: %v", call, err)
+ }
+}
+
+func TestFileTransport(t *testing.T) {
+ check := checker(t)
+
+ dname := t.TempDir()
+ fname := filepath.Join(dname, "foo.txt")
+ err := os.WriteFile(fname, []byte("Bar"), 0644)
+ check("WriteFile", err)
+ defer os.Remove(fname)
+
+ tr := &Transport{}
+ tr.RegisterProtocol("file", NewFileTransport(Dir(dname)))
+ c := &Client{Transport: tr}
+
+ fooURLs := []string{"file:///foo.txt", "file://../foo.txt"}
+ for _, urlstr := range fooURLs {
+ res, err := c.Get(urlstr)
+ check("Get "+urlstr, err)
+ if res.StatusCode != 200 {
+ t.Errorf("for %s, StatusCode = %d, want 200", urlstr, res.StatusCode)
+ }
+ if res.ContentLength != -1 {
+ t.Errorf("for %s, ContentLength = %d, want -1", urlstr, res.ContentLength)
+ }
+ if res.Body == nil {
+ t.Fatalf("for %s, nil Body", urlstr)
+ }
+ slurp, err := io.ReadAll(res.Body)
+ res.Body.Close()
+ check("ReadAll "+urlstr, err)
+ if string(slurp) != "Bar" {
+ t.Errorf("for %s, got content %q, want %q", urlstr, string(slurp), "Bar")
+ }
+ }
+
+ const badURL = "file://../no-exist.txt"
+ res, err := c.Get(badURL)
+ check("Get "+badURL, err)
+ if res.StatusCode != 404 {
+ t.Errorf("for %s, StatusCode = %d, want 404", badURL, res.StatusCode)
+ }
+ res.Body.Close()
+}
diff --git a/src/net/http/fs.go b/src/net/http/fs.go
new file mode 100644
index 0000000..41e0b43
--- /dev/null
+++ b/src/net/http/fs.go
@@ -0,0 +1,988 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// HTTP file system request handler
+
+package http
+
+import (
+ "errors"
+ "fmt"
+ "internal/safefilepath"
+ "io"
+ "io/fs"
+ "mime"
+ "mime/multipart"
+ "net/textproto"
+ "net/url"
+ "os"
+ "path"
+ "path/filepath"
+ "sort"
+ "strconv"
+ "strings"
+ "time"
+)
+
+// A Dir implements FileSystem using the native file system restricted to a
+// specific directory tree.
+//
+// While the FileSystem.Open method takes '/'-separated paths, a Dir's string
+// value is a filename on the native file system, not a URL, so it is separated
+// by filepath.Separator, which isn't necessarily '/'.
+//
+// Note that Dir could expose sensitive files and directories. Dir will follow
+// symlinks pointing out of the directory tree, which can be especially dangerous
+// if serving from a directory in which users are able to create arbitrary symlinks.
+// Dir will also allow access to files and directories starting with a period,
+// which could expose sensitive directories like .git or sensitive files like
+// .htpasswd. To exclude files with a leading period, remove the files/directories
+// from the server or create a custom FileSystem implementation.
+//
+// An empty Dir is treated as ".".
+type Dir string
+
+// mapOpenError maps the provided non-nil error from opening name
+// to a possibly better non-nil error. In particular, it turns OS-specific errors
+// about opening files in non-directories into fs.ErrNotExist. See Issues 18984 and 49552.
+func mapOpenError(originalErr error, name string, sep rune, stat func(string) (fs.FileInfo, error)) error {
+ if errors.Is(originalErr, fs.ErrNotExist) || errors.Is(originalErr, fs.ErrPermission) {
+ return originalErr
+ }
+
+ parts := strings.Split(name, string(sep))
+ for i := range parts {
+ if parts[i] == "" {
+ continue
+ }
+ fi, err := stat(strings.Join(parts[:i+1], string(sep)))
+ if err != nil {
+ return originalErr
+ }
+ if !fi.IsDir() {
+ return fs.ErrNotExist
+ }
+ }
+ return originalErr
+}
+
+// Open implements FileSystem using os.Open, opening files for reading rooted
+// and relative to the directory d.
+func (d Dir) Open(name string) (File, error) {
+ path, err := safefilepath.FromFS(path.Clean("/" + name))
+ if err != nil {
+ return nil, errors.New("http: invalid or unsafe file path")
+ }
+ dir := string(d)
+ if dir == "" {
+ dir = "."
+ }
+ fullName := filepath.Join(dir, path)
+ f, err := os.Open(fullName)
+ if err != nil {
+ return nil, mapOpenError(err, fullName, filepath.Separator, os.Stat)
+ }
+ return f, nil
+}
+
+// A FileSystem implements access to a collection of named files.
+// The elements in a file path are separated by slash ('/', U+002F)
+// characters, regardless of host operating system convention.
+// See the FileServer function to convert a FileSystem to a Handler.
+//
+// This interface predates the fs.FS interface, which can be used instead:
+// the FS adapter function converts an fs.FS to a FileSystem.
+type FileSystem interface {
+ Open(name string) (File, error)
+}
+
+// A File is returned by a FileSystem's Open method and can be
+// served by the FileServer implementation.
+//
+// The methods should behave the same as those on an *os.File.
+type File interface {
+ io.Closer
+ io.Reader
+ io.Seeker
+ Readdir(count int) ([]fs.FileInfo, error)
+ Stat() (fs.FileInfo, error)
+}
+
+type anyDirs interface {
+ len() int
+ name(i int) string
+ isDir(i int) bool
+}
+
+type fileInfoDirs []fs.FileInfo
+
+func (d fileInfoDirs) len() int { return len(d) }
+func (d fileInfoDirs) isDir(i int) bool { return d[i].IsDir() }
+func (d fileInfoDirs) name(i int) string { return d[i].Name() }
+
+type dirEntryDirs []fs.DirEntry
+
+func (d dirEntryDirs) len() int { return len(d) }
+func (d dirEntryDirs) isDir(i int) bool { return d[i].IsDir() }
+func (d dirEntryDirs) name(i int) string { return d[i].Name() }
+
+func dirList(w ResponseWriter, r *Request, f File) {
+ // Prefer to use ReadDir instead of Readdir,
+ // because the former doesn't require calling
+ // Stat on every entry of a directory on Unix.
+ var dirs anyDirs
+ var err error
+ if d, ok := f.(fs.ReadDirFile); ok {
+ var list dirEntryDirs
+ list, err = d.ReadDir(-1)
+ dirs = list
+ } else {
+ var list fileInfoDirs
+ list, err = f.Readdir(-1)
+ dirs = list
+ }
+
+ if err != nil {
+ logf(r, "http: error reading directory: %v", err)
+ Error(w, "Error reading directory", StatusInternalServerError)
+ return
+ }
+ sort.Slice(dirs, func(i, j int) bool { return dirs.name(i) < dirs.name(j) })
+
+ w.Header().Set("Content-Type", "text/html; charset=utf-8")
+ fmt.Fprintf(w, "<pre>\n")
+ for i, n := 0, dirs.len(); i < n; i++ {
+ name := dirs.name(i)
+ if dirs.isDir(i) {
+ name += "/"
+ }
+ // name may contain '?' or '#', which must be escaped to remain
+ // part of the URL path, and not indicate the start of a query
+ // string or fragment.
+ url := url.URL{Path: name}
+ fmt.Fprintf(w, "<a href=\"%s\">%s</a>\n", url.String(), htmlReplacer.Replace(name))
+ }
+ fmt.Fprintf(w, "</pre>\n")
+}
+
+// ServeContent replies to the request using the content in the
+// provided ReadSeeker. The main benefit of ServeContent over io.Copy
+// is that it handles Range requests properly, sets the MIME type, and
+// handles If-Match, If-Unmodified-Since, If-None-Match, If-Modified-Since,
+// and If-Range requests.
+//
+// If the response's Content-Type header is not set, ServeContent
+// first tries to deduce the type from name's file extension and,
+// if that fails, falls back to reading the first block of the content
+// and passing it to DetectContentType.
+// The name is otherwise unused; in particular it can be empty and is
+// never sent in the response.
+//
+// If modtime is not the zero time or Unix epoch, ServeContent
+// includes it in a Last-Modified header in the response. If the
+// request includes an If-Modified-Since header, ServeContent uses
+// modtime to decide whether the content needs to be sent at all.
+//
+// The content's Seek method must work: ServeContent uses
+// a seek to the end of the content to determine its size.
+//
+// If the caller has set w's ETag header formatted per RFC 7232, section 2.3,
+// ServeContent uses it to handle requests using If-Match, If-None-Match, or If-Range.
+//
+// Note that *os.File implements the io.ReadSeeker interface.
+func ServeContent(w ResponseWriter, req *Request, name string, modtime time.Time, content io.ReadSeeker) {
+ sizeFunc := func() (int64, error) {
+ size, err := content.Seek(0, io.SeekEnd)
+ if err != nil {
+ return 0, errSeeker
+ }
+ _, err = content.Seek(0, io.SeekStart)
+ if err != nil {
+ return 0, errSeeker
+ }
+ return size, nil
+ }
+ serveContent(w, req, name, modtime, sizeFunc, content)
+}
+
+// errSeeker is returned by ServeContent's sizeFunc when the content
+// doesn't seek properly. The underlying Seeker's error text isn't
+// included in the sizeFunc reply so it's not sent over HTTP to end
+// users.
+var errSeeker = errors.New("seeker can't seek")
+
+// errNoOverlap is returned by serveContent's parseRange if first-byte-pos of
+// all of the byte-range-spec values is greater than the content size.
+var errNoOverlap = errors.New("invalid range: failed to overlap")
+
+// if name is empty, filename is unknown. (used for mime type, before sniffing)
+// if modtime.IsZero(), modtime is unknown.
+// content must be seeked to the beginning of the file.
+// The sizeFunc is called at most once. Its error, if any, is sent in the HTTP response.
+func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time, sizeFunc func() (int64, error), content io.ReadSeeker) {
+ setLastModified(w, modtime)
+ done, rangeReq := checkPreconditions(w, r, modtime)
+ if done {
+ return
+ }
+
+ code := StatusOK
+
+ // If Content-Type isn't set, use the file's extension to find it, but
+ // if the Content-Type is unset explicitly, do not sniff the type.
+ ctypes, haveType := w.Header()["Content-Type"]
+ var ctype string
+ if !haveType {
+ ctype = mime.TypeByExtension(filepath.Ext(name))
+ if ctype == "" {
+ // read a chunk to decide between utf-8 text and binary
+ var buf [sniffLen]byte
+ n, _ := io.ReadFull(content, buf[:])
+ ctype = DetectContentType(buf[:n])
+ _, err := content.Seek(0, io.SeekStart) // rewind to output whole file
+ if err != nil {
+ Error(w, "seeker can't seek", StatusInternalServerError)
+ return
+ }
+ }
+ w.Header().Set("Content-Type", ctype)
+ } else if len(ctypes) > 0 {
+ ctype = ctypes[0]
+ }
+
+ size, err := sizeFunc()
+ if err != nil {
+ Error(w, err.Error(), StatusInternalServerError)
+ return
+ }
+ if size < 0 {
+ // Should never happen but just to be sure
+ Error(w, "negative content size computed", StatusInternalServerError)
+ return
+ }
+
+ // handle Content-Range header.
+ sendSize := size
+ var sendContent io.Reader = content
+ ranges, err := parseRange(rangeReq, size)
+ switch err {
+ case nil:
+ case errNoOverlap:
+ if size == 0 {
+ // Some clients add a Range header to all requests to
+ // limit the size of the response. If the file is empty,
+ // ignore the range header and respond with a 200 rather
+ // than a 416.
+ ranges = nil
+ break
+ }
+ w.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", size))
+ fallthrough
+ default:
+ Error(w, err.Error(), StatusRequestedRangeNotSatisfiable)
+ return
+ }
+
+ if sumRangesSize(ranges) > size {
+ // The total number of bytes in all the ranges
+ // is larger than the size of the file by
+ // itself, so this is probably an attack, or a
+ // dumb client. Ignore the range request.
+ ranges = nil
+ }
+ switch {
+ case len(ranges) == 1:
+ // RFC 7233, Section 4.1:
+ // "If a single part is being transferred, the server
+ // generating the 206 response MUST generate a
+ // Content-Range header field, describing what range
+ // of the selected representation is enclosed, and a
+ // payload consisting of the range.
+ // ...
+ // A server MUST NOT generate a multipart response to
+ // a request for a single range, since a client that
+ // does not request multiple parts might not support
+ // multipart responses."
+ ra := ranges[0]
+ if _, err := content.Seek(ra.start, io.SeekStart); err != nil {
+ Error(w, err.Error(), StatusRequestedRangeNotSatisfiable)
+ return
+ }
+ sendSize = ra.length
+ code = StatusPartialContent
+ w.Header().Set("Content-Range", ra.contentRange(size))
+ case len(ranges) > 1:
+ sendSize = rangesMIMESize(ranges, ctype, size)
+ code = StatusPartialContent
+
+ pr, pw := io.Pipe()
+ mw := multipart.NewWriter(pw)
+ w.Header().Set("Content-Type", "multipart/byteranges; boundary="+mw.Boundary())
+ sendContent = pr
+ defer pr.Close() // cause writing goroutine to fail and exit if CopyN doesn't finish.
+ go func() {
+ for _, ra := range ranges {
+ part, err := mw.CreatePart(ra.mimeHeader(ctype, size))
+ if err != nil {
+ pw.CloseWithError(err)
+ return
+ }
+ if _, err := content.Seek(ra.start, io.SeekStart); err != nil {
+ pw.CloseWithError(err)
+ return
+ }
+ if _, err := io.CopyN(part, content, ra.length); err != nil {
+ pw.CloseWithError(err)
+ return
+ }
+ }
+ mw.Close()
+ pw.Close()
+ }()
+ }
+
+ w.Header().Set("Accept-Ranges", "bytes")
+ if w.Header().Get("Content-Encoding") == "" {
+ w.Header().Set("Content-Length", strconv.FormatInt(sendSize, 10))
+ }
+
+ w.WriteHeader(code)
+
+ if r.Method != "HEAD" {
+ io.CopyN(w, sendContent, sendSize)
+ }
+}
+
+// scanETag determines if a syntactically valid ETag is present at s. If so,
+// the ETag and remaining text after consuming ETag is returned. Otherwise,
+// it returns "", "".
+func scanETag(s string) (etag string, remain string) {
+ s = textproto.TrimString(s)
+ start := 0
+ if strings.HasPrefix(s, "W/") {
+ start = 2
+ }
+ if len(s[start:]) < 2 || s[start] != '"' {
+ return "", ""
+ }
+ // ETag is either W/"text" or "text".
+ // See RFC 7232 2.3.
+ for i := start + 1; i < len(s); i++ {
+ c := s[i]
+ switch {
+ // Character values allowed in ETags.
+ case c == 0x21 || c >= 0x23 && c <= 0x7E || c >= 0x80:
+ case c == '"':
+ return s[:i+1], s[i+1:]
+ default:
+ return "", ""
+ }
+ }
+ return "", ""
+}
+
+// etagStrongMatch reports whether a and b match using strong ETag comparison.
+// Assumes a and b are valid ETags.
+func etagStrongMatch(a, b string) bool {
+ return a == b && a != "" && a[0] == '"'
+}
+
+// etagWeakMatch reports whether a and b match using weak ETag comparison.
+// Assumes a and b are valid ETags.
+func etagWeakMatch(a, b string) bool {
+ return strings.TrimPrefix(a, "W/") == strings.TrimPrefix(b, "W/")
+}
+
+// condResult is the result of an HTTP request precondition check.
+// See https://tools.ietf.org/html/rfc7232 section 3.
+type condResult int
+
+const (
+ condNone condResult = iota
+ condTrue
+ condFalse
+)
+
+func checkIfMatch(w ResponseWriter, r *Request) condResult {
+ im := r.Header.Get("If-Match")
+ if im == "" {
+ return condNone
+ }
+ for {
+ im = textproto.TrimString(im)
+ if len(im) == 0 {
+ break
+ }
+ if im[0] == ',' {
+ im = im[1:]
+ continue
+ }
+ if im[0] == '*' {
+ return condTrue
+ }
+ etag, remain := scanETag(im)
+ if etag == "" {
+ break
+ }
+ if etagStrongMatch(etag, w.Header().get("Etag")) {
+ return condTrue
+ }
+ im = remain
+ }
+
+ return condFalse
+}
+
+func checkIfUnmodifiedSince(r *Request, modtime time.Time) condResult {
+ ius := r.Header.Get("If-Unmodified-Since")
+ if ius == "" || isZeroTime(modtime) {
+ return condNone
+ }
+ t, err := ParseTime(ius)
+ if err != nil {
+ return condNone
+ }
+
+ // The Last-Modified header truncates sub-second precision so
+ // the modtime needs to be truncated too.
+ modtime = modtime.Truncate(time.Second)
+ if ret := modtime.Compare(t); ret <= 0 {
+ return condTrue
+ }
+ return condFalse
+}
+
+func checkIfNoneMatch(w ResponseWriter, r *Request) condResult {
+ inm := r.Header.get("If-None-Match")
+ if inm == "" {
+ return condNone
+ }
+ buf := inm
+ for {
+ buf = textproto.TrimString(buf)
+ if len(buf) == 0 {
+ break
+ }
+ if buf[0] == ',' {
+ buf = buf[1:]
+ continue
+ }
+ if buf[0] == '*' {
+ return condFalse
+ }
+ etag, remain := scanETag(buf)
+ if etag == "" {
+ break
+ }
+ if etagWeakMatch(etag, w.Header().get("Etag")) {
+ return condFalse
+ }
+ buf = remain
+ }
+ return condTrue
+}
+
+func checkIfModifiedSince(r *Request, modtime time.Time) condResult {
+ if r.Method != "GET" && r.Method != "HEAD" {
+ return condNone
+ }
+ ims := r.Header.Get("If-Modified-Since")
+ if ims == "" || isZeroTime(modtime) {
+ return condNone
+ }
+ t, err := ParseTime(ims)
+ if err != nil {
+ return condNone
+ }
+ // The Last-Modified header truncates sub-second precision so
+ // the modtime needs to be truncated too.
+ modtime = modtime.Truncate(time.Second)
+ if ret := modtime.Compare(t); ret <= 0 {
+ return condFalse
+ }
+ return condTrue
+}
+
+func checkIfRange(w ResponseWriter, r *Request, modtime time.Time) condResult {
+ if r.Method != "GET" && r.Method != "HEAD" {
+ return condNone
+ }
+ ir := r.Header.get("If-Range")
+ if ir == "" {
+ return condNone
+ }
+ etag, _ := scanETag(ir)
+ if etag != "" {
+ if etagStrongMatch(etag, w.Header().Get("Etag")) {
+ return condTrue
+ } else {
+ return condFalse
+ }
+ }
+ // The If-Range value is typically the ETag value, but it may also be
+ // the modtime date. See golang.org/issue/8367.
+ if modtime.IsZero() {
+ return condFalse
+ }
+ t, err := ParseTime(ir)
+ if err != nil {
+ return condFalse
+ }
+ if t.Unix() == modtime.Unix() {
+ return condTrue
+ }
+ return condFalse
+}
+
+var unixEpochTime = time.Unix(0, 0)
+
+// isZeroTime reports whether t is obviously unspecified (either zero or Unix()=0).
+func isZeroTime(t time.Time) bool {
+ return t.IsZero() || t.Equal(unixEpochTime)
+}
+
+func setLastModified(w ResponseWriter, modtime time.Time) {
+ if !isZeroTime(modtime) {
+ w.Header().Set("Last-Modified", modtime.UTC().Format(TimeFormat))
+ }
+}
+
+func writeNotModified(w ResponseWriter) {
+ // RFC 7232 section 4.1:
+ // a sender SHOULD NOT generate representation metadata other than the
+ // above listed fields unless said metadata exists for the purpose of
+ // guiding cache updates (e.g., Last-Modified might be useful if the
+ // response does not have an ETag field).
+ h := w.Header()
+ delete(h, "Content-Type")
+ delete(h, "Content-Length")
+ delete(h, "Content-Encoding")
+ if h.Get("Etag") != "" {
+ delete(h, "Last-Modified")
+ }
+ w.WriteHeader(StatusNotModified)
+}
+
+// checkPreconditions evaluates request preconditions and reports whether a precondition
+// resulted in sending StatusNotModified or StatusPreconditionFailed.
+func checkPreconditions(w ResponseWriter, r *Request, modtime time.Time) (done bool, rangeHeader string) {
+ // This function carefully follows RFC 7232 section 6.
+ ch := checkIfMatch(w, r)
+ if ch == condNone {
+ ch = checkIfUnmodifiedSince(r, modtime)
+ }
+ if ch == condFalse {
+ w.WriteHeader(StatusPreconditionFailed)
+ return true, ""
+ }
+ switch checkIfNoneMatch(w, r) {
+ case condFalse:
+ if r.Method == "GET" || r.Method == "HEAD" {
+ writeNotModified(w)
+ return true, ""
+ } else {
+ w.WriteHeader(StatusPreconditionFailed)
+ return true, ""
+ }
+ case condNone:
+ if checkIfModifiedSince(r, modtime) == condFalse {
+ writeNotModified(w)
+ return true, ""
+ }
+ }
+
+ rangeHeader = r.Header.get("Range")
+ if rangeHeader != "" && checkIfRange(w, r, modtime) == condFalse {
+ rangeHeader = ""
+ }
+ return false, rangeHeader
+}
+
+// name is '/'-separated, not filepath.Separator.
+func serveFile(w ResponseWriter, r *Request, fs FileSystem, name string, redirect bool) {
+ const indexPage = "/index.html"
+
+ // redirect .../index.html to .../
+ // can't use Redirect() because that would make the path absolute,
+ // which would be a problem running under StripPrefix
+ if strings.HasSuffix(r.URL.Path, indexPage) {
+ localRedirect(w, r, "./")
+ return
+ }
+
+ f, err := fs.Open(name)
+ if err != nil {
+ msg, code := toHTTPError(err)
+ Error(w, msg, code)
+ return
+ }
+ defer f.Close()
+
+ d, err := f.Stat()
+ if err != nil {
+ msg, code := toHTTPError(err)
+ Error(w, msg, code)
+ return
+ }
+
+ if redirect {
+ // redirect to canonical path: / at end of directory url
+ // r.URL.Path always begins with /
+ url := r.URL.Path
+ if d.IsDir() {
+ if url[len(url)-1] != '/' {
+ localRedirect(w, r, path.Base(url)+"/")
+ return
+ }
+ } else {
+ if url[len(url)-1] == '/' {
+ localRedirect(w, r, "../"+path.Base(url))
+ return
+ }
+ }
+ }
+
+ if d.IsDir() {
+ url := r.URL.Path
+ // redirect if the directory name doesn't end in a slash
+ if url == "" || url[len(url)-1] != '/' {
+ localRedirect(w, r, path.Base(url)+"/")
+ return
+ }
+
+ // use contents of index.html for directory, if present
+ index := strings.TrimSuffix(name, "/") + indexPage
+ ff, err := fs.Open(index)
+ if err == nil {
+ defer ff.Close()
+ dd, err := ff.Stat()
+ if err == nil {
+ d = dd
+ f = ff
+ }
+ }
+ }
+
+ // Still a directory? (we didn't find an index.html file)
+ if d.IsDir() {
+ if checkIfModifiedSince(r, d.ModTime()) == condFalse {
+ writeNotModified(w)
+ return
+ }
+ setLastModified(w, d.ModTime())
+ dirList(w, r, f)
+ return
+ }
+
+ // serveContent will check modification time
+ sizeFunc := func() (int64, error) { return d.Size(), nil }
+ serveContent(w, r, d.Name(), d.ModTime(), sizeFunc, f)
+}
+
+// toHTTPError returns a non-specific HTTP error message and status code
+// for a given non-nil error value. It's important that toHTTPError does not
+// actually return err.Error(), since msg and httpStatus are returned to users,
+// and historically Go's ServeContent always returned just "404 Not Found" for
+// all errors. We don't want to start leaking information in error messages.
+func toHTTPError(err error) (msg string, httpStatus int) {
+ if errors.Is(err, fs.ErrNotExist) {
+ return "404 page not found", StatusNotFound
+ }
+ if errors.Is(err, fs.ErrPermission) {
+ return "403 Forbidden", StatusForbidden
+ }
+ // Default:
+ return "500 Internal Server Error", StatusInternalServerError
+}
+
+// localRedirect gives a Moved Permanently response.
+// It does not convert relative paths to absolute paths like Redirect does.
+func localRedirect(w ResponseWriter, r *Request, newPath string) {
+ if q := r.URL.RawQuery; q != "" {
+ newPath += "?" + q
+ }
+ w.Header().Set("Location", newPath)
+ w.WriteHeader(StatusMovedPermanently)
+}
+
+// ServeFile replies to the request with the contents of the named
+// file or directory.
+//
+// If the provided file or directory name is a relative path, it is
+// interpreted relative to the current directory and may ascend to
+// parent directories. If the provided name is constructed from user
+// input, it should be sanitized before calling ServeFile.
+//
+// As a precaution, ServeFile will reject requests where r.URL.Path
+// contains a ".." path element; this protects against callers who
+// might unsafely use filepath.Join on r.URL.Path without sanitizing
+// it and then use that filepath.Join result as the name argument.
+//
+// As another special case, ServeFile redirects any request where r.URL.Path
+// ends in "/index.html" to the same path, without the final
+// "index.html". To avoid such redirects either modify the path or
+// use ServeContent.
+//
+// Outside of those two special cases, ServeFile does not use
+// r.URL.Path for selecting the file or directory to serve; only the
+// file or directory provided in the name argument is used.
+func ServeFile(w ResponseWriter, r *Request, name string) {
+ if containsDotDot(r.URL.Path) {
+ // Too many programs use r.URL.Path to construct the argument to
+ // serveFile. Reject the request under the assumption that happened
+ // here and ".." may not be wanted.
+ // Note that name might not contain "..", for example if code (still
+ // incorrectly) used filepath.Join(myDir, r.URL.Path).
+ Error(w, "invalid URL path", StatusBadRequest)
+ return
+ }
+ dir, file := filepath.Split(name)
+ serveFile(w, r, Dir(dir), file, false)
+}
+
+func containsDotDot(v string) bool {
+ if !strings.Contains(v, "..") {
+ return false
+ }
+ for _, ent := range strings.FieldsFunc(v, isSlashRune) {
+ if ent == ".." {
+ return true
+ }
+ }
+ return false
+}
+
+func isSlashRune(r rune) bool { return r == '/' || r == '\\' }
+
+type fileHandler struct {
+ root FileSystem
+}
+
+type ioFS struct {
+ fsys fs.FS
+}
+
+type ioFile struct {
+ file fs.File
+}
+
+func (f ioFS) Open(name string) (File, error) {
+ if name == "/" {
+ name = "."
+ } else {
+ name = strings.TrimPrefix(name, "/")
+ }
+ file, err := f.fsys.Open(name)
+ if err != nil {
+ return nil, mapOpenError(err, name, '/', func(path string) (fs.FileInfo, error) {
+ return fs.Stat(f.fsys, path)
+ })
+ }
+ return ioFile{file}, nil
+}
+
+func (f ioFile) Close() error { return f.file.Close() }
+func (f ioFile) Read(b []byte) (int, error) { return f.file.Read(b) }
+func (f ioFile) Stat() (fs.FileInfo, error) { return f.file.Stat() }
+
+var errMissingSeek = errors.New("io.File missing Seek method")
+var errMissingReadDir = errors.New("io.File directory missing ReadDir method")
+
+func (f ioFile) Seek(offset int64, whence int) (int64, error) {
+ s, ok := f.file.(io.Seeker)
+ if !ok {
+ return 0, errMissingSeek
+ }
+ return s.Seek(offset, whence)
+}
+
+func (f ioFile) ReadDir(count int) ([]fs.DirEntry, error) {
+ d, ok := f.file.(fs.ReadDirFile)
+ if !ok {
+ return nil, errMissingReadDir
+ }
+ return d.ReadDir(count)
+}
+
+func (f ioFile) Readdir(count int) ([]fs.FileInfo, error) {
+ d, ok := f.file.(fs.ReadDirFile)
+ if !ok {
+ return nil, errMissingReadDir
+ }
+ var list []fs.FileInfo
+ for {
+ dirs, err := d.ReadDir(count - len(list))
+ for _, dir := range dirs {
+ info, err := dir.Info()
+ if err != nil {
+ // Pretend it doesn't exist, like (*os.File).Readdir does.
+ continue
+ }
+ list = append(list, info)
+ }
+ if err != nil {
+ return list, err
+ }
+ if count < 0 || len(list) >= count {
+ break
+ }
+ }
+ return list, nil
+}
+
+// FS converts fsys to a FileSystem implementation,
+// for use with FileServer and NewFileTransport.
+// The files provided by fsys must implement io.Seeker.
+func FS(fsys fs.FS) FileSystem {
+ return ioFS{fsys}
+}
+
+// FileServer returns a handler that serves HTTP requests
+// with the contents of the file system rooted at root.
+//
+// As a special case, the returned file server redirects any request
+// ending in "/index.html" to the same path, without the final
+// "index.html".
+//
+// To use the operating system's file system implementation,
+// use http.Dir:
+//
+// http.Handle("/", http.FileServer(http.Dir("/tmp")))
+//
+// To use an fs.FS implementation, use http.FS to convert it:
+//
+// http.Handle("/", http.FileServer(http.FS(fsys)))
+func FileServer(root FileSystem) Handler {
+ return &fileHandler{root}
+}
+
+func (f *fileHandler) ServeHTTP(w ResponseWriter, r *Request) {
+ upath := r.URL.Path
+ if !strings.HasPrefix(upath, "/") {
+ upath = "/" + upath
+ r.URL.Path = upath
+ }
+ serveFile(w, r, f.root, path.Clean(upath), true)
+}
+
+// httpRange specifies the byte range to be sent to the client.
+type httpRange struct {
+ start, length int64
+}
+
+func (r httpRange) contentRange(size int64) string {
+ return fmt.Sprintf("bytes %d-%d/%d", r.start, r.start+r.length-1, size)
+}
+
+func (r httpRange) mimeHeader(contentType string, size int64) textproto.MIMEHeader {
+ return textproto.MIMEHeader{
+ "Content-Range": {r.contentRange(size)},
+ "Content-Type": {contentType},
+ }
+}
+
+// parseRange parses a Range header string as per RFC 7233.
+// errNoOverlap is returned if none of the ranges overlap.
+func parseRange(s string, size int64) ([]httpRange, error) {
+ if s == "" {
+ return nil, nil // header not present
+ }
+ const b = "bytes="
+ if !strings.HasPrefix(s, b) {
+ return nil, errors.New("invalid range")
+ }
+ var ranges []httpRange
+ noOverlap := false
+ for _, ra := range strings.Split(s[len(b):], ",") {
+ ra = textproto.TrimString(ra)
+ if ra == "" {
+ continue
+ }
+ start, end, ok := strings.Cut(ra, "-")
+ if !ok {
+ return nil, errors.New("invalid range")
+ }
+ start, end = textproto.TrimString(start), textproto.TrimString(end)
+ var r httpRange
+ if start == "" {
+ // If no start is specified, end specifies the
+ // range start relative to the end of the file,
+ // and we are dealing with <suffix-length>
+ // which has to be a non-negative integer as per
+ // RFC 7233 Section 2.1 "Byte-Ranges".
+ if end == "" || end[0] == '-' {
+ return nil, errors.New("invalid range")
+ }
+ i, err := strconv.ParseInt(end, 10, 64)
+ if i < 0 || err != nil {
+ return nil, errors.New("invalid range")
+ }
+ if i > size {
+ i = size
+ }
+ r.start = size - i
+ r.length = size - r.start
+ } else {
+ i, err := strconv.ParseInt(start, 10, 64)
+ if err != nil || i < 0 {
+ return nil, errors.New("invalid range")
+ }
+ if i >= size {
+ // If the range begins after the size of the content,
+ // then it does not overlap.
+ noOverlap = true
+ continue
+ }
+ r.start = i
+ if end == "" {
+ // If no end is specified, range extends to end of the file.
+ r.length = size - r.start
+ } else {
+ i, err := strconv.ParseInt(end, 10, 64)
+ if err != nil || r.start > i {
+ return nil, errors.New("invalid range")
+ }
+ if i >= size {
+ i = size - 1
+ }
+ r.length = i - r.start + 1
+ }
+ }
+ ranges = append(ranges, r)
+ }
+ if noOverlap && len(ranges) == 0 {
+ // The specified ranges did not overlap with the content.
+ return nil, errNoOverlap
+ }
+ return ranges, nil
+}
+
+// countingWriter counts how many bytes have been written to it.
+type countingWriter int64
+
+func (w *countingWriter) Write(p []byte) (n int, err error) {
+ *w += countingWriter(len(p))
+ return len(p), nil
+}
+
+// rangesMIMESize returns the number of bytes it takes to encode the
+// provided ranges as a multipart response.
+func rangesMIMESize(ranges []httpRange, contentType string, contentSize int64) (encSize int64) {
+ var w countingWriter
+ mw := multipart.NewWriter(&w)
+ for _, ra := range ranges {
+ mw.CreatePart(ra.mimeHeader(contentType, contentSize))
+ encSize += ra.length
+ }
+ mw.Close()
+ encSize += int64(w)
+ return
+}
+
+func sumRangesSize(ranges []httpRange) (size int64) {
+ for _, ra := range ranges {
+ size += ra.length
+ }
+ return
+}
diff --git a/src/net/http/fs_test.go b/src/net/http/fs_test.go
new file mode 100644
index 0000000..3fb9e01
--- /dev/null
+++ b/src/net/http/fs_test.go
@@ -0,0 +1,1561 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http_test
+
+import (
+ "bufio"
+ "bytes"
+ "errors"
+ "fmt"
+ "io"
+ "io/fs"
+ "mime"
+ "mime/multipart"
+ "net"
+ . "net/http"
+ "net/http/httptest"
+ "net/url"
+ "os"
+ "os/exec"
+ "path"
+ "path/filepath"
+ "reflect"
+ "regexp"
+ "runtime"
+ "strings"
+ "testing"
+ "time"
+)
+
+const (
+ testFile = "testdata/file"
+ testFileLen = 11
+)
+
+type wantRange struct {
+ start, end int64 // range [start,end)
+}
+
+var ServeFileRangeTests = []struct {
+ r string
+ code int
+ ranges []wantRange
+}{
+ {r: "", code: StatusOK},
+ {r: "bytes=0-4", code: StatusPartialContent, ranges: []wantRange{{0, 5}}},
+ {r: "bytes=2-", code: StatusPartialContent, ranges: []wantRange{{2, testFileLen}}},
+ {r: "bytes=-5", code: StatusPartialContent, ranges: []wantRange{{testFileLen - 5, testFileLen}}},
+ {r: "bytes=3-7", code: StatusPartialContent, ranges: []wantRange{{3, 8}}},
+ {r: "bytes=0-0,-2", code: StatusPartialContent, ranges: []wantRange{{0, 1}, {testFileLen - 2, testFileLen}}},
+ {r: "bytes=0-1,5-8", code: StatusPartialContent, ranges: []wantRange{{0, 2}, {5, 9}}},
+ {r: "bytes=0-1,5-", code: StatusPartialContent, ranges: []wantRange{{0, 2}, {5, testFileLen}}},
+ {r: "bytes=5-1000", code: StatusPartialContent, ranges: []wantRange{{5, testFileLen}}},
+ {r: "bytes=0-,1-,2-,3-,4-", code: StatusOK}, // ignore wasteful range request
+ {r: "bytes=0-9", code: StatusPartialContent, ranges: []wantRange{{0, testFileLen - 1}}},
+ {r: "bytes=0-10", code: StatusPartialContent, ranges: []wantRange{{0, testFileLen}}},
+ {r: "bytes=0-11", code: StatusPartialContent, ranges: []wantRange{{0, testFileLen}}},
+ {r: "bytes=10-11", code: StatusPartialContent, ranges: []wantRange{{testFileLen - 1, testFileLen}}},
+ {r: "bytes=10-", code: StatusPartialContent, ranges: []wantRange{{testFileLen - 1, testFileLen}}},
+ {r: "bytes=11-", code: StatusRequestedRangeNotSatisfiable},
+ {r: "bytes=11-12", code: StatusRequestedRangeNotSatisfiable},
+ {r: "bytes=12-12", code: StatusRequestedRangeNotSatisfiable},
+ {r: "bytes=11-100", code: StatusRequestedRangeNotSatisfiable},
+ {r: "bytes=12-100", code: StatusRequestedRangeNotSatisfiable},
+ {r: "bytes=100-", code: StatusRequestedRangeNotSatisfiable},
+ {r: "bytes=100-1000", code: StatusRequestedRangeNotSatisfiable},
+}
+
+func TestServeFile(t *testing.T) { run(t, testServeFile) }
+func testServeFile(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ ServeFile(w, r, "testdata/file")
+ })).ts
+ c := ts.Client()
+
+ var err error
+
+ file, err := os.ReadFile(testFile)
+ if err != nil {
+ t.Fatal("reading file:", err)
+ }
+
+ // set up the Request (re-used for all tests)
+ var req Request
+ req.Header = make(Header)
+ if req.URL, err = url.Parse(ts.URL); err != nil {
+ t.Fatal("ParseURL:", err)
+ }
+
+ // Get contents via various methods.
+ //
+ // See https://go.dev/issue/59471 for a proposal to limit the set of methods handled.
+ // For now, test the historical behavior.
+ for _, method := range []string{
+ MethodGet,
+ MethodPost,
+ MethodPut,
+ MethodPatch,
+ MethodDelete,
+ MethodOptions,
+ MethodTrace,
+ } {
+ req.Method = method
+ _, body := getBody(t, method, req, c)
+ if !bytes.Equal(body, file) {
+ t.Fatalf("body mismatch for %v request: got %q, want %q", method, body, file)
+ }
+ }
+
+ // HEAD request.
+ req.Method = MethodHead
+ resp, body := getBody(t, "HEAD", req, c)
+ if len(body) != 0 {
+ t.Fatalf("body mismatch for HEAD request: got %q, want empty", body)
+ }
+ if got, want := resp.Header.Get("Content-Length"), fmt.Sprint(len(file)); got != want {
+ t.Fatalf("Content-Length mismatch for HEAD request: got %v, want %v", got, want)
+ }
+
+ // Range tests
+ req.Method = MethodGet
+Cases:
+ for _, rt := range ServeFileRangeTests {
+ if rt.r != "" {
+ req.Header.Set("Range", rt.r)
+ }
+ resp, body := getBody(t, fmt.Sprintf("range test %q", rt.r), req, c)
+ if resp.StatusCode != rt.code {
+ t.Errorf("range=%q: StatusCode=%d, want %d", rt.r, resp.StatusCode, rt.code)
+ }
+ if rt.code == StatusRequestedRangeNotSatisfiable {
+ continue
+ }
+ wantContentRange := ""
+ if len(rt.ranges) == 1 {
+ rng := rt.ranges[0]
+ wantContentRange = fmt.Sprintf("bytes %d-%d/%d", rng.start, rng.end-1, testFileLen)
+ }
+ cr := resp.Header.Get("Content-Range")
+ if cr != wantContentRange {
+ t.Errorf("range=%q: Content-Range = %q, want %q", rt.r, cr, wantContentRange)
+ }
+ ct := resp.Header.Get("Content-Type")
+ if len(rt.ranges) == 1 {
+ rng := rt.ranges[0]
+ wantBody := file[rng.start:rng.end]
+ if !bytes.Equal(body, wantBody) {
+ t.Errorf("range=%q: body = %q, want %q", rt.r, body, wantBody)
+ }
+ if strings.HasPrefix(ct, "multipart/byteranges") {
+ t.Errorf("range=%q content-type = %q; unexpected multipart/byteranges", rt.r, ct)
+ }
+ }
+ if len(rt.ranges) > 1 {
+ typ, params, err := mime.ParseMediaType(ct)
+ if err != nil {
+ t.Errorf("range=%q content-type = %q; %v", rt.r, ct, err)
+ continue
+ }
+ if typ != "multipart/byteranges" {
+ t.Errorf("range=%q content-type = %q; want multipart/byteranges", rt.r, typ)
+ continue
+ }
+ if params["boundary"] == "" {
+ t.Errorf("range=%q content-type = %q; lacks boundary", rt.r, ct)
+ continue
+ }
+ if g, w := resp.ContentLength, int64(len(body)); g != w {
+ t.Errorf("range=%q Content-Length = %d; want %d", rt.r, g, w)
+ continue
+ }
+ mr := multipart.NewReader(bytes.NewReader(body), params["boundary"])
+ for ri, rng := range rt.ranges {
+ part, err := mr.NextPart()
+ if err != nil {
+ t.Errorf("range=%q, reading part index %d: %v", rt.r, ri, err)
+ continue Cases
+ }
+ wantContentRange = fmt.Sprintf("bytes %d-%d/%d", rng.start, rng.end-1, testFileLen)
+ if g, w := part.Header.Get("Content-Range"), wantContentRange; g != w {
+ t.Errorf("range=%q: part Content-Range = %q; want %q", rt.r, g, w)
+ }
+ body, err := io.ReadAll(part)
+ if err != nil {
+ t.Errorf("range=%q, reading part index %d body: %v", rt.r, ri, err)
+ continue Cases
+ }
+ wantBody := file[rng.start:rng.end]
+ if !bytes.Equal(body, wantBody) {
+ t.Errorf("range=%q: body = %q, want %q", rt.r, body, wantBody)
+ }
+ }
+ _, err = mr.NextPart()
+ if err != io.EOF {
+ t.Errorf("range=%q; expected final error io.EOF; got %v", rt.r, err)
+ }
+ }
+ }
+}
+
+func TestServeFile_DotDot(t *testing.T) {
+ tests := []struct {
+ req string
+ wantStatus int
+ }{
+ {"/testdata/file", 200},
+ {"/../file", 400},
+ {"/..", 400},
+ {"/../", 400},
+ {"/../foo", 400},
+ {"/..\\foo", 400},
+ {"/file/a", 200},
+ {"/file/a..", 200},
+ {"/file/a/..", 400},
+ {"/file/a\\..", 400},
+ }
+ for _, tt := range tests {
+ req, err := ReadRequest(bufio.NewReader(strings.NewReader("GET " + tt.req + " HTTP/1.1\r\nHost: foo\r\n\r\n")))
+ if err != nil {
+ t.Errorf("bad request %q: %v", tt.req, err)
+ continue
+ }
+ rec := httptest.NewRecorder()
+ ServeFile(rec, req, "testdata/file")
+ if rec.Code != tt.wantStatus {
+ t.Errorf("for request %q, status = %d; want %d", tt.req, rec.Code, tt.wantStatus)
+ }
+ }
+}
+
+// Tests that this doesn't panic. (Issue 30165)
+func TestServeFileDirPanicEmptyPath(t *testing.T) {
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest("GET", "/", nil)
+ req.URL.Path = ""
+ ServeFile(rec, req, "testdata")
+ res := rec.Result()
+ if res.StatusCode != 301 {
+ t.Errorf("code = %v; want 301", res.Status)
+ }
+}
+
+// Tests that ranges are ignored with serving empty content. (Issue 54794)
+func TestServeContentWithEmptyContentIgnoreRanges(t *testing.T) {
+ for _, r := range []string{
+ "bytes=0-128",
+ "bytes=1-",
+ } {
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest("GET", "/", nil)
+ req.Header.Set("Range", r)
+ ServeContent(rec, req, "nothing", time.Now(), bytes.NewReader(nil))
+ res := rec.Result()
+ if res.StatusCode != 200 {
+ t.Errorf("code = %v; want 200", res.Status)
+ }
+ bodyLen := rec.Body.Len()
+ if bodyLen != 0 {
+ t.Errorf("body.Len() = %v; want 0", res.Status)
+ }
+ }
+}
+
+var fsRedirectTestData = []struct {
+ original, redirect string
+}{
+ {"/test/index.html", "/test/"},
+ {"/test/testdata", "/test/testdata/"},
+ {"/test/testdata/file/", "/test/testdata/file"},
+}
+
+func TestFSRedirect(t *testing.T) { run(t, testFSRedirect) }
+func testFSRedirect(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, StripPrefix("/test", FileServer(Dir(".")))).ts
+
+ for _, data := range fsRedirectTestData {
+ res, err := ts.Client().Get(ts.URL + data.original)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ if g, e := res.Request.URL.Path, data.redirect; g != e {
+ t.Errorf("redirect from %s: got %s, want %s", data.original, g, e)
+ }
+ }
+}
+
+type testFileSystem struct {
+ open func(name string) (File, error)
+}
+
+func (fs *testFileSystem) Open(name string) (File, error) {
+ return fs.open(name)
+}
+
+func TestFileServerCleans(t *testing.T) {
+ defer afterTest(t)
+ ch := make(chan string, 1)
+ fs := FileServer(&testFileSystem{func(name string) (File, error) {
+ ch <- name
+ return nil, errors.New("file does not exist")
+ }})
+ tests := []struct {
+ reqPath, openArg string
+ }{
+ {"/foo.txt", "/foo.txt"},
+ {"//foo.txt", "/foo.txt"},
+ {"/../foo.txt", "/foo.txt"},
+ }
+ req, _ := NewRequest("GET", "http://example.com", nil)
+ for n, test := range tests {
+ rec := httptest.NewRecorder()
+ req.URL.Path = test.reqPath
+ fs.ServeHTTP(rec, req)
+ if got := <-ch; got != test.openArg {
+ t.Errorf("test %d: got %q, want %q", n, got, test.openArg)
+ }
+ }
+}
+
+func TestFileServerEscapesNames(t *testing.T) { run(t, testFileServerEscapesNames) }
+func testFileServerEscapesNames(t *testing.T, mode testMode) {
+ const dirListPrefix = "<pre>\n"
+ const dirListSuffix = "\n</pre>\n"
+ tests := []struct {
+ name, escaped string
+ }{
+ {`simple_name`, `<a href="simple_name">simple_name</a>`},
+ {`"'<>&`, `<a href="%22%27%3C%3E&">&#34;&#39;&lt;&gt;&amp;</a>`},
+ {`?foo=bar#baz`, `<a href="%3Ffoo=bar%23baz">?foo=bar#baz</a>`},
+ {`<combo>?foo`, `<a href="%3Ccombo%3E%3Ffoo">&lt;combo&gt;?foo</a>`},
+ {`foo:bar`, `<a href="./foo:bar">foo:bar</a>`},
+ }
+
+ // We put each test file in its own directory in the fakeFS so we can look at it in isolation.
+ fs := make(fakeFS)
+ for i, test := range tests {
+ testFile := &fakeFileInfo{basename: test.name}
+ fs[fmt.Sprintf("/%d", i)] = &fakeFileInfo{
+ dir: true,
+ modtime: time.Unix(1000000000, 0).UTC(),
+ ents: []*fakeFileInfo{testFile},
+ }
+ fs[fmt.Sprintf("/%d/%s", i, test.name)] = testFile
+ }
+
+ ts := newClientServerTest(t, mode, FileServer(&fs)).ts
+ for i, test := range tests {
+ url := fmt.Sprintf("%s/%d", ts.URL, i)
+ res, err := ts.Client().Get(url)
+ if err != nil {
+ t.Fatalf("test %q: Get: %v", test.name, err)
+ }
+ b, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatalf("test %q: read Body: %v", test.name, err)
+ }
+ s := string(b)
+ if !strings.HasPrefix(s, dirListPrefix) || !strings.HasSuffix(s, dirListSuffix) {
+ t.Errorf("test %q: listing dir, full output is %q, want prefix %q and suffix %q", test.name, s, dirListPrefix, dirListSuffix)
+ }
+ if trimmed := strings.TrimSuffix(strings.TrimPrefix(s, dirListPrefix), dirListSuffix); trimmed != test.escaped {
+ t.Errorf("test %q: listing dir, filename escaped to %q, want %q", test.name, trimmed, test.escaped)
+ }
+ res.Body.Close()
+ }
+}
+
+func TestFileServerSortsNames(t *testing.T) { run(t, testFileServerSortsNames) }
+func testFileServerSortsNames(t *testing.T, mode testMode) {
+ const contents = "I am a fake file"
+ dirMod := time.Unix(123, 0).UTC()
+ fileMod := time.Unix(1000000000, 0).UTC()
+ fs := fakeFS{
+ "/": &fakeFileInfo{
+ dir: true,
+ modtime: dirMod,
+ ents: []*fakeFileInfo{
+ {
+ basename: "b",
+ modtime: fileMod,
+ contents: contents,
+ },
+ {
+ basename: "a",
+ modtime: fileMod,
+ contents: contents,
+ },
+ },
+ },
+ }
+
+ ts := newClientServerTest(t, mode, FileServer(&fs)).ts
+
+ res, err := ts.Client().Get(ts.URL)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ defer res.Body.Close()
+
+ b, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatalf("read Body: %v", err)
+ }
+ s := string(b)
+ if !strings.Contains(s, "<a href=\"a\">a</a>\n<a href=\"b\">b</a>") {
+ t.Errorf("output appears to be unsorted:\n%s", s)
+ }
+}
+
+func mustRemoveAll(dir string) {
+ err := os.RemoveAll(dir)
+ if err != nil {
+ panic(err)
+ }
+}
+
+func TestFileServerImplicitLeadingSlash(t *testing.T) { run(t, testFileServerImplicitLeadingSlash) }
+func testFileServerImplicitLeadingSlash(t *testing.T, mode testMode) {
+ tempDir := t.TempDir()
+ if err := os.WriteFile(filepath.Join(tempDir, "foo.txt"), []byte("Hello world"), 0644); err != nil {
+ t.Fatalf("WriteFile: %v", err)
+ }
+ ts := newClientServerTest(t, mode, StripPrefix("/bar/", FileServer(Dir(tempDir)))).ts
+ get := func(suffix string) string {
+ res, err := ts.Client().Get(ts.URL + suffix)
+ if err != nil {
+ t.Fatalf("Get %s: %v", suffix, err)
+ }
+ b, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatalf("ReadAll %s: %v", suffix, err)
+ }
+ res.Body.Close()
+ return string(b)
+ }
+ if s := get("/bar/"); !strings.Contains(s, ">foo.txt<") {
+ t.Logf("expected a directory listing with foo.txt, got %q", s)
+ }
+ if s := get("/bar/foo.txt"); s != "Hello world" {
+ t.Logf("expected %q, got %q", "Hello world", s)
+ }
+}
+
+func TestDirJoin(t *testing.T) {
+ if runtime.GOOS == "windows" {
+ t.Skip("skipping test on windows")
+ }
+ wfi, err := os.Stat("/etc/hosts")
+ if err != nil {
+ t.Skip("skipping test; no /etc/hosts file")
+ }
+ test := func(d Dir, name string) {
+ f, err := d.Open(name)
+ if err != nil {
+ t.Fatalf("open of %s: %v", name, err)
+ }
+ defer f.Close()
+ gfi, err := f.Stat()
+ if err != nil {
+ t.Fatalf("stat of %s: %v", name, err)
+ }
+ if !os.SameFile(gfi, wfi) {
+ t.Errorf("%s got different file", name)
+ }
+ }
+ test(Dir("/etc/"), "/hosts")
+ test(Dir("/etc/"), "hosts")
+ test(Dir("/etc/"), "../../../../hosts")
+ test(Dir("/etc"), "/hosts")
+ test(Dir("/etc"), "hosts")
+ test(Dir("/etc"), "../../../../hosts")
+
+ // Not really directories, but since we use this trick in
+ // ServeFile, test it:
+ test(Dir("/etc/hosts"), "")
+ test(Dir("/etc/hosts"), "/")
+ test(Dir("/etc/hosts"), "../")
+}
+
+func TestEmptyDirOpenCWD(t *testing.T) {
+ test := func(d Dir) {
+ name := "fs_test.go"
+ f, err := d.Open(name)
+ if err != nil {
+ t.Fatalf("open of %s: %v", name, err)
+ }
+ defer f.Close()
+ }
+ test(Dir(""))
+ test(Dir("."))
+ test(Dir("./"))
+}
+
+func TestServeFileContentType(t *testing.T) { run(t, testServeFileContentType) }
+func testServeFileContentType(t *testing.T, mode testMode) {
+ const ctype = "icecream/chocolate"
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ switch r.FormValue("override") {
+ case "1":
+ w.Header().Set("Content-Type", ctype)
+ case "2":
+ // Explicitly inhibit sniffing.
+ w.Header()["Content-Type"] = []string{}
+ }
+ ServeFile(w, r, "testdata/file")
+ })).ts
+ get := func(override string, want []string) {
+ resp, err := ts.Client().Get(ts.URL + "?override=" + override)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if h := resp.Header["Content-Type"]; !reflect.DeepEqual(h, want) {
+ t.Errorf("Content-Type mismatch: got %v, want %v", h, want)
+ }
+ resp.Body.Close()
+ }
+ get("0", []string{"text/plain; charset=utf-8"})
+ get("1", []string{ctype})
+ get("2", nil)
+}
+
+func TestServeFileMimeType(t *testing.T) { run(t, testServeFileMimeType) }
+func testServeFileMimeType(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ ServeFile(w, r, "testdata/style.css")
+ })).ts
+ resp, err := ts.Client().Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ resp.Body.Close()
+ want := "text/css; charset=utf-8"
+ if h := resp.Header.Get("Content-Type"); h != want {
+ t.Errorf("Content-Type mismatch: got %q, want %q", h, want)
+ }
+}
+
+func TestServeFileFromCWD(t *testing.T) { run(t, testServeFileFromCWD) }
+func testServeFileFromCWD(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ ServeFile(w, r, "fs_test.go")
+ })).ts
+ r, err := ts.Client().Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ r.Body.Close()
+ if r.StatusCode != 200 {
+ t.Fatalf("expected 200 OK, got %s", r.Status)
+ }
+}
+
+// Issue 13996
+func TestServeDirWithoutTrailingSlash(t *testing.T) { run(t, testServeDirWithoutTrailingSlash) }
+func testServeDirWithoutTrailingSlash(t *testing.T, mode testMode) {
+ e := "/testdata/"
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ ServeFile(w, r, ".")
+ })).ts
+ r, err := ts.Client().Get(ts.URL + "/testdata")
+ if err != nil {
+ t.Fatal(err)
+ }
+ r.Body.Close()
+ if g := r.Request.URL.Path; g != e {
+ t.Errorf("got %s, want %s", g, e)
+ }
+}
+
+// Tests that ServeFile doesn't add a Content-Length if a Content-Encoding is
+// specified.
+func TestServeFileWithContentEncoding(t *testing.T) { run(t, testServeFileWithContentEncoding) }
+func testServeFileWithContentEncoding(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Encoding", "foo")
+ ServeFile(w, r, "testdata/file")
+
+ // Because the testdata is so small, it would fit in
+ // both the h1 and h2 Server's write buffers. For h1,
+ // sendfile is used, though, forcing a header flush at
+ // the io.Copy. http2 doesn't do a header flush so
+ // buffers all 11 bytes and then adds its own
+ // Content-Length. To prevent the Server's
+ // Content-Length and test ServeFile only, flush here.
+ w.(Flusher).Flush()
+ }))
+ resp, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ resp.Body.Close()
+ if g, e := resp.ContentLength, int64(-1); g != e {
+ t.Errorf("Content-Length mismatch: got %d, want %d", g, e)
+ }
+}
+
+// Tests that ServeFile does not generate representation metadata when
+// file has not been modified, as per RFC 7232 section 4.1.
+func TestServeFileNotModified(t *testing.T) { run(t, testServeFileNotModified) }
+func testServeFileNotModified(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Type", "application/json")
+ w.Header().Set("Content-Encoding", "foo")
+ w.Header().Set("Etag", `"123"`)
+ ServeFile(w, r, "testdata/file")
+
+ // Because the testdata is so small, it would fit in
+ // both the h1 and h2 Server's write buffers. For h1,
+ // sendfile is used, though, forcing a header flush at
+ // the io.Copy. http2 doesn't do a header flush so
+ // buffers all 11 bytes and then adds its own
+ // Content-Length. To prevent the Server's
+ // Content-Length and test ServeFile only, flush here.
+ w.(Flusher).Flush()
+ }))
+ req, err := NewRequest("GET", cst.ts.URL, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ req.Header.Set("If-None-Match", `"123"`)
+ resp, err := cst.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ b, err := io.ReadAll(resp.Body)
+ resp.Body.Close()
+ if err != nil {
+ t.Fatal("reading Body:", err)
+ }
+ if len(b) != 0 {
+ t.Errorf("non-empty body")
+ }
+ if g, e := resp.StatusCode, StatusNotModified; g != e {
+ t.Errorf("status mismatch: got %d, want %d", g, e)
+ }
+ // HTTP1 transport sets ContentLength to 0.
+ if g, e1, e2 := resp.ContentLength, int64(-1), int64(0); g != e1 && g != e2 {
+ t.Errorf("Content-Length mismatch: got %d, want %d or %d", g, e1, e2)
+ }
+ if resp.Header.Get("Content-Type") != "" {
+ t.Errorf("Content-Type present, but it should not be")
+ }
+ if resp.Header.Get("Content-Encoding") != "" {
+ t.Errorf("Content-Encoding present, but it should not be")
+ }
+}
+
+func TestServeIndexHtml(t *testing.T) { run(t, testServeIndexHtml) }
+func testServeIndexHtml(t *testing.T, mode testMode) {
+ for i := 0; i < 2; i++ {
+ var h Handler
+ var name string
+ switch i {
+ case 0:
+ h = FileServer(Dir("."))
+ name = "Dir"
+ case 1:
+ h = FileServer(FS(os.DirFS(".")))
+ name = "DirFS"
+ }
+ t.Run(name, func(t *testing.T) {
+ const want = "index.html says hello\n"
+ ts := newClientServerTest(t, mode, h).ts
+
+ for _, path := range []string{"/testdata/", "/testdata/index.html"} {
+ res, err := ts.Client().Get(ts.URL + path)
+ if err != nil {
+ t.Fatal(err)
+ }
+ b, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal("reading Body:", err)
+ }
+ if s := string(b); s != want {
+ t.Errorf("for path %q got %q, want %q", path, s, want)
+ }
+ res.Body.Close()
+ }
+ })
+ }
+}
+
+func TestServeIndexHtmlFS(t *testing.T) { run(t, testServeIndexHtmlFS) }
+func testServeIndexHtmlFS(t *testing.T, mode testMode) {
+ const want = "index.html says hello\n"
+ ts := newClientServerTest(t, mode, FileServer(Dir("."))).ts
+ defer ts.Close()
+
+ for _, path := range []string{"/testdata/", "/testdata/index.html"} {
+ res, err := ts.Client().Get(ts.URL + path)
+ if err != nil {
+ t.Fatal(err)
+ }
+ b, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal("reading Body:", err)
+ }
+ if s := string(b); s != want {
+ t.Errorf("for path %q got %q, want %q", path, s, want)
+ }
+ res.Body.Close()
+ }
+}
+
+func TestFileServerZeroByte(t *testing.T) { run(t, testFileServerZeroByte) }
+func testFileServerZeroByte(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, FileServer(Dir("."))).ts
+
+ c, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+ _, err = fmt.Fprintf(c, "GET /..\x00 HTTP/1.0\r\n\r\n")
+ if err != nil {
+ t.Fatal(err)
+ }
+ var got bytes.Buffer
+ bufr := bufio.NewReader(io.TeeReader(c, &got))
+ res, err := ReadResponse(bufr, nil)
+ if err != nil {
+ t.Fatal("ReadResponse: ", err)
+ }
+ if res.StatusCode == 200 {
+ t.Errorf("got status 200; want an error. Body is:\n%s", got.Bytes())
+ }
+}
+
+func TestFileServerNamesEscape(t *testing.T) { run(t, testFileServerNamesEscape) }
+func testFileServerNamesEscape(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, FileServer(Dir("testdata"))).ts
+ for _, path := range []string{
+ "/../testdata/file",
+ "/NUL", // don't read from device files on Windows
+ } {
+ res, err := ts.Client().Get(ts.URL + path)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ if res.StatusCode < 400 || res.StatusCode > 599 {
+ t.Errorf("Get(%q): got status %v, want 4xx or 5xx", path, res.StatusCode)
+ }
+
+ }
+}
+
+type fakeFileInfo struct {
+ dir bool
+ basename string
+ modtime time.Time
+ ents []*fakeFileInfo
+ contents string
+ err error
+}
+
+func (f *fakeFileInfo) Name() string { return f.basename }
+func (f *fakeFileInfo) Sys() any { return nil }
+func (f *fakeFileInfo) ModTime() time.Time { return f.modtime }
+func (f *fakeFileInfo) IsDir() bool { return f.dir }
+func (f *fakeFileInfo) Size() int64 { return int64(len(f.contents)) }
+func (f *fakeFileInfo) Mode() fs.FileMode {
+ if f.dir {
+ return 0755 | fs.ModeDir
+ }
+ return 0644
+}
+
+func (f *fakeFileInfo) String() string {
+ return fs.FormatFileInfo(f)
+}
+
+type fakeFile struct {
+ io.ReadSeeker
+ fi *fakeFileInfo
+ path string // as opened
+ entpos int
+}
+
+func (f *fakeFile) Close() error { return nil }
+func (f *fakeFile) Stat() (fs.FileInfo, error) { return f.fi, nil }
+func (f *fakeFile) Readdir(count int) ([]fs.FileInfo, error) {
+ if !f.fi.dir {
+ return nil, fs.ErrInvalid
+ }
+ var fis []fs.FileInfo
+
+ limit := f.entpos + count
+ if count <= 0 || limit > len(f.fi.ents) {
+ limit = len(f.fi.ents)
+ }
+ for ; f.entpos < limit; f.entpos++ {
+ fis = append(fis, f.fi.ents[f.entpos])
+ }
+
+ if len(fis) == 0 && count > 0 {
+ return fis, io.EOF
+ } else {
+ return fis, nil
+ }
+}
+
+type fakeFS map[string]*fakeFileInfo
+
+func (fsys fakeFS) Open(name string) (File, error) {
+ name = path.Clean(name)
+ f, ok := fsys[name]
+ if !ok {
+ return nil, fs.ErrNotExist
+ }
+ if f.err != nil {
+ return nil, f.err
+ }
+ return &fakeFile{ReadSeeker: strings.NewReader(f.contents), fi: f, path: name}, nil
+}
+
+func TestDirectoryIfNotModified(t *testing.T) { run(t, testDirectoryIfNotModified) }
+func testDirectoryIfNotModified(t *testing.T, mode testMode) {
+ const indexContents = "I am a fake index.html file"
+ fileMod := time.Unix(1000000000, 0).UTC()
+ fileModStr := fileMod.Format(TimeFormat)
+ dirMod := time.Unix(123, 0).UTC()
+ indexFile := &fakeFileInfo{
+ basename: "index.html",
+ modtime: fileMod,
+ contents: indexContents,
+ }
+ fs := fakeFS{
+ "/": &fakeFileInfo{
+ dir: true,
+ modtime: dirMod,
+ ents: []*fakeFileInfo{indexFile},
+ },
+ "/index.html": indexFile,
+ }
+
+ ts := newClientServerTest(t, mode, FileServer(fs)).ts
+
+ res, err := ts.Client().Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ b, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if string(b) != indexContents {
+ t.Fatalf("Got body %q; want %q", b, indexContents)
+ }
+ res.Body.Close()
+
+ lastMod := res.Header.Get("Last-Modified")
+ if lastMod != fileModStr {
+ t.Fatalf("initial Last-Modified = %q; want %q", lastMod, fileModStr)
+ }
+
+ req, _ := NewRequest("GET", ts.URL, nil)
+ req.Header.Set("If-Modified-Since", lastMod)
+
+ c := ts.Client()
+ res, err = c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if res.StatusCode != 304 {
+ t.Fatalf("Code after If-Modified-Since request = %v; want 304", res.StatusCode)
+ }
+ res.Body.Close()
+
+ // Advance the index.html file's modtime, but not the directory's.
+ indexFile.modtime = indexFile.modtime.Add(1 * time.Hour)
+
+ res, err = c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if res.StatusCode != 200 {
+ t.Fatalf("Code after second If-Modified-Since request = %v; want 200; res is %#v", res.StatusCode, res)
+ }
+ res.Body.Close()
+}
+
+func mustStat(t *testing.T, fileName string) fs.FileInfo {
+ fi, err := os.Stat(fileName)
+ if err != nil {
+ t.Fatal(err)
+ }
+ return fi
+}
+
+func TestServeContent(t *testing.T) { run(t, testServeContent) }
+func testServeContent(t *testing.T, mode testMode) {
+ type serveParam struct {
+ name string
+ modtime time.Time
+ content io.ReadSeeker
+ contentType string
+ etag string
+ }
+ servec := make(chan serveParam, 1)
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ p := <-servec
+ if p.etag != "" {
+ w.Header().Set("ETag", p.etag)
+ }
+ if p.contentType != "" {
+ w.Header().Set("Content-Type", p.contentType)
+ }
+ ServeContent(w, r, p.name, p.modtime, p.content)
+ })).ts
+
+ type testCase struct {
+ // One of file or content must be set:
+ file string
+ content io.ReadSeeker
+
+ modtime time.Time
+ serveETag string // optional
+ serveContentType string // optional
+ reqHeader map[string]string
+ wantLastMod string
+ wantContentType string
+ wantContentRange string
+ wantStatus int
+ }
+ htmlModTime := mustStat(t, "testdata/index.html").ModTime()
+ tests := map[string]testCase{
+ "no_last_modified": {
+ file: "testdata/style.css",
+ wantContentType: "text/css; charset=utf-8",
+ wantStatus: 200,
+ },
+ "with_last_modified": {
+ file: "testdata/index.html",
+ wantContentType: "text/html; charset=utf-8",
+ modtime: htmlModTime,
+ wantLastMod: htmlModTime.UTC().Format(TimeFormat),
+ wantStatus: 200,
+ },
+ "not_modified_modtime": {
+ file: "testdata/style.css",
+ serveETag: `"foo"`, // Last-Modified sent only when no ETag
+ modtime: htmlModTime,
+ reqHeader: map[string]string{
+ "If-Modified-Since": htmlModTime.UTC().Format(TimeFormat),
+ },
+ wantStatus: 304,
+ },
+ "not_modified_modtime_with_contenttype": {
+ file: "testdata/style.css",
+ serveContentType: "text/css", // explicit content type
+ serveETag: `"foo"`, // Last-Modified sent only when no ETag
+ modtime: htmlModTime,
+ reqHeader: map[string]string{
+ "If-Modified-Since": htmlModTime.UTC().Format(TimeFormat),
+ },
+ wantStatus: 304,
+ },
+ "not_modified_etag": {
+ file: "testdata/style.css",
+ serveETag: `"foo"`,
+ reqHeader: map[string]string{
+ "If-None-Match": `"foo"`,
+ },
+ wantStatus: 304,
+ },
+ "not_modified_etag_no_seek": {
+ content: panicOnSeek{nil}, // should never be called
+ serveETag: `W/"foo"`, // If-None-Match uses weak ETag comparison
+ reqHeader: map[string]string{
+ "If-None-Match": `"baz", W/"foo"`,
+ },
+ wantStatus: 304,
+ },
+ "if_none_match_mismatch": {
+ file: "testdata/style.css",
+ serveETag: `"foo"`,
+ reqHeader: map[string]string{
+ "If-None-Match": `"Foo"`,
+ },
+ wantStatus: 200,
+ wantContentType: "text/css; charset=utf-8",
+ },
+ "if_none_match_malformed": {
+ file: "testdata/style.css",
+ serveETag: `"foo"`,
+ reqHeader: map[string]string{
+ "If-None-Match": `,`,
+ },
+ wantStatus: 200,
+ wantContentType: "text/css; charset=utf-8",
+ },
+ "range_good": {
+ file: "testdata/style.css",
+ serveETag: `"A"`,
+ reqHeader: map[string]string{
+ "Range": "bytes=0-4",
+ },
+ wantStatus: StatusPartialContent,
+ wantContentType: "text/css; charset=utf-8",
+ wantContentRange: "bytes 0-4/8",
+ },
+ "range_match": {
+ file: "testdata/style.css",
+ serveETag: `"A"`,
+ reqHeader: map[string]string{
+ "Range": "bytes=0-4",
+ "If-Range": `"A"`,
+ },
+ wantStatus: StatusPartialContent,
+ wantContentType: "text/css; charset=utf-8",
+ wantContentRange: "bytes 0-4/8",
+ },
+ "range_match_weak_etag": {
+ file: "testdata/style.css",
+ serveETag: `W/"A"`,
+ reqHeader: map[string]string{
+ "Range": "bytes=0-4",
+ "If-Range": `W/"A"`,
+ },
+ wantStatus: 200,
+ wantContentType: "text/css; charset=utf-8",
+ },
+ "range_no_overlap": {
+ file: "testdata/style.css",
+ serveETag: `"A"`,
+ reqHeader: map[string]string{
+ "Range": "bytes=10-20",
+ },
+ wantStatus: StatusRequestedRangeNotSatisfiable,
+ wantContentType: "text/plain; charset=utf-8",
+ wantContentRange: "bytes */8",
+ },
+ // An If-Range resource for entity "A", but entity "B" is now current.
+ // The Range request should be ignored.
+ "range_no_match": {
+ file: "testdata/style.css",
+ serveETag: `"A"`,
+ reqHeader: map[string]string{
+ "Range": "bytes=0-4",
+ "If-Range": `"B"`,
+ },
+ wantStatus: 200,
+ wantContentType: "text/css; charset=utf-8",
+ },
+ "range_with_modtime": {
+ file: "testdata/style.css",
+ modtime: time.Date(2014, 6, 25, 17, 12, 18, 0 /* nanos */, time.UTC),
+ reqHeader: map[string]string{
+ "Range": "bytes=0-4",
+ "If-Range": "Wed, 25 Jun 2014 17:12:18 GMT",
+ },
+ wantStatus: StatusPartialContent,
+ wantContentType: "text/css; charset=utf-8",
+ wantContentRange: "bytes 0-4/8",
+ wantLastMod: "Wed, 25 Jun 2014 17:12:18 GMT",
+ },
+ "range_with_modtime_mismatch": {
+ file: "testdata/style.css",
+ modtime: time.Date(2014, 6, 25, 17, 12, 18, 0 /* nanos */, time.UTC),
+ reqHeader: map[string]string{
+ "Range": "bytes=0-4",
+ "If-Range": "Wed, 25 Jun 2014 17:12:19 GMT",
+ },
+ wantStatus: StatusOK,
+ wantContentType: "text/css; charset=utf-8",
+ wantLastMod: "Wed, 25 Jun 2014 17:12:18 GMT",
+ },
+ "range_with_modtime_nanos": {
+ file: "testdata/style.css",
+ modtime: time.Date(2014, 6, 25, 17, 12, 18, 123 /* nanos */, time.UTC),
+ reqHeader: map[string]string{
+ "Range": "bytes=0-4",
+ "If-Range": "Wed, 25 Jun 2014 17:12:18 GMT",
+ },
+ wantStatus: StatusPartialContent,
+ wantContentType: "text/css; charset=utf-8",
+ wantContentRange: "bytes 0-4/8",
+ wantLastMod: "Wed, 25 Jun 2014 17:12:18 GMT",
+ },
+ "unix_zero_modtime": {
+ content: strings.NewReader("<html>foo"),
+ modtime: time.Unix(0, 0),
+ wantStatus: StatusOK,
+ wantContentType: "text/html; charset=utf-8",
+ },
+ "ifmatch_matches": {
+ file: "testdata/style.css",
+ serveETag: `"A"`,
+ reqHeader: map[string]string{
+ "If-Match": `"Z", "A"`,
+ },
+ wantStatus: 200,
+ wantContentType: "text/css; charset=utf-8",
+ },
+ "ifmatch_star": {
+ file: "testdata/style.css",
+ serveETag: `"A"`,
+ reqHeader: map[string]string{
+ "If-Match": `*`,
+ },
+ wantStatus: 200,
+ wantContentType: "text/css; charset=utf-8",
+ },
+ "ifmatch_failed": {
+ file: "testdata/style.css",
+ serveETag: `"A"`,
+ reqHeader: map[string]string{
+ "If-Match": `"B"`,
+ },
+ wantStatus: 412,
+ },
+ "ifmatch_fails_on_weak_etag": {
+ file: "testdata/style.css",
+ serveETag: `W/"A"`,
+ reqHeader: map[string]string{
+ "If-Match": `W/"A"`,
+ },
+ wantStatus: 412,
+ },
+ "if_unmodified_since_true": {
+ file: "testdata/style.css",
+ modtime: htmlModTime,
+ reqHeader: map[string]string{
+ "If-Unmodified-Since": htmlModTime.UTC().Format(TimeFormat),
+ },
+ wantStatus: 200,
+ wantContentType: "text/css; charset=utf-8",
+ wantLastMod: htmlModTime.UTC().Format(TimeFormat),
+ },
+ "if_unmodified_since_false": {
+ file: "testdata/style.css",
+ modtime: htmlModTime,
+ reqHeader: map[string]string{
+ "If-Unmodified-Since": htmlModTime.Add(-2 * time.Second).UTC().Format(TimeFormat),
+ },
+ wantStatus: 412,
+ wantLastMod: htmlModTime.UTC().Format(TimeFormat),
+ },
+ }
+ for testName, tt := range tests {
+ var content io.ReadSeeker
+ if tt.file != "" {
+ f, err := os.Open(tt.file)
+ if err != nil {
+ t.Fatalf("test %q: %v", testName, err)
+ }
+ defer f.Close()
+ content = f
+ } else {
+ content = tt.content
+ }
+ for _, method := range []string{"GET", "HEAD"} {
+ //restore content in case it is consumed by previous method
+ if content, ok := content.(*strings.Reader); ok {
+ content.Seek(0, io.SeekStart)
+ }
+
+ servec <- serveParam{
+ name: filepath.Base(tt.file),
+ content: content,
+ modtime: tt.modtime,
+ etag: tt.serveETag,
+ contentType: tt.serveContentType,
+ }
+ req, err := NewRequest(method, ts.URL, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ for k, v := range tt.reqHeader {
+ req.Header.Set(k, v)
+ }
+
+ c := ts.Client()
+ res, err := c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ io.Copy(io.Discard, res.Body)
+ res.Body.Close()
+ if res.StatusCode != tt.wantStatus {
+ t.Errorf("test %q using %q: got status = %d; want %d", testName, method, res.StatusCode, tt.wantStatus)
+ }
+ if g, e := res.Header.Get("Content-Type"), tt.wantContentType; g != e {
+ t.Errorf("test %q using %q: got content-type = %q, want %q", testName, method, g, e)
+ }
+ if g, e := res.Header.Get("Content-Range"), tt.wantContentRange; g != e {
+ t.Errorf("test %q using %q: got content-range = %q, want %q", testName, method, g, e)
+ }
+ if g, e := res.Header.Get("Last-Modified"), tt.wantLastMod; g != e {
+ t.Errorf("test %q using %q: got last-modified = %q, want %q", testName, method, g, e)
+ }
+ }
+ }
+}
+
+// Issue 12991
+func TestServerFileStatError(t *testing.T) {
+ rec := httptest.NewRecorder()
+ r, _ := NewRequest("GET", "http://foo/", nil)
+ redirect := false
+ name := "file.txt"
+ fs := issue12991FS{}
+ ExportServeFile(rec, r, fs, name, redirect)
+ if body := rec.Body.String(); !strings.Contains(body, "403") || !strings.Contains(body, "Forbidden") {
+ t.Errorf("wanted 403 forbidden message; got: %s", body)
+ }
+}
+
+type issue12991FS struct{}
+
+func (issue12991FS) Open(string) (File, error) { return issue12991File{}, nil }
+
+type issue12991File struct{ File }
+
+func (issue12991File) Stat() (fs.FileInfo, error) { return nil, fs.ErrPermission }
+func (issue12991File) Close() error { return nil }
+
+func TestServeContentErrorMessages(t *testing.T) { run(t, testServeContentErrorMessages) }
+func testServeContentErrorMessages(t *testing.T, mode testMode) {
+ fs := fakeFS{
+ "/500": &fakeFileInfo{
+ err: errors.New("random error"),
+ },
+ "/403": &fakeFileInfo{
+ err: &fs.PathError{Err: fs.ErrPermission},
+ },
+ }
+ ts := newClientServerTest(t, mode, FileServer(fs)).ts
+ c := ts.Client()
+ for _, code := range []int{403, 404, 500} {
+ res, err := c.Get(fmt.Sprintf("%s/%d", ts.URL, code))
+ if err != nil {
+ t.Errorf("Error fetching /%d: %v", code, err)
+ continue
+ }
+ if res.StatusCode != code {
+ t.Errorf("For /%d, status code = %d; want %d", code, res.StatusCode, code)
+ }
+ res.Body.Close()
+ }
+}
+
+// verifies that sendfile is being used on Linux
+func TestLinuxSendfile(t *testing.T) {
+ setParallel(t)
+ defer afterTest(t)
+ if runtime.GOOS != "linux" {
+ t.Skip("skipping; linux-only test")
+ }
+ if _, err := exec.LookPath("strace"); err != nil {
+ t.Skip("skipping; strace not found in path")
+ }
+
+ ln, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ lnf, err := ln.(*net.TCPListener).File()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ln.Close()
+
+ // Attempt to run strace, and skip on failure - this test requires SYS_PTRACE.
+ if err := exec.Command("strace", "-f", "-q", os.Args[0], "-test.run=^$").Run(); err != nil {
+ t.Skipf("skipping; failed to run strace: %v", err)
+ }
+
+ filename := fmt.Sprintf("1kb-%d", os.Getpid())
+ filepath := path.Join(os.TempDir(), filename)
+
+ if err := os.WriteFile(filepath, bytes.Repeat([]byte{'a'}, 1<<10), 0755); err != nil {
+ t.Fatal(err)
+ }
+ defer os.Remove(filepath)
+
+ var buf strings.Builder
+ child := exec.Command("strace", "-f", "-q", os.Args[0], "-test.run=TestLinuxSendfileChild")
+ child.ExtraFiles = append(child.ExtraFiles, lnf)
+ child.Env = append([]string{"GO_WANT_HELPER_PROCESS=1"}, os.Environ()...)
+ child.Stdout = &buf
+ child.Stderr = &buf
+ if err := child.Start(); err != nil {
+ t.Skipf("skipping; failed to start straced child: %v", err)
+ }
+
+ res, err := Get(fmt.Sprintf("http://%s/%s", ln.Addr(), filename))
+ if err != nil {
+ t.Fatalf("http client error: %v", err)
+ }
+ _, err = io.Copy(io.Discard, res.Body)
+ if err != nil {
+ t.Fatalf("client body read error: %v", err)
+ }
+ res.Body.Close()
+
+ // Force child to exit cleanly.
+ Post(fmt.Sprintf("http://%s/quit", ln.Addr()), "", nil)
+ child.Wait()
+
+ rx := regexp.MustCompile(`\b(n64:)?sendfile(64)?\(`)
+ out := buf.String()
+ if !rx.MatchString(out) {
+ t.Errorf("no sendfile system call found in:\n%s", out)
+ }
+}
+
+func getBody(t *testing.T, testName string, req Request, client *Client) (*Response, []byte) {
+ r, err := client.Do(&req)
+ if err != nil {
+ t.Fatalf("%s: for URL %q, send error: %v", testName, req.URL.String(), err)
+ }
+ b, err := io.ReadAll(r.Body)
+ if err != nil {
+ t.Fatalf("%s: for URL %q, reading body: %v", testName, req.URL.String(), err)
+ }
+ return r, b
+}
+
+// TestLinuxSendfileChild isn't a real test. It's used as a helper process
+// for TestLinuxSendfile.
+func TestLinuxSendfileChild(*testing.T) {
+ if os.Getenv("GO_WANT_HELPER_PROCESS") != "1" {
+ return
+ }
+ defer os.Exit(0)
+ fd3 := os.NewFile(3, "ephemeral-port-listener")
+ ln, err := net.FileListener(fd3)
+ if err != nil {
+ panic(err)
+ }
+ mux := NewServeMux()
+ mux.Handle("/", FileServer(Dir(os.TempDir())))
+ mux.HandleFunc("/quit", func(ResponseWriter, *Request) {
+ os.Exit(0)
+ })
+ s := &Server{Handler: mux}
+ err = s.Serve(ln)
+ if err != nil {
+ panic(err)
+ }
+}
+
+// Issues 18984, 49552: tests that requests for paths beyond files return not-found errors
+func TestFileServerNotDirError(t *testing.T) {
+ run(t, func(t *testing.T, mode testMode) {
+ t.Run("Dir", func(t *testing.T) {
+ testFileServerNotDirError(t, mode, func(path string) FileSystem { return Dir(path) })
+ })
+ t.Run("FS", func(t *testing.T) {
+ testFileServerNotDirError(t, mode, func(path string) FileSystem { return FS(os.DirFS(path)) })
+ })
+ })
+}
+
+func testFileServerNotDirError(t *testing.T, mode testMode, newfs func(string) FileSystem) {
+ ts := newClientServerTest(t, mode, FileServer(newfs("testdata"))).ts
+
+ res, err := ts.Client().Get(ts.URL + "/index.html/not-a-file")
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ if res.StatusCode != 404 {
+ t.Errorf("StatusCode = %v; want 404", res.StatusCode)
+ }
+
+ test := func(name string, fsys FileSystem) {
+ t.Run(name, func(t *testing.T) {
+ _, err = fsys.Open("/index.html/not-a-file")
+ if err == nil {
+ t.Fatal("err == nil; want != nil")
+ }
+ if !errors.Is(err, fs.ErrNotExist) {
+ t.Errorf("err = %v; errors.Is(err, fs.ErrNotExist) = %v; want true", err,
+ errors.Is(err, fs.ErrNotExist))
+ }
+
+ _, err = fsys.Open("/index.html/not-a-dir/not-a-file")
+ if err == nil {
+ t.Fatal("err == nil; want != nil")
+ }
+ if !errors.Is(err, fs.ErrNotExist) {
+ t.Errorf("err = %v; errors.Is(err, fs.ErrNotExist) = %v; want true", err,
+ errors.Is(err, fs.ErrNotExist))
+ }
+ })
+ }
+
+ absPath, err := filepath.Abs("testdata")
+ if err != nil {
+ t.Fatal("get abs path:", err)
+ }
+
+ test("RelativePath", newfs("testdata"))
+ test("AbsolutePath", newfs(absPath))
+}
+
+func TestFileServerCleanPath(t *testing.T) {
+ tests := []struct {
+ path string
+ wantCode int
+ wantOpen []string
+ }{
+ {"/", 200, []string{"/", "/index.html"}},
+ {"/dir", 301, []string{"/dir"}},
+ {"/dir/", 200, []string{"/dir", "/dir/index.html"}},
+ }
+ for _, tt := range tests {
+ var log []string
+ rr := httptest.NewRecorder()
+ req, _ := NewRequest("GET", "http://foo.localhost"+tt.path, nil)
+ FileServer(fileServerCleanPathDir{&log}).ServeHTTP(rr, req)
+ if !reflect.DeepEqual(log, tt.wantOpen) {
+ t.Logf("For %s: Opens = %q; want %q", tt.path, log, tt.wantOpen)
+ }
+ if rr.Code != tt.wantCode {
+ t.Logf("For %s: Response code = %d; want %d", tt.path, rr.Code, tt.wantCode)
+ }
+ }
+}
+
+type fileServerCleanPathDir struct {
+ log *[]string
+}
+
+func (d fileServerCleanPathDir) Open(path string) (File, error) {
+ *(d.log) = append(*(d.log), path)
+ if path == "/" || path == "/dir" || path == "/dir/" {
+ // Just return back something that's a directory.
+ return Dir(".").Open(".")
+ }
+ return nil, fs.ErrNotExist
+}
+
+type panicOnSeek struct{ io.ReadSeeker }
+
+func Test_scanETag(t *testing.T) {
+ tests := []struct {
+ in string
+ wantETag string
+ wantRemain string
+ }{
+ {`W/"etag-1"`, `W/"etag-1"`, ""},
+ {`"etag-2"`, `"etag-2"`, ""},
+ {`"etag-1", "etag-2"`, `"etag-1"`, `, "etag-2"`},
+ {"", "", ""},
+ {"W/", "", ""},
+ {`W/"truc`, "", ""},
+ {`w/"case-sensitive"`, "", ""},
+ {`"spaced etag"`, "", ""},
+ }
+ for _, test := range tests {
+ etag, remain := ExportScanETag(test.in)
+ if etag != test.wantETag || remain != test.wantRemain {
+ t.Errorf("scanETag(%q)=%q %q, want %q %q", test.in, etag, remain, test.wantETag, test.wantRemain)
+ }
+ }
+}
+
+// Issue 40940: Ensure that we only accept non-negative suffix-lengths
+// in "Range": "bytes=-N", and should reject "bytes=--2".
+func TestServeFileRejectsInvalidSuffixLengths(t *testing.T) {
+ run(t, testServeFileRejectsInvalidSuffixLengths, []testMode{http1Mode, https1Mode, http2Mode})
+}
+func testServeFileRejectsInvalidSuffixLengths(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, FileServer(Dir("testdata"))).ts
+
+ tests := []struct {
+ r string
+ wantCode int
+ wantBody string
+ }{
+ {"bytes=--6", 416, "invalid range\n"},
+ {"bytes=--0", 416, "invalid range\n"},
+ {"bytes=---0", 416, "invalid range\n"},
+ {"bytes=-6", 206, "hello\n"},
+ {"bytes=6-", 206, "html says hello\n"},
+ {"bytes=-6-", 416, "invalid range\n"},
+ {"bytes=-0", 206, ""},
+ {"bytes=", 200, "index.html says hello\n"},
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.r, func(t *testing.T) {
+ req, err := NewRequest("GET", cst.URL+"/index.html", nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ req.Header.Set("Range", tt.r)
+ res, err := cst.Client().Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if g, w := res.StatusCode, tt.wantCode; g != w {
+ t.Errorf("StatusCode mismatch: got %d want %d", g, w)
+ }
+ slurp, err := io.ReadAll(res.Body)
+ res.Body.Close()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if g, w := string(slurp), tt.wantBody; g != w {
+ t.Fatalf("Content mismatch:\nGot: %q\nWant: %q", g, w)
+ }
+ })
+ }
+}
+
+func TestFileServerMethods(t *testing.T) {
+ run(t, testFileServerMethods)
+}
+func testFileServerMethods(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, FileServer(Dir("testdata"))).ts
+
+ file, err := os.ReadFile(testFile)
+ if err != nil {
+ t.Fatal("reading file:", err)
+ }
+
+ // Get contents via various methods.
+ //
+ // See https://go.dev/issue/59471 for a proposal to limit the set of methods handled.
+ // For now, test the historical behavior.
+ for _, method := range []string{
+ MethodGet,
+ MethodHead,
+ MethodPost,
+ MethodPut,
+ MethodPatch,
+ MethodDelete,
+ MethodOptions,
+ MethodTrace,
+ } {
+ req, _ := NewRequest(method, ts.URL+"/file", nil)
+ t.Log(req.URL)
+ res, err := ts.Client().Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ body, err := io.ReadAll(res.Body)
+ res.Body.Close()
+ if err != nil {
+ t.Fatal(err)
+ }
+ wantBody := file
+ if method == MethodHead {
+ wantBody = nil
+ }
+ if !bytes.Equal(body, wantBody) {
+ t.Fatalf("%v: got body %q, want %q", method, body, wantBody)
+ }
+ if got, want := res.Header.Get("Content-Length"), fmt.Sprint(len(file)); got != want {
+ t.Fatalf("%v: got Content-Length %q, want %q", method, got, want)
+ }
+ }
+}
diff --git a/src/net/http/h2_bundle.go b/src/net/http/h2_bundle.go
new file mode 100644
index 0000000..dd59e1f
--- /dev/null
+++ b/src/net/http/h2_bundle.go
@@ -0,0 +1,11493 @@
+//go:build !nethttpomithttp2
+// +build !nethttpomithttp2
+
+// Code generated by golang.org/x/tools/cmd/bundle. DO NOT EDIT.
+// $ bundle -o=h2_bundle.go -prefix=http2 -tags=!nethttpomithttp2 golang.org/x/net/http2
+
+// Package http2 implements the HTTP/2 protocol.
+//
+// This package is low-level and intended to be used directly by very
+// few people. Most users will use it indirectly through the automatic
+// use by the net/http package (from Go 1.6 and later).
+// For use in earlier Go versions see ConfigureServer. (Transport support
+// requires Go 1.6 or later)
+//
+// See https://http2.github.io/ for more information on HTTP/2.
+//
+// See https://http2.golang.org/ for a test server running this code.
+//
+
+package http
+
+import (
+ "bufio"
+ "bytes"
+ "compress/gzip"
+ "context"
+ "crypto/rand"
+ "crypto/tls"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "io"
+ "io/fs"
+ "log"
+ "math"
+ mathrand "math/rand"
+ "net"
+ "net/http/httptrace"
+ "net/textproto"
+ "net/url"
+ "os"
+ "reflect"
+ "runtime"
+ "sort"
+ "strconv"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "golang.org/x/net/http/httpguts"
+ "golang.org/x/net/http2/hpack"
+ "golang.org/x/net/idna"
+)
+
+// The HTTP protocols are defined in terms of ASCII, not Unicode. This file
+// contains helper functions which may use Unicode-aware functions which would
+// otherwise be unsafe and could introduce vulnerabilities if used improperly.
+
+// asciiEqualFold is strings.EqualFold, ASCII only. It reports whether s and t
+// are equal, ASCII-case-insensitively.
+func http2asciiEqualFold(s, t string) bool {
+ if len(s) != len(t) {
+ return false
+ }
+ for i := 0; i < len(s); i++ {
+ if http2lower(s[i]) != http2lower(t[i]) {
+ return false
+ }
+ }
+ return true
+}
+
+// lower returns the ASCII lowercase version of b.
+func http2lower(b byte) byte {
+ if 'A' <= b && b <= 'Z' {
+ return b + ('a' - 'A')
+ }
+ return b
+}
+
+// isASCIIPrint returns whether s is ASCII and printable according to
+// https://tools.ietf.org/html/rfc20#section-4.2.
+func http2isASCIIPrint(s string) bool {
+ for i := 0; i < len(s); i++ {
+ if s[i] < ' ' || s[i] > '~' {
+ return false
+ }
+ }
+ return true
+}
+
+// asciiToLower returns the lowercase version of s if s is ASCII and printable,
+// and whether or not it was.
+func http2asciiToLower(s string) (lower string, ok bool) {
+ if !http2isASCIIPrint(s) {
+ return "", false
+ }
+ return strings.ToLower(s), true
+}
+
+// A list of the possible cipher suite ids. Taken from
+// https://www.iana.org/assignments/tls-parameters/tls-parameters.txt
+
+const (
+ http2cipher_TLS_NULL_WITH_NULL_NULL uint16 = 0x0000
+ http2cipher_TLS_RSA_WITH_NULL_MD5 uint16 = 0x0001
+ http2cipher_TLS_RSA_WITH_NULL_SHA uint16 = 0x0002
+ http2cipher_TLS_RSA_EXPORT_WITH_RC4_40_MD5 uint16 = 0x0003
+ http2cipher_TLS_RSA_WITH_RC4_128_MD5 uint16 = 0x0004
+ http2cipher_TLS_RSA_WITH_RC4_128_SHA uint16 = 0x0005
+ http2cipher_TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5 uint16 = 0x0006
+ http2cipher_TLS_RSA_WITH_IDEA_CBC_SHA uint16 = 0x0007
+ http2cipher_TLS_RSA_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x0008
+ http2cipher_TLS_RSA_WITH_DES_CBC_SHA uint16 = 0x0009
+ http2cipher_TLS_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x000A
+ http2cipher_TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x000B
+ http2cipher_TLS_DH_DSS_WITH_DES_CBC_SHA uint16 = 0x000C
+ http2cipher_TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA uint16 = 0x000D
+ http2cipher_TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x000E
+ http2cipher_TLS_DH_RSA_WITH_DES_CBC_SHA uint16 = 0x000F
+ http2cipher_TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x0010
+ http2cipher_TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x0011
+ http2cipher_TLS_DHE_DSS_WITH_DES_CBC_SHA uint16 = 0x0012
+ http2cipher_TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA uint16 = 0x0013
+ http2cipher_TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x0014
+ http2cipher_TLS_DHE_RSA_WITH_DES_CBC_SHA uint16 = 0x0015
+ http2cipher_TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x0016
+ http2cipher_TLS_DH_anon_EXPORT_WITH_RC4_40_MD5 uint16 = 0x0017
+ http2cipher_TLS_DH_anon_WITH_RC4_128_MD5 uint16 = 0x0018
+ http2cipher_TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x0019
+ http2cipher_TLS_DH_anon_WITH_DES_CBC_SHA uint16 = 0x001A
+ http2cipher_TLS_DH_anon_WITH_3DES_EDE_CBC_SHA uint16 = 0x001B
+ // Reserved uint16 = 0x001C-1D
+ http2cipher_TLS_KRB5_WITH_DES_CBC_SHA uint16 = 0x001E
+ http2cipher_TLS_KRB5_WITH_3DES_EDE_CBC_SHA uint16 = 0x001F
+ http2cipher_TLS_KRB5_WITH_RC4_128_SHA uint16 = 0x0020
+ http2cipher_TLS_KRB5_WITH_IDEA_CBC_SHA uint16 = 0x0021
+ http2cipher_TLS_KRB5_WITH_DES_CBC_MD5 uint16 = 0x0022
+ http2cipher_TLS_KRB5_WITH_3DES_EDE_CBC_MD5 uint16 = 0x0023
+ http2cipher_TLS_KRB5_WITH_RC4_128_MD5 uint16 = 0x0024
+ http2cipher_TLS_KRB5_WITH_IDEA_CBC_MD5 uint16 = 0x0025
+ http2cipher_TLS_KRB5_EXPORT_WITH_DES_CBC_40_SHA uint16 = 0x0026
+ http2cipher_TLS_KRB5_EXPORT_WITH_RC2_CBC_40_SHA uint16 = 0x0027
+ http2cipher_TLS_KRB5_EXPORT_WITH_RC4_40_SHA uint16 = 0x0028
+ http2cipher_TLS_KRB5_EXPORT_WITH_DES_CBC_40_MD5 uint16 = 0x0029
+ http2cipher_TLS_KRB5_EXPORT_WITH_RC2_CBC_40_MD5 uint16 = 0x002A
+ http2cipher_TLS_KRB5_EXPORT_WITH_RC4_40_MD5 uint16 = 0x002B
+ http2cipher_TLS_PSK_WITH_NULL_SHA uint16 = 0x002C
+ http2cipher_TLS_DHE_PSK_WITH_NULL_SHA uint16 = 0x002D
+ http2cipher_TLS_RSA_PSK_WITH_NULL_SHA uint16 = 0x002E
+ http2cipher_TLS_RSA_WITH_AES_128_CBC_SHA uint16 = 0x002F
+ http2cipher_TLS_DH_DSS_WITH_AES_128_CBC_SHA uint16 = 0x0030
+ http2cipher_TLS_DH_RSA_WITH_AES_128_CBC_SHA uint16 = 0x0031
+ http2cipher_TLS_DHE_DSS_WITH_AES_128_CBC_SHA uint16 = 0x0032
+ http2cipher_TLS_DHE_RSA_WITH_AES_128_CBC_SHA uint16 = 0x0033
+ http2cipher_TLS_DH_anon_WITH_AES_128_CBC_SHA uint16 = 0x0034
+ http2cipher_TLS_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0035
+ http2cipher_TLS_DH_DSS_WITH_AES_256_CBC_SHA uint16 = 0x0036
+ http2cipher_TLS_DH_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0037
+ http2cipher_TLS_DHE_DSS_WITH_AES_256_CBC_SHA uint16 = 0x0038
+ http2cipher_TLS_DHE_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0039
+ http2cipher_TLS_DH_anon_WITH_AES_256_CBC_SHA uint16 = 0x003A
+ http2cipher_TLS_RSA_WITH_NULL_SHA256 uint16 = 0x003B
+ http2cipher_TLS_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x003C
+ http2cipher_TLS_RSA_WITH_AES_256_CBC_SHA256 uint16 = 0x003D
+ http2cipher_TLS_DH_DSS_WITH_AES_128_CBC_SHA256 uint16 = 0x003E
+ http2cipher_TLS_DH_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x003F
+ http2cipher_TLS_DHE_DSS_WITH_AES_128_CBC_SHA256 uint16 = 0x0040
+ http2cipher_TLS_RSA_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0041
+ http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0042
+ http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0043
+ http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0044
+ http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0045
+ http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0046
+ // Reserved uint16 = 0x0047-4F
+ // Reserved uint16 = 0x0050-58
+ // Reserved uint16 = 0x0059-5C
+ // Unassigned uint16 = 0x005D-5F
+ // Reserved uint16 = 0x0060-66
+ http2cipher_TLS_DHE_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x0067
+ http2cipher_TLS_DH_DSS_WITH_AES_256_CBC_SHA256 uint16 = 0x0068
+ http2cipher_TLS_DH_RSA_WITH_AES_256_CBC_SHA256 uint16 = 0x0069
+ http2cipher_TLS_DHE_DSS_WITH_AES_256_CBC_SHA256 uint16 = 0x006A
+ http2cipher_TLS_DHE_RSA_WITH_AES_256_CBC_SHA256 uint16 = 0x006B
+ http2cipher_TLS_DH_anon_WITH_AES_128_CBC_SHA256 uint16 = 0x006C
+ http2cipher_TLS_DH_anon_WITH_AES_256_CBC_SHA256 uint16 = 0x006D
+ // Unassigned uint16 = 0x006E-83
+ http2cipher_TLS_RSA_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0084
+ http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0085
+ http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0086
+ http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0087
+ http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0088
+ http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0089
+ http2cipher_TLS_PSK_WITH_RC4_128_SHA uint16 = 0x008A
+ http2cipher_TLS_PSK_WITH_3DES_EDE_CBC_SHA uint16 = 0x008B
+ http2cipher_TLS_PSK_WITH_AES_128_CBC_SHA uint16 = 0x008C
+ http2cipher_TLS_PSK_WITH_AES_256_CBC_SHA uint16 = 0x008D
+ http2cipher_TLS_DHE_PSK_WITH_RC4_128_SHA uint16 = 0x008E
+ http2cipher_TLS_DHE_PSK_WITH_3DES_EDE_CBC_SHA uint16 = 0x008F
+ http2cipher_TLS_DHE_PSK_WITH_AES_128_CBC_SHA uint16 = 0x0090
+ http2cipher_TLS_DHE_PSK_WITH_AES_256_CBC_SHA uint16 = 0x0091
+ http2cipher_TLS_RSA_PSK_WITH_RC4_128_SHA uint16 = 0x0092
+ http2cipher_TLS_RSA_PSK_WITH_3DES_EDE_CBC_SHA uint16 = 0x0093
+ http2cipher_TLS_RSA_PSK_WITH_AES_128_CBC_SHA uint16 = 0x0094
+ http2cipher_TLS_RSA_PSK_WITH_AES_256_CBC_SHA uint16 = 0x0095
+ http2cipher_TLS_RSA_WITH_SEED_CBC_SHA uint16 = 0x0096
+ http2cipher_TLS_DH_DSS_WITH_SEED_CBC_SHA uint16 = 0x0097
+ http2cipher_TLS_DH_RSA_WITH_SEED_CBC_SHA uint16 = 0x0098
+ http2cipher_TLS_DHE_DSS_WITH_SEED_CBC_SHA uint16 = 0x0099
+ http2cipher_TLS_DHE_RSA_WITH_SEED_CBC_SHA uint16 = 0x009A
+ http2cipher_TLS_DH_anon_WITH_SEED_CBC_SHA uint16 = 0x009B
+ http2cipher_TLS_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0x009C
+ http2cipher_TLS_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0x009D
+ http2cipher_TLS_DHE_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0x009E
+ http2cipher_TLS_DHE_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0x009F
+ http2cipher_TLS_DH_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0x00A0
+ http2cipher_TLS_DH_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0x00A1
+ http2cipher_TLS_DHE_DSS_WITH_AES_128_GCM_SHA256 uint16 = 0x00A2
+ http2cipher_TLS_DHE_DSS_WITH_AES_256_GCM_SHA384 uint16 = 0x00A3
+ http2cipher_TLS_DH_DSS_WITH_AES_128_GCM_SHA256 uint16 = 0x00A4
+ http2cipher_TLS_DH_DSS_WITH_AES_256_GCM_SHA384 uint16 = 0x00A5
+ http2cipher_TLS_DH_anon_WITH_AES_128_GCM_SHA256 uint16 = 0x00A6
+ http2cipher_TLS_DH_anon_WITH_AES_256_GCM_SHA384 uint16 = 0x00A7
+ http2cipher_TLS_PSK_WITH_AES_128_GCM_SHA256 uint16 = 0x00A8
+ http2cipher_TLS_PSK_WITH_AES_256_GCM_SHA384 uint16 = 0x00A9
+ http2cipher_TLS_DHE_PSK_WITH_AES_128_GCM_SHA256 uint16 = 0x00AA
+ http2cipher_TLS_DHE_PSK_WITH_AES_256_GCM_SHA384 uint16 = 0x00AB
+ http2cipher_TLS_RSA_PSK_WITH_AES_128_GCM_SHA256 uint16 = 0x00AC
+ http2cipher_TLS_RSA_PSK_WITH_AES_256_GCM_SHA384 uint16 = 0x00AD
+ http2cipher_TLS_PSK_WITH_AES_128_CBC_SHA256 uint16 = 0x00AE
+ http2cipher_TLS_PSK_WITH_AES_256_CBC_SHA384 uint16 = 0x00AF
+ http2cipher_TLS_PSK_WITH_NULL_SHA256 uint16 = 0x00B0
+ http2cipher_TLS_PSK_WITH_NULL_SHA384 uint16 = 0x00B1
+ http2cipher_TLS_DHE_PSK_WITH_AES_128_CBC_SHA256 uint16 = 0x00B2
+ http2cipher_TLS_DHE_PSK_WITH_AES_256_CBC_SHA384 uint16 = 0x00B3
+ http2cipher_TLS_DHE_PSK_WITH_NULL_SHA256 uint16 = 0x00B4
+ http2cipher_TLS_DHE_PSK_WITH_NULL_SHA384 uint16 = 0x00B5
+ http2cipher_TLS_RSA_PSK_WITH_AES_128_CBC_SHA256 uint16 = 0x00B6
+ http2cipher_TLS_RSA_PSK_WITH_AES_256_CBC_SHA384 uint16 = 0x00B7
+ http2cipher_TLS_RSA_PSK_WITH_NULL_SHA256 uint16 = 0x00B8
+ http2cipher_TLS_RSA_PSK_WITH_NULL_SHA384 uint16 = 0x00B9
+ http2cipher_TLS_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BA
+ http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BB
+ http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BC
+ http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BD
+ http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BE
+ http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BF
+ http2cipher_TLS_RSA_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C0
+ http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C1
+ http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C2
+ http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C3
+ http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C4
+ http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C5
+ // Unassigned uint16 = 0x00C6-FE
+ http2cipher_TLS_EMPTY_RENEGOTIATION_INFO_SCSV uint16 = 0x00FF
+ // Unassigned uint16 = 0x01-55,*
+ http2cipher_TLS_FALLBACK_SCSV uint16 = 0x5600
+ // Unassigned uint16 = 0x5601 - 0xC000
+ http2cipher_TLS_ECDH_ECDSA_WITH_NULL_SHA uint16 = 0xC001
+ http2cipher_TLS_ECDH_ECDSA_WITH_RC4_128_SHA uint16 = 0xC002
+ http2cipher_TLS_ECDH_ECDSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC003
+ http2cipher_TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA uint16 = 0xC004
+ http2cipher_TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA uint16 = 0xC005
+ http2cipher_TLS_ECDHE_ECDSA_WITH_NULL_SHA uint16 = 0xC006
+ http2cipher_TLS_ECDHE_ECDSA_WITH_RC4_128_SHA uint16 = 0xC007
+ http2cipher_TLS_ECDHE_ECDSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC008
+ http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA uint16 = 0xC009
+ http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA uint16 = 0xC00A
+ http2cipher_TLS_ECDH_RSA_WITH_NULL_SHA uint16 = 0xC00B
+ http2cipher_TLS_ECDH_RSA_WITH_RC4_128_SHA uint16 = 0xC00C
+ http2cipher_TLS_ECDH_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC00D
+ http2cipher_TLS_ECDH_RSA_WITH_AES_128_CBC_SHA uint16 = 0xC00E
+ http2cipher_TLS_ECDH_RSA_WITH_AES_256_CBC_SHA uint16 = 0xC00F
+ http2cipher_TLS_ECDHE_RSA_WITH_NULL_SHA uint16 = 0xC010
+ http2cipher_TLS_ECDHE_RSA_WITH_RC4_128_SHA uint16 = 0xC011
+ http2cipher_TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC012
+ http2cipher_TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA uint16 = 0xC013
+ http2cipher_TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA uint16 = 0xC014
+ http2cipher_TLS_ECDH_anon_WITH_NULL_SHA uint16 = 0xC015
+ http2cipher_TLS_ECDH_anon_WITH_RC4_128_SHA uint16 = 0xC016
+ http2cipher_TLS_ECDH_anon_WITH_3DES_EDE_CBC_SHA uint16 = 0xC017
+ http2cipher_TLS_ECDH_anon_WITH_AES_128_CBC_SHA uint16 = 0xC018
+ http2cipher_TLS_ECDH_anon_WITH_AES_256_CBC_SHA uint16 = 0xC019
+ http2cipher_TLS_SRP_SHA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC01A
+ http2cipher_TLS_SRP_SHA_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC01B
+ http2cipher_TLS_SRP_SHA_DSS_WITH_3DES_EDE_CBC_SHA uint16 = 0xC01C
+ http2cipher_TLS_SRP_SHA_WITH_AES_128_CBC_SHA uint16 = 0xC01D
+ http2cipher_TLS_SRP_SHA_RSA_WITH_AES_128_CBC_SHA uint16 = 0xC01E
+ http2cipher_TLS_SRP_SHA_DSS_WITH_AES_128_CBC_SHA uint16 = 0xC01F
+ http2cipher_TLS_SRP_SHA_WITH_AES_256_CBC_SHA uint16 = 0xC020
+ http2cipher_TLS_SRP_SHA_RSA_WITH_AES_256_CBC_SHA uint16 = 0xC021
+ http2cipher_TLS_SRP_SHA_DSS_WITH_AES_256_CBC_SHA uint16 = 0xC022
+ http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 uint16 = 0xC023
+ http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384 uint16 = 0xC024
+ http2cipher_TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA256 uint16 = 0xC025
+ http2cipher_TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA384 uint16 = 0xC026
+ http2cipher_TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0xC027
+ http2cipher_TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384 uint16 = 0xC028
+ http2cipher_TLS_ECDH_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0xC029
+ http2cipher_TLS_ECDH_RSA_WITH_AES_256_CBC_SHA384 uint16 = 0xC02A
+ http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 uint16 = 0xC02B
+ http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 uint16 = 0xC02C
+ http2cipher_TLS_ECDH_ECDSA_WITH_AES_128_GCM_SHA256 uint16 = 0xC02D
+ http2cipher_TLS_ECDH_ECDSA_WITH_AES_256_GCM_SHA384 uint16 = 0xC02E
+ http2cipher_TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0xC02F
+ http2cipher_TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0xC030
+ http2cipher_TLS_ECDH_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0xC031
+ http2cipher_TLS_ECDH_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0xC032
+ http2cipher_TLS_ECDHE_PSK_WITH_RC4_128_SHA uint16 = 0xC033
+ http2cipher_TLS_ECDHE_PSK_WITH_3DES_EDE_CBC_SHA uint16 = 0xC034
+ http2cipher_TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA uint16 = 0xC035
+ http2cipher_TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA uint16 = 0xC036
+ http2cipher_TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 uint16 = 0xC037
+ http2cipher_TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA384 uint16 = 0xC038
+ http2cipher_TLS_ECDHE_PSK_WITH_NULL_SHA uint16 = 0xC039
+ http2cipher_TLS_ECDHE_PSK_WITH_NULL_SHA256 uint16 = 0xC03A
+ http2cipher_TLS_ECDHE_PSK_WITH_NULL_SHA384 uint16 = 0xC03B
+ http2cipher_TLS_RSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC03C
+ http2cipher_TLS_RSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC03D
+ http2cipher_TLS_DH_DSS_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC03E
+ http2cipher_TLS_DH_DSS_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC03F
+ http2cipher_TLS_DH_RSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC040
+ http2cipher_TLS_DH_RSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC041
+ http2cipher_TLS_DHE_DSS_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC042
+ http2cipher_TLS_DHE_DSS_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC043
+ http2cipher_TLS_DHE_RSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC044
+ http2cipher_TLS_DHE_RSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC045
+ http2cipher_TLS_DH_anon_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC046
+ http2cipher_TLS_DH_anon_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC047
+ http2cipher_TLS_ECDHE_ECDSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC048
+ http2cipher_TLS_ECDHE_ECDSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC049
+ http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC04A
+ http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC04B
+ http2cipher_TLS_ECDHE_RSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC04C
+ http2cipher_TLS_ECDHE_RSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC04D
+ http2cipher_TLS_ECDH_RSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC04E
+ http2cipher_TLS_ECDH_RSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC04F
+ http2cipher_TLS_RSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC050
+ http2cipher_TLS_RSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC051
+ http2cipher_TLS_DHE_RSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC052
+ http2cipher_TLS_DHE_RSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC053
+ http2cipher_TLS_DH_RSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC054
+ http2cipher_TLS_DH_RSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC055
+ http2cipher_TLS_DHE_DSS_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC056
+ http2cipher_TLS_DHE_DSS_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC057
+ http2cipher_TLS_DH_DSS_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC058
+ http2cipher_TLS_DH_DSS_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC059
+ http2cipher_TLS_DH_anon_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC05A
+ http2cipher_TLS_DH_anon_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC05B
+ http2cipher_TLS_ECDHE_ECDSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC05C
+ http2cipher_TLS_ECDHE_ECDSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC05D
+ http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC05E
+ http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC05F
+ http2cipher_TLS_ECDHE_RSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC060
+ http2cipher_TLS_ECDHE_RSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC061
+ http2cipher_TLS_ECDH_RSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC062
+ http2cipher_TLS_ECDH_RSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC063
+ http2cipher_TLS_PSK_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC064
+ http2cipher_TLS_PSK_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC065
+ http2cipher_TLS_DHE_PSK_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC066
+ http2cipher_TLS_DHE_PSK_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC067
+ http2cipher_TLS_RSA_PSK_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC068
+ http2cipher_TLS_RSA_PSK_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC069
+ http2cipher_TLS_PSK_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC06A
+ http2cipher_TLS_PSK_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC06B
+ http2cipher_TLS_DHE_PSK_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC06C
+ http2cipher_TLS_DHE_PSK_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC06D
+ http2cipher_TLS_RSA_PSK_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC06E
+ http2cipher_TLS_RSA_PSK_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC06F
+ http2cipher_TLS_ECDHE_PSK_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC070
+ http2cipher_TLS_ECDHE_PSK_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC071
+ http2cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC072
+ http2cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC073
+ http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC074
+ http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC075
+ http2cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC076
+ http2cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC077
+ http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC078
+ http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC079
+ http2cipher_TLS_RSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC07A
+ http2cipher_TLS_RSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC07B
+ http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC07C
+ http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC07D
+ http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC07E
+ http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC07F
+ http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC080
+ http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC081
+ http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC082
+ http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC083
+ http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC084
+ http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC085
+ http2cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC086
+ http2cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC087
+ http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC088
+ http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC089
+ http2cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC08A
+ http2cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC08B
+ http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC08C
+ http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC08D
+ http2cipher_TLS_PSK_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC08E
+ http2cipher_TLS_PSK_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC08F
+ http2cipher_TLS_DHE_PSK_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC090
+ http2cipher_TLS_DHE_PSK_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC091
+ http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC092
+ http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC093
+ http2cipher_TLS_PSK_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC094
+ http2cipher_TLS_PSK_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC095
+ http2cipher_TLS_DHE_PSK_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC096
+ http2cipher_TLS_DHE_PSK_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC097
+ http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC098
+ http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC099
+ http2cipher_TLS_ECDHE_PSK_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC09A
+ http2cipher_TLS_ECDHE_PSK_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC09B
+ http2cipher_TLS_RSA_WITH_AES_128_CCM uint16 = 0xC09C
+ http2cipher_TLS_RSA_WITH_AES_256_CCM uint16 = 0xC09D
+ http2cipher_TLS_DHE_RSA_WITH_AES_128_CCM uint16 = 0xC09E
+ http2cipher_TLS_DHE_RSA_WITH_AES_256_CCM uint16 = 0xC09F
+ http2cipher_TLS_RSA_WITH_AES_128_CCM_8 uint16 = 0xC0A0
+ http2cipher_TLS_RSA_WITH_AES_256_CCM_8 uint16 = 0xC0A1
+ http2cipher_TLS_DHE_RSA_WITH_AES_128_CCM_8 uint16 = 0xC0A2
+ http2cipher_TLS_DHE_RSA_WITH_AES_256_CCM_8 uint16 = 0xC0A3
+ http2cipher_TLS_PSK_WITH_AES_128_CCM uint16 = 0xC0A4
+ http2cipher_TLS_PSK_WITH_AES_256_CCM uint16 = 0xC0A5
+ http2cipher_TLS_DHE_PSK_WITH_AES_128_CCM uint16 = 0xC0A6
+ http2cipher_TLS_DHE_PSK_WITH_AES_256_CCM uint16 = 0xC0A7
+ http2cipher_TLS_PSK_WITH_AES_128_CCM_8 uint16 = 0xC0A8
+ http2cipher_TLS_PSK_WITH_AES_256_CCM_8 uint16 = 0xC0A9
+ http2cipher_TLS_PSK_DHE_WITH_AES_128_CCM_8 uint16 = 0xC0AA
+ http2cipher_TLS_PSK_DHE_WITH_AES_256_CCM_8 uint16 = 0xC0AB
+ http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CCM uint16 = 0xC0AC
+ http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CCM uint16 = 0xC0AD
+ http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 uint16 = 0xC0AE
+ http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CCM_8 uint16 = 0xC0AF
+ // Unassigned uint16 = 0xC0B0-FF
+ // Unassigned uint16 = 0xC1-CB,*
+ // Unassigned uint16 = 0xCC00-A7
+ http2cipher_TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCA8
+ http2cipher_TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCA9
+ http2cipher_TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCAA
+ http2cipher_TLS_PSK_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCAB
+ http2cipher_TLS_ECDHE_PSK_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCAC
+ http2cipher_TLS_DHE_PSK_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCAD
+ http2cipher_TLS_RSA_PSK_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCAE
+)
+
+// isBadCipher reports whether the cipher is blacklisted by the HTTP/2 spec.
+// References:
+// https://tools.ietf.org/html/rfc7540#appendix-A
+// Reject cipher suites from Appendix A.
+// "This list includes those cipher suites that do not
+// offer an ephemeral key exchange and those that are
+// based on the TLS null, stream or block cipher type"
+func http2isBadCipher(cipher uint16) bool {
+ switch cipher {
+ case http2cipher_TLS_NULL_WITH_NULL_NULL,
+ http2cipher_TLS_RSA_WITH_NULL_MD5,
+ http2cipher_TLS_RSA_WITH_NULL_SHA,
+ http2cipher_TLS_RSA_EXPORT_WITH_RC4_40_MD5,
+ http2cipher_TLS_RSA_WITH_RC4_128_MD5,
+ http2cipher_TLS_RSA_WITH_RC4_128_SHA,
+ http2cipher_TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5,
+ http2cipher_TLS_RSA_WITH_IDEA_CBC_SHA,
+ http2cipher_TLS_RSA_EXPORT_WITH_DES40_CBC_SHA,
+ http2cipher_TLS_RSA_WITH_DES_CBC_SHA,
+ http2cipher_TLS_RSA_WITH_3DES_EDE_CBC_SHA,
+ http2cipher_TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA,
+ http2cipher_TLS_DH_DSS_WITH_DES_CBC_SHA,
+ http2cipher_TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA,
+ http2cipher_TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA,
+ http2cipher_TLS_DH_RSA_WITH_DES_CBC_SHA,
+ http2cipher_TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA,
+ http2cipher_TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA,
+ http2cipher_TLS_DHE_DSS_WITH_DES_CBC_SHA,
+ http2cipher_TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA,
+ http2cipher_TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA,
+ http2cipher_TLS_DHE_RSA_WITH_DES_CBC_SHA,
+ http2cipher_TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA,
+ http2cipher_TLS_DH_anon_EXPORT_WITH_RC4_40_MD5,
+ http2cipher_TLS_DH_anon_WITH_RC4_128_MD5,
+ http2cipher_TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA,
+ http2cipher_TLS_DH_anon_WITH_DES_CBC_SHA,
+ http2cipher_TLS_DH_anon_WITH_3DES_EDE_CBC_SHA,
+ http2cipher_TLS_KRB5_WITH_DES_CBC_SHA,
+ http2cipher_TLS_KRB5_WITH_3DES_EDE_CBC_SHA,
+ http2cipher_TLS_KRB5_WITH_RC4_128_SHA,
+ http2cipher_TLS_KRB5_WITH_IDEA_CBC_SHA,
+ http2cipher_TLS_KRB5_WITH_DES_CBC_MD5,
+ http2cipher_TLS_KRB5_WITH_3DES_EDE_CBC_MD5,
+ http2cipher_TLS_KRB5_WITH_RC4_128_MD5,
+ http2cipher_TLS_KRB5_WITH_IDEA_CBC_MD5,
+ http2cipher_TLS_KRB5_EXPORT_WITH_DES_CBC_40_SHA,
+ http2cipher_TLS_KRB5_EXPORT_WITH_RC2_CBC_40_SHA,
+ http2cipher_TLS_KRB5_EXPORT_WITH_RC4_40_SHA,
+ http2cipher_TLS_KRB5_EXPORT_WITH_DES_CBC_40_MD5,
+ http2cipher_TLS_KRB5_EXPORT_WITH_RC2_CBC_40_MD5,
+ http2cipher_TLS_KRB5_EXPORT_WITH_RC4_40_MD5,
+ http2cipher_TLS_PSK_WITH_NULL_SHA,
+ http2cipher_TLS_DHE_PSK_WITH_NULL_SHA,
+ http2cipher_TLS_RSA_PSK_WITH_NULL_SHA,
+ http2cipher_TLS_RSA_WITH_AES_128_CBC_SHA,
+ http2cipher_TLS_DH_DSS_WITH_AES_128_CBC_SHA,
+ http2cipher_TLS_DH_RSA_WITH_AES_128_CBC_SHA,
+ http2cipher_TLS_DHE_DSS_WITH_AES_128_CBC_SHA,
+ http2cipher_TLS_DHE_RSA_WITH_AES_128_CBC_SHA,
+ http2cipher_TLS_DH_anon_WITH_AES_128_CBC_SHA,
+ http2cipher_TLS_RSA_WITH_AES_256_CBC_SHA,
+ http2cipher_TLS_DH_DSS_WITH_AES_256_CBC_SHA,
+ http2cipher_TLS_DH_RSA_WITH_AES_256_CBC_SHA,
+ http2cipher_TLS_DHE_DSS_WITH_AES_256_CBC_SHA,
+ http2cipher_TLS_DHE_RSA_WITH_AES_256_CBC_SHA,
+ http2cipher_TLS_DH_anon_WITH_AES_256_CBC_SHA,
+ http2cipher_TLS_RSA_WITH_NULL_SHA256,
+ http2cipher_TLS_RSA_WITH_AES_128_CBC_SHA256,
+ http2cipher_TLS_RSA_WITH_AES_256_CBC_SHA256,
+ http2cipher_TLS_DH_DSS_WITH_AES_128_CBC_SHA256,
+ http2cipher_TLS_DH_RSA_WITH_AES_128_CBC_SHA256,
+ http2cipher_TLS_DHE_DSS_WITH_AES_128_CBC_SHA256,
+ http2cipher_TLS_RSA_WITH_CAMELLIA_128_CBC_SHA,
+ http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_CBC_SHA,
+ http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_CBC_SHA,
+ http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_128_CBC_SHA,
+ http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_128_CBC_SHA,
+ http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_CBC_SHA,
+ http2cipher_TLS_DHE_RSA_WITH_AES_128_CBC_SHA256,
+ http2cipher_TLS_DH_DSS_WITH_AES_256_CBC_SHA256,
+ http2cipher_TLS_DH_RSA_WITH_AES_256_CBC_SHA256,
+ http2cipher_TLS_DHE_DSS_WITH_AES_256_CBC_SHA256,
+ http2cipher_TLS_DHE_RSA_WITH_AES_256_CBC_SHA256,
+ http2cipher_TLS_DH_anon_WITH_AES_128_CBC_SHA256,
+ http2cipher_TLS_DH_anon_WITH_AES_256_CBC_SHA256,
+ http2cipher_TLS_RSA_WITH_CAMELLIA_256_CBC_SHA,
+ http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_CBC_SHA,
+ http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_CBC_SHA,
+ http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_256_CBC_SHA,
+ http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_256_CBC_SHA,
+ http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_CBC_SHA,
+ http2cipher_TLS_PSK_WITH_RC4_128_SHA,
+ http2cipher_TLS_PSK_WITH_3DES_EDE_CBC_SHA,
+ http2cipher_TLS_PSK_WITH_AES_128_CBC_SHA,
+ http2cipher_TLS_PSK_WITH_AES_256_CBC_SHA,
+ http2cipher_TLS_DHE_PSK_WITH_RC4_128_SHA,
+ http2cipher_TLS_DHE_PSK_WITH_3DES_EDE_CBC_SHA,
+ http2cipher_TLS_DHE_PSK_WITH_AES_128_CBC_SHA,
+ http2cipher_TLS_DHE_PSK_WITH_AES_256_CBC_SHA,
+ http2cipher_TLS_RSA_PSK_WITH_RC4_128_SHA,
+ http2cipher_TLS_RSA_PSK_WITH_3DES_EDE_CBC_SHA,
+ http2cipher_TLS_RSA_PSK_WITH_AES_128_CBC_SHA,
+ http2cipher_TLS_RSA_PSK_WITH_AES_256_CBC_SHA,
+ http2cipher_TLS_RSA_WITH_SEED_CBC_SHA,
+ http2cipher_TLS_DH_DSS_WITH_SEED_CBC_SHA,
+ http2cipher_TLS_DH_RSA_WITH_SEED_CBC_SHA,
+ http2cipher_TLS_DHE_DSS_WITH_SEED_CBC_SHA,
+ http2cipher_TLS_DHE_RSA_WITH_SEED_CBC_SHA,
+ http2cipher_TLS_DH_anon_WITH_SEED_CBC_SHA,
+ http2cipher_TLS_RSA_WITH_AES_128_GCM_SHA256,
+ http2cipher_TLS_RSA_WITH_AES_256_GCM_SHA384,
+ http2cipher_TLS_DH_RSA_WITH_AES_128_GCM_SHA256,
+ http2cipher_TLS_DH_RSA_WITH_AES_256_GCM_SHA384,
+ http2cipher_TLS_DH_DSS_WITH_AES_128_GCM_SHA256,
+ http2cipher_TLS_DH_DSS_WITH_AES_256_GCM_SHA384,
+ http2cipher_TLS_DH_anon_WITH_AES_128_GCM_SHA256,
+ http2cipher_TLS_DH_anon_WITH_AES_256_GCM_SHA384,
+ http2cipher_TLS_PSK_WITH_AES_128_GCM_SHA256,
+ http2cipher_TLS_PSK_WITH_AES_256_GCM_SHA384,
+ http2cipher_TLS_RSA_PSK_WITH_AES_128_GCM_SHA256,
+ http2cipher_TLS_RSA_PSK_WITH_AES_256_GCM_SHA384,
+ http2cipher_TLS_PSK_WITH_AES_128_CBC_SHA256,
+ http2cipher_TLS_PSK_WITH_AES_256_CBC_SHA384,
+ http2cipher_TLS_PSK_WITH_NULL_SHA256,
+ http2cipher_TLS_PSK_WITH_NULL_SHA384,
+ http2cipher_TLS_DHE_PSK_WITH_AES_128_CBC_SHA256,
+ http2cipher_TLS_DHE_PSK_WITH_AES_256_CBC_SHA384,
+ http2cipher_TLS_DHE_PSK_WITH_NULL_SHA256,
+ http2cipher_TLS_DHE_PSK_WITH_NULL_SHA384,
+ http2cipher_TLS_RSA_PSK_WITH_AES_128_CBC_SHA256,
+ http2cipher_TLS_RSA_PSK_WITH_AES_256_CBC_SHA384,
+ http2cipher_TLS_RSA_PSK_WITH_NULL_SHA256,
+ http2cipher_TLS_RSA_PSK_WITH_NULL_SHA384,
+ http2cipher_TLS_RSA_WITH_CAMELLIA_128_CBC_SHA256,
+ http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_CBC_SHA256,
+ http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_CBC_SHA256,
+ http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_128_CBC_SHA256,
+ http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_128_CBC_SHA256,
+ http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_CBC_SHA256,
+ http2cipher_TLS_RSA_WITH_CAMELLIA_256_CBC_SHA256,
+ http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_CBC_SHA256,
+ http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_CBC_SHA256,
+ http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_256_CBC_SHA256,
+ http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_256_CBC_SHA256,
+ http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_CBC_SHA256,
+ http2cipher_TLS_EMPTY_RENEGOTIATION_INFO_SCSV,
+ http2cipher_TLS_ECDH_ECDSA_WITH_NULL_SHA,
+ http2cipher_TLS_ECDH_ECDSA_WITH_RC4_128_SHA,
+ http2cipher_TLS_ECDH_ECDSA_WITH_3DES_EDE_CBC_SHA,
+ http2cipher_TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA,
+ http2cipher_TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA,
+ http2cipher_TLS_ECDHE_ECDSA_WITH_NULL_SHA,
+ http2cipher_TLS_ECDHE_ECDSA_WITH_RC4_128_SHA,
+ http2cipher_TLS_ECDHE_ECDSA_WITH_3DES_EDE_CBC_SHA,
+ http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
+ http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
+ http2cipher_TLS_ECDH_RSA_WITH_NULL_SHA,
+ http2cipher_TLS_ECDH_RSA_WITH_RC4_128_SHA,
+ http2cipher_TLS_ECDH_RSA_WITH_3DES_EDE_CBC_SHA,
+ http2cipher_TLS_ECDH_RSA_WITH_AES_128_CBC_SHA,
+ http2cipher_TLS_ECDH_RSA_WITH_AES_256_CBC_SHA,
+ http2cipher_TLS_ECDHE_RSA_WITH_NULL_SHA,
+ http2cipher_TLS_ECDHE_RSA_WITH_RC4_128_SHA,
+ http2cipher_TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
+ http2cipher_TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
+ http2cipher_TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
+ http2cipher_TLS_ECDH_anon_WITH_NULL_SHA,
+ http2cipher_TLS_ECDH_anon_WITH_RC4_128_SHA,
+ http2cipher_TLS_ECDH_anon_WITH_3DES_EDE_CBC_SHA,
+ http2cipher_TLS_ECDH_anon_WITH_AES_128_CBC_SHA,
+ http2cipher_TLS_ECDH_anon_WITH_AES_256_CBC_SHA,
+ http2cipher_TLS_SRP_SHA_WITH_3DES_EDE_CBC_SHA,
+ http2cipher_TLS_SRP_SHA_RSA_WITH_3DES_EDE_CBC_SHA,
+ http2cipher_TLS_SRP_SHA_DSS_WITH_3DES_EDE_CBC_SHA,
+ http2cipher_TLS_SRP_SHA_WITH_AES_128_CBC_SHA,
+ http2cipher_TLS_SRP_SHA_RSA_WITH_AES_128_CBC_SHA,
+ http2cipher_TLS_SRP_SHA_DSS_WITH_AES_128_CBC_SHA,
+ http2cipher_TLS_SRP_SHA_WITH_AES_256_CBC_SHA,
+ http2cipher_TLS_SRP_SHA_RSA_WITH_AES_256_CBC_SHA,
+ http2cipher_TLS_SRP_SHA_DSS_WITH_AES_256_CBC_SHA,
+ http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256,
+ http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384,
+ http2cipher_TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA256,
+ http2cipher_TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA384,
+ http2cipher_TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256,
+ http2cipher_TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384,
+ http2cipher_TLS_ECDH_RSA_WITH_AES_128_CBC_SHA256,
+ http2cipher_TLS_ECDH_RSA_WITH_AES_256_CBC_SHA384,
+ http2cipher_TLS_ECDH_ECDSA_WITH_AES_128_GCM_SHA256,
+ http2cipher_TLS_ECDH_ECDSA_WITH_AES_256_GCM_SHA384,
+ http2cipher_TLS_ECDH_RSA_WITH_AES_128_GCM_SHA256,
+ http2cipher_TLS_ECDH_RSA_WITH_AES_256_GCM_SHA384,
+ http2cipher_TLS_ECDHE_PSK_WITH_RC4_128_SHA,
+ http2cipher_TLS_ECDHE_PSK_WITH_3DES_EDE_CBC_SHA,
+ http2cipher_TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA,
+ http2cipher_TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA,
+ http2cipher_TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256,
+ http2cipher_TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA384,
+ http2cipher_TLS_ECDHE_PSK_WITH_NULL_SHA,
+ http2cipher_TLS_ECDHE_PSK_WITH_NULL_SHA256,
+ http2cipher_TLS_ECDHE_PSK_WITH_NULL_SHA384,
+ http2cipher_TLS_RSA_WITH_ARIA_128_CBC_SHA256,
+ http2cipher_TLS_RSA_WITH_ARIA_256_CBC_SHA384,
+ http2cipher_TLS_DH_DSS_WITH_ARIA_128_CBC_SHA256,
+ http2cipher_TLS_DH_DSS_WITH_ARIA_256_CBC_SHA384,
+ http2cipher_TLS_DH_RSA_WITH_ARIA_128_CBC_SHA256,
+ http2cipher_TLS_DH_RSA_WITH_ARIA_256_CBC_SHA384,
+ http2cipher_TLS_DHE_DSS_WITH_ARIA_128_CBC_SHA256,
+ http2cipher_TLS_DHE_DSS_WITH_ARIA_256_CBC_SHA384,
+ http2cipher_TLS_DHE_RSA_WITH_ARIA_128_CBC_SHA256,
+ http2cipher_TLS_DHE_RSA_WITH_ARIA_256_CBC_SHA384,
+ http2cipher_TLS_DH_anon_WITH_ARIA_128_CBC_SHA256,
+ http2cipher_TLS_DH_anon_WITH_ARIA_256_CBC_SHA384,
+ http2cipher_TLS_ECDHE_ECDSA_WITH_ARIA_128_CBC_SHA256,
+ http2cipher_TLS_ECDHE_ECDSA_WITH_ARIA_256_CBC_SHA384,
+ http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_128_CBC_SHA256,
+ http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_256_CBC_SHA384,
+ http2cipher_TLS_ECDHE_RSA_WITH_ARIA_128_CBC_SHA256,
+ http2cipher_TLS_ECDHE_RSA_WITH_ARIA_256_CBC_SHA384,
+ http2cipher_TLS_ECDH_RSA_WITH_ARIA_128_CBC_SHA256,
+ http2cipher_TLS_ECDH_RSA_WITH_ARIA_256_CBC_SHA384,
+ http2cipher_TLS_RSA_WITH_ARIA_128_GCM_SHA256,
+ http2cipher_TLS_RSA_WITH_ARIA_256_GCM_SHA384,
+ http2cipher_TLS_DH_RSA_WITH_ARIA_128_GCM_SHA256,
+ http2cipher_TLS_DH_RSA_WITH_ARIA_256_GCM_SHA384,
+ http2cipher_TLS_DH_DSS_WITH_ARIA_128_GCM_SHA256,
+ http2cipher_TLS_DH_DSS_WITH_ARIA_256_GCM_SHA384,
+ http2cipher_TLS_DH_anon_WITH_ARIA_128_GCM_SHA256,
+ http2cipher_TLS_DH_anon_WITH_ARIA_256_GCM_SHA384,
+ http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_128_GCM_SHA256,
+ http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_256_GCM_SHA384,
+ http2cipher_TLS_ECDH_RSA_WITH_ARIA_128_GCM_SHA256,
+ http2cipher_TLS_ECDH_RSA_WITH_ARIA_256_GCM_SHA384,
+ http2cipher_TLS_PSK_WITH_ARIA_128_CBC_SHA256,
+ http2cipher_TLS_PSK_WITH_ARIA_256_CBC_SHA384,
+ http2cipher_TLS_DHE_PSK_WITH_ARIA_128_CBC_SHA256,
+ http2cipher_TLS_DHE_PSK_WITH_ARIA_256_CBC_SHA384,
+ http2cipher_TLS_RSA_PSK_WITH_ARIA_128_CBC_SHA256,
+ http2cipher_TLS_RSA_PSK_WITH_ARIA_256_CBC_SHA384,
+ http2cipher_TLS_PSK_WITH_ARIA_128_GCM_SHA256,
+ http2cipher_TLS_PSK_WITH_ARIA_256_GCM_SHA384,
+ http2cipher_TLS_RSA_PSK_WITH_ARIA_128_GCM_SHA256,
+ http2cipher_TLS_RSA_PSK_WITH_ARIA_256_GCM_SHA384,
+ http2cipher_TLS_ECDHE_PSK_WITH_ARIA_128_CBC_SHA256,
+ http2cipher_TLS_ECDHE_PSK_WITH_ARIA_256_CBC_SHA384,
+ http2cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_128_CBC_SHA256,
+ http2cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_256_CBC_SHA384,
+ http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_128_CBC_SHA256,
+ http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_256_CBC_SHA384,
+ http2cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_128_CBC_SHA256,
+ http2cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_256_CBC_SHA384,
+ http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_128_CBC_SHA256,
+ http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_256_CBC_SHA384,
+ http2cipher_TLS_RSA_WITH_CAMELLIA_128_GCM_SHA256,
+ http2cipher_TLS_RSA_WITH_CAMELLIA_256_GCM_SHA384,
+ http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_GCM_SHA256,
+ http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_GCM_SHA384,
+ http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_GCM_SHA256,
+ http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_GCM_SHA384,
+ http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_GCM_SHA256,
+ http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_GCM_SHA384,
+ http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_128_GCM_SHA256,
+ http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_256_GCM_SHA384,
+ http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_128_GCM_SHA256,
+ http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_256_GCM_SHA384,
+ http2cipher_TLS_PSK_WITH_CAMELLIA_128_GCM_SHA256,
+ http2cipher_TLS_PSK_WITH_CAMELLIA_256_GCM_SHA384,
+ http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_128_GCM_SHA256,
+ http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_256_GCM_SHA384,
+ http2cipher_TLS_PSK_WITH_CAMELLIA_128_CBC_SHA256,
+ http2cipher_TLS_PSK_WITH_CAMELLIA_256_CBC_SHA384,
+ http2cipher_TLS_DHE_PSK_WITH_CAMELLIA_128_CBC_SHA256,
+ http2cipher_TLS_DHE_PSK_WITH_CAMELLIA_256_CBC_SHA384,
+ http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_128_CBC_SHA256,
+ http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_256_CBC_SHA384,
+ http2cipher_TLS_ECDHE_PSK_WITH_CAMELLIA_128_CBC_SHA256,
+ http2cipher_TLS_ECDHE_PSK_WITH_CAMELLIA_256_CBC_SHA384,
+ http2cipher_TLS_RSA_WITH_AES_128_CCM,
+ http2cipher_TLS_RSA_WITH_AES_256_CCM,
+ http2cipher_TLS_RSA_WITH_AES_128_CCM_8,
+ http2cipher_TLS_RSA_WITH_AES_256_CCM_8,
+ http2cipher_TLS_PSK_WITH_AES_128_CCM,
+ http2cipher_TLS_PSK_WITH_AES_256_CCM,
+ http2cipher_TLS_PSK_WITH_AES_128_CCM_8,
+ http2cipher_TLS_PSK_WITH_AES_256_CCM_8:
+ return true
+ default:
+ return false
+ }
+}
+
+// ClientConnPool manages a pool of HTTP/2 client connections.
+type http2ClientConnPool interface {
+ // GetClientConn returns a specific HTTP/2 connection (usually
+ // a TLS-TCP connection) to an HTTP/2 server. On success, the
+ // returned ClientConn accounts for the upcoming RoundTrip
+ // call, so the caller should not omit it. If the caller needs
+ // to, ClientConn.RoundTrip can be called with a bogus
+ // new(http.Request) to release the stream reservation.
+ GetClientConn(req *Request, addr string) (*http2ClientConn, error)
+ MarkDead(*http2ClientConn)
+}
+
+// clientConnPoolIdleCloser is the interface implemented by ClientConnPool
+// implementations which can close their idle connections.
+type http2clientConnPoolIdleCloser interface {
+ http2ClientConnPool
+ closeIdleConnections()
+}
+
+var (
+ _ http2clientConnPoolIdleCloser = (*http2clientConnPool)(nil)
+ _ http2clientConnPoolIdleCloser = http2noDialClientConnPool{}
+)
+
+// TODO: use singleflight for dialing and addConnCalls?
+type http2clientConnPool struct {
+ t *http2Transport
+
+ mu sync.Mutex // TODO: maybe switch to RWMutex
+ // TODO: add support for sharing conns based on cert names
+ // (e.g. share conn for googleapis.com and appspot.com)
+ conns map[string][]*http2ClientConn // key is host:port
+ dialing map[string]*http2dialCall // currently in-flight dials
+ keys map[*http2ClientConn][]string
+ addConnCalls map[string]*http2addConnCall // in-flight addConnIfNeeded calls
+}
+
+func (p *http2clientConnPool) GetClientConn(req *Request, addr string) (*http2ClientConn, error) {
+ return p.getClientConn(req, addr, http2dialOnMiss)
+}
+
+const (
+ http2dialOnMiss = true
+ http2noDialOnMiss = false
+)
+
+func (p *http2clientConnPool) getClientConn(req *Request, addr string, dialOnMiss bool) (*http2ClientConn, error) {
+ // TODO(dneil): Dial a new connection when t.DisableKeepAlives is set?
+ if http2isConnectionCloseRequest(req) && dialOnMiss {
+ // It gets its own connection.
+ http2traceGetConn(req, addr)
+ const singleUse = true
+ cc, err := p.t.dialClientConn(req.Context(), addr, singleUse)
+ if err != nil {
+ return nil, err
+ }
+ return cc, nil
+ }
+ for {
+ p.mu.Lock()
+ for _, cc := range p.conns[addr] {
+ if cc.ReserveNewRequest() {
+ // When a connection is presented to us by the net/http package,
+ // the GetConn hook has already been called.
+ // Don't call it a second time here.
+ if !cc.getConnCalled {
+ http2traceGetConn(req, addr)
+ }
+ cc.getConnCalled = false
+ p.mu.Unlock()
+ return cc, nil
+ }
+ }
+ if !dialOnMiss {
+ p.mu.Unlock()
+ return nil, http2ErrNoCachedConn
+ }
+ http2traceGetConn(req, addr)
+ call := p.getStartDialLocked(req.Context(), addr)
+ p.mu.Unlock()
+ <-call.done
+ if http2shouldRetryDial(call, req) {
+ continue
+ }
+ cc, err := call.res, call.err
+ if err != nil {
+ return nil, err
+ }
+ if cc.ReserveNewRequest() {
+ return cc, nil
+ }
+ }
+}
+
+// dialCall is an in-flight Transport dial call to a host.
+type http2dialCall struct {
+ _ http2incomparable
+ p *http2clientConnPool
+ // the context associated with the request
+ // that created this dialCall
+ ctx context.Context
+ done chan struct{} // closed when done
+ res *http2ClientConn // valid after done is closed
+ err error // valid after done is closed
+}
+
+// requires p.mu is held.
+func (p *http2clientConnPool) getStartDialLocked(ctx context.Context, addr string) *http2dialCall {
+ if call, ok := p.dialing[addr]; ok {
+ // A dial is already in-flight. Don't start another.
+ return call
+ }
+ call := &http2dialCall{p: p, done: make(chan struct{}), ctx: ctx}
+ if p.dialing == nil {
+ p.dialing = make(map[string]*http2dialCall)
+ }
+ p.dialing[addr] = call
+ go call.dial(call.ctx, addr)
+ return call
+}
+
+// run in its own goroutine.
+func (c *http2dialCall) dial(ctx context.Context, addr string) {
+ const singleUse = false // shared conn
+ c.res, c.err = c.p.t.dialClientConn(ctx, addr, singleUse)
+
+ c.p.mu.Lock()
+ delete(c.p.dialing, addr)
+ if c.err == nil {
+ c.p.addConnLocked(addr, c.res)
+ }
+ c.p.mu.Unlock()
+
+ close(c.done)
+}
+
+// addConnIfNeeded makes a NewClientConn out of c if a connection for key doesn't
+// already exist. It coalesces concurrent calls with the same key.
+// This is used by the http1 Transport code when it creates a new connection. Because
+// the http1 Transport doesn't de-dup TCP dials to outbound hosts (because it doesn't know
+// the protocol), it can get into a situation where it has multiple TLS connections.
+// This code decides which ones live or die.
+// The return value used is whether c was used.
+// c is never closed.
+func (p *http2clientConnPool) addConnIfNeeded(key string, t *http2Transport, c *tls.Conn) (used bool, err error) {
+ p.mu.Lock()
+ for _, cc := range p.conns[key] {
+ if cc.CanTakeNewRequest() {
+ p.mu.Unlock()
+ return false, nil
+ }
+ }
+ call, dup := p.addConnCalls[key]
+ if !dup {
+ if p.addConnCalls == nil {
+ p.addConnCalls = make(map[string]*http2addConnCall)
+ }
+ call = &http2addConnCall{
+ p: p,
+ done: make(chan struct{}),
+ }
+ p.addConnCalls[key] = call
+ go call.run(t, key, c)
+ }
+ p.mu.Unlock()
+
+ <-call.done
+ if call.err != nil {
+ return false, call.err
+ }
+ return !dup, nil
+}
+
+type http2addConnCall struct {
+ _ http2incomparable
+ p *http2clientConnPool
+ done chan struct{} // closed when done
+ err error
+}
+
+func (c *http2addConnCall) run(t *http2Transport, key string, tc *tls.Conn) {
+ cc, err := t.NewClientConn(tc)
+
+ p := c.p
+ p.mu.Lock()
+ if err != nil {
+ c.err = err
+ } else {
+ cc.getConnCalled = true // already called by the net/http package
+ p.addConnLocked(key, cc)
+ }
+ delete(p.addConnCalls, key)
+ p.mu.Unlock()
+ close(c.done)
+}
+
+// p.mu must be held
+func (p *http2clientConnPool) addConnLocked(key string, cc *http2ClientConn) {
+ for _, v := range p.conns[key] {
+ if v == cc {
+ return
+ }
+ }
+ if p.conns == nil {
+ p.conns = make(map[string][]*http2ClientConn)
+ }
+ if p.keys == nil {
+ p.keys = make(map[*http2ClientConn][]string)
+ }
+ p.conns[key] = append(p.conns[key], cc)
+ p.keys[cc] = append(p.keys[cc], key)
+}
+
+func (p *http2clientConnPool) MarkDead(cc *http2ClientConn) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ for _, key := range p.keys[cc] {
+ vv, ok := p.conns[key]
+ if !ok {
+ continue
+ }
+ newList := http2filterOutClientConn(vv, cc)
+ if len(newList) > 0 {
+ p.conns[key] = newList
+ } else {
+ delete(p.conns, key)
+ }
+ }
+ delete(p.keys, cc)
+}
+
+func (p *http2clientConnPool) closeIdleConnections() {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ // TODO: don't close a cc if it was just added to the pool
+ // milliseconds ago and has never been used. There's currently
+ // a small race window with the HTTP/1 Transport's integration
+ // where it can add an idle conn just before using it, and
+ // somebody else can concurrently call CloseIdleConns and
+ // break some caller's RoundTrip.
+ for _, vv := range p.conns {
+ for _, cc := range vv {
+ cc.closeIfIdle()
+ }
+ }
+}
+
+func http2filterOutClientConn(in []*http2ClientConn, exclude *http2ClientConn) []*http2ClientConn {
+ out := in[:0]
+ for _, v := range in {
+ if v != exclude {
+ out = append(out, v)
+ }
+ }
+ // If we filtered it out, zero out the last item to prevent
+ // the GC from seeing it.
+ if len(in) != len(out) {
+ in[len(in)-1] = nil
+ }
+ return out
+}
+
+// noDialClientConnPool is an implementation of http2.ClientConnPool
+// which never dials. We let the HTTP/1.1 client dial and use its TLS
+// connection instead.
+type http2noDialClientConnPool struct{ *http2clientConnPool }
+
+func (p http2noDialClientConnPool) GetClientConn(req *Request, addr string) (*http2ClientConn, error) {
+ return p.getClientConn(req, addr, http2noDialOnMiss)
+}
+
+// shouldRetryDial reports whether the current request should
+// retry dialing after the call finished unsuccessfully, for example
+// if the dial was canceled because of a context cancellation or
+// deadline expiry.
+func http2shouldRetryDial(call *http2dialCall, req *Request) bool {
+ if call.err == nil {
+ // No error, no need to retry
+ return false
+ }
+ if call.ctx == req.Context() {
+ // If the call has the same context as the request, the dial
+ // should not be retried, since any cancellation will have come
+ // from this request.
+ return false
+ }
+ if !errors.Is(call.err, context.Canceled) && !errors.Is(call.err, context.DeadlineExceeded) {
+ // If the call error is not because of a context cancellation or a deadline expiry,
+ // the dial should not be retried.
+ return false
+ }
+ // Only retry if the error is a context cancellation error or deadline expiry
+ // and the context associated with the call was canceled or expired.
+ return call.ctx.Err() != nil
+}
+
+// Buffer chunks are allocated from a pool to reduce pressure on GC.
+// The maximum wasted space per dataBuffer is 2x the largest size class,
+// which happens when the dataBuffer has multiple chunks and there is
+// one unread byte in both the first and last chunks. We use a few size
+// classes to minimize overheads for servers that typically receive very
+// small request bodies.
+//
+// TODO: Benchmark to determine if the pools are necessary. The GC may have
+// improved enough that we can instead allocate chunks like this:
+// make([]byte, max(16<<10, expectedBytesRemaining))
+var (
+ http2dataChunkSizeClasses = []int{
+ 1 << 10,
+ 2 << 10,
+ 4 << 10,
+ 8 << 10,
+ 16 << 10,
+ }
+ http2dataChunkPools = [...]sync.Pool{
+ {New: func() interface{} { return make([]byte, 1<<10) }},
+ {New: func() interface{} { return make([]byte, 2<<10) }},
+ {New: func() interface{} { return make([]byte, 4<<10) }},
+ {New: func() interface{} { return make([]byte, 8<<10) }},
+ {New: func() interface{} { return make([]byte, 16<<10) }},
+ }
+)
+
+func http2getDataBufferChunk(size int64) []byte {
+ i := 0
+ for ; i < len(http2dataChunkSizeClasses)-1; i++ {
+ if size <= int64(http2dataChunkSizeClasses[i]) {
+ break
+ }
+ }
+ return http2dataChunkPools[i].Get().([]byte)
+}
+
+func http2putDataBufferChunk(p []byte) {
+ for i, n := range http2dataChunkSizeClasses {
+ if len(p) == n {
+ http2dataChunkPools[i].Put(p)
+ return
+ }
+ }
+ panic(fmt.Sprintf("unexpected buffer len=%v", len(p)))
+}
+
+// dataBuffer is an io.ReadWriter backed by a list of data chunks.
+// Each dataBuffer is used to read DATA frames on a single stream.
+// The buffer is divided into chunks so the server can limit the
+// total memory used by a single connection without limiting the
+// request body size on any single stream.
+type http2dataBuffer struct {
+ chunks [][]byte
+ r int // next byte to read is chunks[0][r]
+ w int // next byte to write is chunks[len(chunks)-1][w]
+ size int // total buffered bytes
+ expected int64 // we expect at least this many bytes in future Write calls (ignored if <= 0)
+}
+
+var http2errReadEmpty = errors.New("read from empty dataBuffer")
+
+// Read copies bytes from the buffer into p.
+// It is an error to read when no data is available.
+func (b *http2dataBuffer) Read(p []byte) (int, error) {
+ if b.size == 0 {
+ return 0, http2errReadEmpty
+ }
+ var ntotal int
+ for len(p) > 0 && b.size > 0 {
+ readFrom := b.bytesFromFirstChunk()
+ n := copy(p, readFrom)
+ p = p[n:]
+ ntotal += n
+ b.r += n
+ b.size -= n
+ // If the first chunk has been consumed, advance to the next chunk.
+ if b.r == len(b.chunks[0]) {
+ http2putDataBufferChunk(b.chunks[0])
+ end := len(b.chunks) - 1
+ copy(b.chunks[:end], b.chunks[1:])
+ b.chunks[end] = nil
+ b.chunks = b.chunks[:end]
+ b.r = 0
+ }
+ }
+ return ntotal, nil
+}
+
+func (b *http2dataBuffer) bytesFromFirstChunk() []byte {
+ if len(b.chunks) == 1 {
+ return b.chunks[0][b.r:b.w]
+ }
+ return b.chunks[0][b.r:]
+}
+
+// Len returns the number of bytes of the unread portion of the buffer.
+func (b *http2dataBuffer) Len() int {
+ return b.size
+}
+
+// Write appends p to the buffer.
+func (b *http2dataBuffer) Write(p []byte) (int, error) {
+ ntotal := len(p)
+ for len(p) > 0 {
+ // If the last chunk is empty, allocate a new chunk. Try to allocate
+ // enough to fully copy p plus any additional bytes we expect to
+ // receive. However, this may allocate less than len(p).
+ want := int64(len(p))
+ if b.expected > want {
+ want = b.expected
+ }
+ chunk := b.lastChunkOrAlloc(want)
+ n := copy(chunk[b.w:], p)
+ p = p[n:]
+ b.w += n
+ b.size += n
+ b.expected -= int64(n)
+ }
+ return ntotal, nil
+}
+
+func (b *http2dataBuffer) lastChunkOrAlloc(want int64) []byte {
+ if len(b.chunks) != 0 {
+ last := b.chunks[len(b.chunks)-1]
+ if b.w < len(last) {
+ return last
+ }
+ }
+ chunk := http2getDataBufferChunk(want)
+ b.chunks = append(b.chunks, chunk)
+ b.w = 0
+ return chunk
+}
+
+// An ErrCode is an unsigned 32-bit error code as defined in the HTTP/2 spec.
+type http2ErrCode uint32
+
+const (
+ http2ErrCodeNo http2ErrCode = 0x0
+ http2ErrCodeProtocol http2ErrCode = 0x1
+ http2ErrCodeInternal http2ErrCode = 0x2
+ http2ErrCodeFlowControl http2ErrCode = 0x3
+ http2ErrCodeSettingsTimeout http2ErrCode = 0x4
+ http2ErrCodeStreamClosed http2ErrCode = 0x5
+ http2ErrCodeFrameSize http2ErrCode = 0x6
+ http2ErrCodeRefusedStream http2ErrCode = 0x7
+ http2ErrCodeCancel http2ErrCode = 0x8
+ http2ErrCodeCompression http2ErrCode = 0x9
+ http2ErrCodeConnect http2ErrCode = 0xa
+ http2ErrCodeEnhanceYourCalm http2ErrCode = 0xb
+ http2ErrCodeInadequateSecurity http2ErrCode = 0xc
+ http2ErrCodeHTTP11Required http2ErrCode = 0xd
+)
+
+var http2errCodeName = map[http2ErrCode]string{
+ http2ErrCodeNo: "NO_ERROR",
+ http2ErrCodeProtocol: "PROTOCOL_ERROR",
+ http2ErrCodeInternal: "INTERNAL_ERROR",
+ http2ErrCodeFlowControl: "FLOW_CONTROL_ERROR",
+ http2ErrCodeSettingsTimeout: "SETTINGS_TIMEOUT",
+ http2ErrCodeStreamClosed: "STREAM_CLOSED",
+ http2ErrCodeFrameSize: "FRAME_SIZE_ERROR",
+ http2ErrCodeRefusedStream: "REFUSED_STREAM",
+ http2ErrCodeCancel: "CANCEL",
+ http2ErrCodeCompression: "COMPRESSION_ERROR",
+ http2ErrCodeConnect: "CONNECT_ERROR",
+ http2ErrCodeEnhanceYourCalm: "ENHANCE_YOUR_CALM",
+ http2ErrCodeInadequateSecurity: "INADEQUATE_SECURITY",
+ http2ErrCodeHTTP11Required: "HTTP_1_1_REQUIRED",
+}
+
+func (e http2ErrCode) String() string {
+ if s, ok := http2errCodeName[e]; ok {
+ return s
+ }
+ return fmt.Sprintf("unknown error code 0x%x", uint32(e))
+}
+
+func (e http2ErrCode) stringToken() string {
+ if s, ok := http2errCodeName[e]; ok {
+ return s
+ }
+ return fmt.Sprintf("ERR_UNKNOWN_%d", uint32(e))
+}
+
+// ConnectionError is an error that results in the termination of the
+// entire connection.
+type http2ConnectionError http2ErrCode
+
+func (e http2ConnectionError) Error() string {
+ return fmt.Sprintf("connection error: %s", http2ErrCode(e))
+}
+
+// StreamError is an error that only affects one stream within an
+// HTTP/2 connection.
+type http2StreamError struct {
+ StreamID uint32
+ Code http2ErrCode
+ Cause error // optional additional detail
+}
+
+// errFromPeer is a sentinel error value for StreamError.Cause to
+// indicate that the StreamError was sent from the peer over the wire
+// and wasn't locally generated in the Transport.
+var http2errFromPeer = errors.New("received from peer")
+
+func http2streamError(id uint32, code http2ErrCode) http2StreamError {
+ return http2StreamError{StreamID: id, Code: code}
+}
+
+func (e http2StreamError) Error() string {
+ if e.Cause != nil {
+ return fmt.Sprintf("stream error: stream ID %d; %v; %v", e.StreamID, e.Code, e.Cause)
+ }
+ return fmt.Sprintf("stream error: stream ID %d; %v", e.StreamID, e.Code)
+}
+
+// 6.9.1 The Flow Control Window
+// "If a sender receives a WINDOW_UPDATE that causes a flow control
+// window to exceed this maximum it MUST terminate either the stream
+// or the connection, as appropriate. For streams, [...]; for the
+// connection, a GOAWAY frame with a FLOW_CONTROL_ERROR code."
+type http2goAwayFlowError struct{}
+
+func (http2goAwayFlowError) Error() string { return "connection exceeded flow control window size" }
+
+// connError represents an HTTP/2 ConnectionError error code, along
+// with a string (for debugging) explaining why.
+//
+// Errors of this type are only returned by the frame parser functions
+// and converted into ConnectionError(Code), after stashing away
+// the Reason into the Framer's errDetail field, accessible via
+// the (*Framer).ErrorDetail method.
+type http2connError struct {
+ Code http2ErrCode // the ConnectionError error code
+ Reason string // additional reason
+}
+
+func (e http2connError) Error() string {
+ return fmt.Sprintf("http2: connection error: %v: %v", e.Code, e.Reason)
+}
+
+type http2pseudoHeaderError string
+
+func (e http2pseudoHeaderError) Error() string {
+ return fmt.Sprintf("invalid pseudo-header %q", string(e))
+}
+
+type http2duplicatePseudoHeaderError string
+
+func (e http2duplicatePseudoHeaderError) Error() string {
+ return fmt.Sprintf("duplicate pseudo-header %q", string(e))
+}
+
+type http2headerFieldNameError string
+
+func (e http2headerFieldNameError) Error() string {
+ return fmt.Sprintf("invalid header field name %q", string(e))
+}
+
+type http2headerFieldValueError string
+
+func (e http2headerFieldValueError) Error() string {
+ return fmt.Sprintf("invalid header field value for %q", string(e))
+}
+
+var (
+ http2errMixPseudoHeaderTypes = errors.New("mix of request and response pseudo headers")
+ http2errPseudoAfterRegular = errors.New("pseudo header field after regular")
+)
+
+// inflowMinRefresh is the minimum number of bytes we'll send for a
+// flow control window update.
+const http2inflowMinRefresh = 4 << 10
+
+// inflow accounts for an inbound flow control window.
+// It tracks both the latest window sent to the peer (used for enforcement)
+// and the accumulated unsent window.
+type http2inflow struct {
+ avail int32
+ unsent int32
+}
+
+// init sets the initial window.
+func (f *http2inflow) init(n int32) {
+ f.avail = n
+}
+
+// add adds n bytes to the window, with a maximum window size of max,
+// indicating that the peer can now send us more data.
+// For example, the user read from a {Request,Response} body and consumed
+// some of the buffered data, so the peer can now send more.
+// It returns the number of bytes to send in a WINDOW_UPDATE frame to the peer.
+// Window updates are accumulated and sent when the unsent capacity
+// is at least inflowMinRefresh or will at least double the peer's available window.
+func (f *http2inflow) add(n int) (connAdd int32) {
+ if n < 0 {
+ panic("negative update")
+ }
+ unsent := int64(f.unsent) + int64(n)
+ // "A sender MUST NOT allow a flow-control window to exceed 2^31-1 octets."
+ // RFC 7540 Section 6.9.1.
+ const maxWindow = 1<<31 - 1
+ if unsent+int64(f.avail) > maxWindow {
+ panic("flow control update exceeds maximum window size")
+ }
+ f.unsent = int32(unsent)
+ if f.unsent < http2inflowMinRefresh && f.unsent < f.avail {
+ // If there aren't at least inflowMinRefresh bytes of window to send,
+ // and this update won't at least double the window, buffer the update for later.
+ return 0
+ }
+ f.avail += f.unsent
+ f.unsent = 0
+ return int32(unsent)
+}
+
+// take attempts to take n bytes from the peer's flow control window.
+// It reports whether the window has available capacity.
+func (f *http2inflow) take(n uint32) bool {
+ if n > uint32(f.avail) {
+ return false
+ }
+ f.avail -= int32(n)
+ return true
+}
+
+// takeInflows attempts to take n bytes from two inflows,
+// typically connection-level and stream-level flows.
+// It reports whether both windows have available capacity.
+func http2takeInflows(f1, f2 *http2inflow, n uint32) bool {
+ if n > uint32(f1.avail) || n > uint32(f2.avail) {
+ return false
+ }
+ f1.avail -= int32(n)
+ f2.avail -= int32(n)
+ return true
+}
+
+// outflow is the outbound flow control window's size.
+type http2outflow struct {
+ _ http2incomparable
+
+ // n is the number of DATA bytes we're allowed to send.
+ // An outflow is kept both on a conn and a per-stream.
+ n int32
+
+ // conn points to the shared connection-level outflow that is
+ // shared by all streams on that conn. It is nil for the outflow
+ // that's on the conn directly.
+ conn *http2outflow
+}
+
+func (f *http2outflow) setConnFlow(cf *http2outflow) { f.conn = cf }
+
+func (f *http2outflow) available() int32 {
+ n := f.n
+ if f.conn != nil && f.conn.n < n {
+ n = f.conn.n
+ }
+ return n
+}
+
+func (f *http2outflow) take(n int32) {
+ if n > f.available() {
+ panic("internal error: took too much")
+ }
+ f.n -= n
+ if f.conn != nil {
+ f.conn.n -= n
+ }
+}
+
+// add adds n bytes (positive or negative) to the flow control window.
+// It returns false if the sum would exceed 2^31-1.
+func (f *http2outflow) add(n int32) bool {
+ sum := f.n + n
+ if (sum > n) == (f.n > 0) {
+ f.n = sum
+ return true
+ }
+ return false
+}
+
+const http2frameHeaderLen = 9
+
+var http2padZeros = make([]byte, 255) // zeros for padding
+
+// A FrameType is a registered frame type as defined in
+// https://httpwg.org/specs/rfc7540.html#rfc.section.11.2
+type http2FrameType uint8
+
+const (
+ http2FrameData http2FrameType = 0x0
+ http2FrameHeaders http2FrameType = 0x1
+ http2FramePriority http2FrameType = 0x2
+ http2FrameRSTStream http2FrameType = 0x3
+ http2FrameSettings http2FrameType = 0x4
+ http2FramePushPromise http2FrameType = 0x5
+ http2FramePing http2FrameType = 0x6
+ http2FrameGoAway http2FrameType = 0x7
+ http2FrameWindowUpdate http2FrameType = 0x8
+ http2FrameContinuation http2FrameType = 0x9
+)
+
+var http2frameName = map[http2FrameType]string{
+ http2FrameData: "DATA",
+ http2FrameHeaders: "HEADERS",
+ http2FramePriority: "PRIORITY",
+ http2FrameRSTStream: "RST_STREAM",
+ http2FrameSettings: "SETTINGS",
+ http2FramePushPromise: "PUSH_PROMISE",
+ http2FramePing: "PING",
+ http2FrameGoAway: "GOAWAY",
+ http2FrameWindowUpdate: "WINDOW_UPDATE",
+ http2FrameContinuation: "CONTINUATION",
+}
+
+func (t http2FrameType) String() string {
+ if s, ok := http2frameName[t]; ok {
+ return s
+ }
+ return fmt.Sprintf("UNKNOWN_FRAME_TYPE_%d", uint8(t))
+}
+
+// Flags is a bitmask of HTTP/2 flags.
+// The meaning of flags varies depending on the frame type.
+type http2Flags uint8
+
+// Has reports whether f contains all (0 or more) flags in v.
+func (f http2Flags) Has(v http2Flags) bool {
+ return (f & v) == v
+}
+
+// Frame-specific FrameHeader flag bits.
+const (
+ // Data Frame
+ http2FlagDataEndStream http2Flags = 0x1
+ http2FlagDataPadded http2Flags = 0x8
+
+ // Headers Frame
+ http2FlagHeadersEndStream http2Flags = 0x1
+ http2FlagHeadersEndHeaders http2Flags = 0x4
+ http2FlagHeadersPadded http2Flags = 0x8
+ http2FlagHeadersPriority http2Flags = 0x20
+
+ // Settings Frame
+ http2FlagSettingsAck http2Flags = 0x1
+
+ // Ping Frame
+ http2FlagPingAck http2Flags = 0x1
+
+ // Continuation Frame
+ http2FlagContinuationEndHeaders http2Flags = 0x4
+
+ http2FlagPushPromiseEndHeaders http2Flags = 0x4
+ http2FlagPushPromisePadded http2Flags = 0x8
+)
+
+var http2flagName = map[http2FrameType]map[http2Flags]string{
+ http2FrameData: {
+ http2FlagDataEndStream: "END_STREAM",
+ http2FlagDataPadded: "PADDED",
+ },
+ http2FrameHeaders: {
+ http2FlagHeadersEndStream: "END_STREAM",
+ http2FlagHeadersEndHeaders: "END_HEADERS",
+ http2FlagHeadersPadded: "PADDED",
+ http2FlagHeadersPriority: "PRIORITY",
+ },
+ http2FrameSettings: {
+ http2FlagSettingsAck: "ACK",
+ },
+ http2FramePing: {
+ http2FlagPingAck: "ACK",
+ },
+ http2FrameContinuation: {
+ http2FlagContinuationEndHeaders: "END_HEADERS",
+ },
+ http2FramePushPromise: {
+ http2FlagPushPromiseEndHeaders: "END_HEADERS",
+ http2FlagPushPromisePadded: "PADDED",
+ },
+}
+
+// a frameParser parses a frame given its FrameHeader and payload
+// bytes. The length of payload will always equal fh.Length (which
+// might be 0).
+type http2frameParser func(fc *http2frameCache, fh http2FrameHeader, countError func(string), payload []byte) (http2Frame, error)
+
+var http2frameParsers = map[http2FrameType]http2frameParser{
+ http2FrameData: http2parseDataFrame,
+ http2FrameHeaders: http2parseHeadersFrame,
+ http2FramePriority: http2parsePriorityFrame,
+ http2FrameRSTStream: http2parseRSTStreamFrame,
+ http2FrameSettings: http2parseSettingsFrame,
+ http2FramePushPromise: http2parsePushPromise,
+ http2FramePing: http2parsePingFrame,
+ http2FrameGoAway: http2parseGoAwayFrame,
+ http2FrameWindowUpdate: http2parseWindowUpdateFrame,
+ http2FrameContinuation: http2parseContinuationFrame,
+}
+
+func http2typeFrameParser(t http2FrameType) http2frameParser {
+ if f := http2frameParsers[t]; f != nil {
+ return f
+ }
+ return http2parseUnknownFrame
+}
+
+// A FrameHeader is the 9 byte header of all HTTP/2 frames.
+//
+// See https://httpwg.org/specs/rfc7540.html#FrameHeader
+type http2FrameHeader struct {
+ valid bool // caller can access []byte fields in the Frame
+
+ // Type is the 1 byte frame type. There are ten standard frame
+ // types, but extension frame types may be written by WriteRawFrame
+ // and will be returned by ReadFrame (as UnknownFrame).
+ Type http2FrameType
+
+ // Flags are the 1 byte of 8 potential bit flags per frame.
+ // They are specific to the frame type.
+ Flags http2Flags
+
+ // Length is the length of the frame, not including the 9 byte header.
+ // The maximum size is one byte less than 16MB (uint24), but only
+ // frames up to 16KB are allowed without peer agreement.
+ Length uint32
+
+ // StreamID is which stream this frame is for. Certain frames
+ // are not stream-specific, in which case this field is 0.
+ StreamID uint32
+}
+
+// Header returns h. It exists so FrameHeaders can be embedded in other
+// specific frame types and implement the Frame interface.
+func (h http2FrameHeader) Header() http2FrameHeader { return h }
+
+func (h http2FrameHeader) String() string {
+ var buf bytes.Buffer
+ buf.WriteString("[FrameHeader ")
+ h.writeDebug(&buf)
+ buf.WriteByte(']')
+ return buf.String()
+}
+
+func (h http2FrameHeader) writeDebug(buf *bytes.Buffer) {
+ buf.WriteString(h.Type.String())
+ if h.Flags != 0 {
+ buf.WriteString(" flags=")
+ set := 0
+ for i := uint8(0); i < 8; i++ {
+ if h.Flags&(1<<i) == 0 {
+ continue
+ }
+ set++
+ if set > 1 {
+ buf.WriteByte('|')
+ }
+ name := http2flagName[h.Type][http2Flags(1<<i)]
+ if name != "" {
+ buf.WriteString(name)
+ } else {
+ fmt.Fprintf(buf, "0x%x", 1<<i)
+ }
+ }
+ }
+ if h.StreamID != 0 {
+ fmt.Fprintf(buf, " stream=%d", h.StreamID)
+ }
+ fmt.Fprintf(buf, " len=%d", h.Length)
+}
+
+func (h *http2FrameHeader) checkValid() {
+ if !h.valid {
+ panic("Frame accessor called on non-owned Frame")
+ }
+}
+
+func (h *http2FrameHeader) invalidate() { h.valid = false }
+
+// frame header bytes.
+// Used only by ReadFrameHeader.
+var http2fhBytes = sync.Pool{
+ New: func() interface{} {
+ buf := make([]byte, http2frameHeaderLen)
+ return &buf
+ },
+}
+
+// ReadFrameHeader reads 9 bytes from r and returns a FrameHeader.
+// Most users should use Framer.ReadFrame instead.
+func http2ReadFrameHeader(r io.Reader) (http2FrameHeader, error) {
+ bufp := http2fhBytes.Get().(*[]byte)
+ defer http2fhBytes.Put(bufp)
+ return http2readFrameHeader(*bufp, r)
+}
+
+func http2readFrameHeader(buf []byte, r io.Reader) (http2FrameHeader, error) {
+ _, err := io.ReadFull(r, buf[:http2frameHeaderLen])
+ if err != nil {
+ return http2FrameHeader{}, err
+ }
+ return http2FrameHeader{
+ Length: (uint32(buf[0])<<16 | uint32(buf[1])<<8 | uint32(buf[2])),
+ Type: http2FrameType(buf[3]),
+ Flags: http2Flags(buf[4]),
+ StreamID: binary.BigEndian.Uint32(buf[5:]) & (1<<31 - 1),
+ valid: true,
+ }, nil
+}
+
+// A Frame is the base interface implemented by all frame types.
+// Callers will generally type-assert the specific frame type:
+// *HeadersFrame, *SettingsFrame, *WindowUpdateFrame, etc.
+//
+// Frames are only valid until the next call to Framer.ReadFrame.
+type http2Frame interface {
+ Header() http2FrameHeader
+
+ // invalidate is called by Framer.ReadFrame to make this
+ // frame's buffers as being invalid, since the subsequent
+ // frame will reuse them.
+ invalidate()
+}
+
+// A Framer reads and writes Frames.
+type http2Framer struct {
+ r io.Reader
+ lastFrame http2Frame
+ errDetail error
+
+ // countError is a non-nil func that's called on a frame parse
+ // error with some unique error path token. It's initialized
+ // from Transport.CountError or Server.CountError.
+ countError func(errToken string)
+
+ // lastHeaderStream is non-zero if the last frame was an
+ // unfinished HEADERS/CONTINUATION.
+ lastHeaderStream uint32
+
+ maxReadSize uint32
+ headerBuf [http2frameHeaderLen]byte
+
+ // TODO: let getReadBuf be configurable, and use a less memory-pinning
+ // allocator in server.go to minimize memory pinned for many idle conns.
+ // Will probably also need to make frame invalidation have a hook too.
+ getReadBuf func(size uint32) []byte
+ readBuf []byte // cache for default getReadBuf
+
+ maxWriteSize uint32 // zero means unlimited; TODO: implement
+
+ w io.Writer
+ wbuf []byte
+
+ // AllowIllegalWrites permits the Framer's Write methods to
+ // write frames that do not conform to the HTTP/2 spec. This
+ // permits using the Framer to test other HTTP/2
+ // implementations' conformance to the spec.
+ // If false, the Write methods will prefer to return an error
+ // rather than comply.
+ AllowIllegalWrites bool
+
+ // AllowIllegalReads permits the Framer's ReadFrame method
+ // to return non-compliant frames or frame orders.
+ // This is for testing and permits using the Framer to test
+ // other HTTP/2 implementations' conformance to the spec.
+ // It is not compatible with ReadMetaHeaders.
+ AllowIllegalReads bool
+
+ // ReadMetaHeaders if non-nil causes ReadFrame to merge
+ // HEADERS and CONTINUATION frames together and return
+ // MetaHeadersFrame instead.
+ ReadMetaHeaders *hpack.Decoder
+
+ // MaxHeaderListSize is the http2 MAX_HEADER_LIST_SIZE.
+ // It's used only if ReadMetaHeaders is set; 0 means a sane default
+ // (currently 16MB)
+ // If the limit is hit, MetaHeadersFrame.Truncated is set true.
+ MaxHeaderListSize uint32
+
+ // TODO: track which type of frame & with which flags was sent
+ // last. Then return an error (unless AllowIllegalWrites) if
+ // we're in the middle of a header block and a
+ // non-Continuation or Continuation on a different stream is
+ // attempted to be written.
+
+ logReads, logWrites bool
+
+ debugFramer *http2Framer // only use for logging written writes
+ debugFramerBuf *bytes.Buffer
+ debugReadLoggerf func(string, ...interface{})
+ debugWriteLoggerf func(string, ...interface{})
+
+ frameCache *http2frameCache // nil if frames aren't reused (default)
+}
+
+func (fr *http2Framer) maxHeaderListSize() uint32 {
+ if fr.MaxHeaderListSize == 0 {
+ return 16 << 20 // sane default, per docs
+ }
+ return fr.MaxHeaderListSize
+}
+
+func (f *http2Framer) startWrite(ftype http2FrameType, flags http2Flags, streamID uint32) {
+ // Write the FrameHeader.
+ f.wbuf = append(f.wbuf[:0],
+ 0, // 3 bytes of length, filled in in endWrite
+ 0,
+ 0,
+ byte(ftype),
+ byte(flags),
+ byte(streamID>>24),
+ byte(streamID>>16),
+ byte(streamID>>8),
+ byte(streamID))
+}
+
+func (f *http2Framer) endWrite() error {
+ // Now that we know the final size, fill in the FrameHeader in
+ // the space previously reserved for it. Abuse append.
+ length := len(f.wbuf) - http2frameHeaderLen
+ if length >= (1 << 24) {
+ return http2ErrFrameTooLarge
+ }
+ _ = append(f.wbuf[:0],
+ byte(length>>16),
+ byte(length>>8),
+ byte(length))
+ if f.logWrites {
+ f.logWrite()
+ }
+
+ n, err := f.w.Write(f.wbuf)
+ if err == nil && n != len(f.wbuf) {
+ err = io.ErrShortWrite
+ }
+ return err
+}
+
+func (f *http2Framer) logWrite() {
+ if f.debugFramer == nil {
+ f.debugFramerBuf = new(bytes.Buffer)
+ f.debugFramer = http2NewFramer(nil, f.debugFramerBuf)
+ f.debugFramer.logReads = false // we log it ourselves, saying "wrote" below
+ // Let us read anything, even if we accidentally wrote it
+ // in the wrong order:
+ f.debugFramer.AllowIllegalReads = true
+ }
+ f.debugFramerBuf.Write(f.wbuf)
+ fr, err := f.debugFramer.ReadFrame()
+ if err != nil {
+ f.debugWriteLoggerf("http2: Framer %p: failed to decode just-written frame", f)
+ return
+ }
+ f.debugWriteLoggerf("http2: Framer %p: wrote %v", f, http2summarizeFrame(fr))
+}
+
+func (f *http2Framer) writeByte(v byte) { f.wbuf = append(f.wbuf, v) }
+
+func (f *http2Framer) writeBytes(v []byte) { f.wbuf = append(f.wbuf, v...) }
+
+func (f *http2Framer) writeUint16(v uint16) { f.wbuf = append(f.wbuf, byte(v>>8), byte(v)) }
+
+func (f *http2Framer) writeUint32(v uint32) {
+ f.wbuf = append(f.wbuf, byte(v>>24), byte(v>>16), byte(v>>8), byte(v))
+}
+
+const (
+ http2minMaxFrameSize = 1 << 14
+ http2maxFrameSize = 1<<24 - 1
+)
+
+// SetReuseFrames allows the Framer to reuse Frames.
+// If called on a Framer, Frames returned by calls to ReadFrame are only
+// valid until the next call to ReadFrame.
+func (fr *http2Framer) SetReuseFrames() {
+ if fr.frameCache != nil {
+ return
+ }
+ fr.frameCache = &http2frameCache{}
+}
+
+type http2frameCache struct {
+ dataFrame http2DataFrame
+}
+
+func (fc *http2frameCache) getDataFrame() *http2DataFrame {
+ if fc == nil {
+ return &http2DataFrame{}
+ }
+ return &fc.dataFrame
+}
+
+// NewFramer returns a Framer that writes frames to w and reads them from r.
+func http2NewFramer(w io.Writer, r io.Reader) *http2Framer {
+ fr := &http2Framer{
+ w: w,
+ r: r,
+ countError: func(string) {},
+ logReads: http2logFrameReads,
+ logWrites: http2logFrameWrites,
+ debugReadLoggerf: log.Printf,
+ debugWriteLoggerf: log.Printf,
+ }
+ fr.getReadBuf = func(size uint32) []byte {
+ if cap(fr.readBuf) >= int(size) {
+ return fr.readBuf[:size]
+ }
+ fr.readBuf = make([]byte, size)
+ return fr.readBuf
+ }
+ fr.SetMaxReadFrameSize(http2maxFrameSize)
+ return fr
+}
+
+// SetMaxReadFrameSize sets the maximum size of a frame
+// that will be read by a subsequent call to ReadFrame.
+// It is the caller's responsibility to advertise this
+// limit with a SETTINGS frame.
+func (fr *http2Framer) SetMaxReadFrameSize(v uint32) {
+ if v > http2maxFrameSize {
+ v = http2maxFrameSize
+ }
+ fr.maxReadSize = v
+}
+
+// ErrorDetail returns a more detailed error of the last error
+// returned by Framer.ReadFrame. For instance, if ReadFrame
+// returns a StreamError with code PROTOCOL_ERROR, ErrorDetail
+// will say exactly what was invalid. ErrorDetail is not guaranteed
+// to return a non-nil value and like the rest of the http2 package,
+// its return value is not protected by an API compatibility promise.
+// ErrorDetail is reset after the next call to ReadFrame.
+func (fr *http2Framer) ErrorDetail() error {
+ return fr.errDetail
+}
+
+// ErrFrameTooLarge is returned from Framer.ReadFrame when the peer
+// sends a frame that is larger than declared with SetMaxReadFrameSize.
+var http2ErrFrameTooLarge = errors.New("http2: frame too large")
+
+// terminalReadFrameError reports whether err is an unrecoverable
+// error from ReadFrame and no other frames should be read.
+func http2terminalReadFrameError(err error) bool {
+ if _, ok := err.(http2StreamError); ok {
+ return false
+ }
+ return err != nil
+}
+
+// ReadFrame reads a single frame. The returned Frame is only valid
+// until the next call to ReadFrame.
+//
+// If the frame is larger than previously set with SetMaxReadFrameSize, the
+// returned error is ErrFrameTooLarge. Other errors may be of type
+// ConnectionError, StreamError, or anything else from the underlying
+// reader.
+func (fr *http2Framer) ReadFrame() (http2Frame, error) {
+ fr.errDetail = nil
+ if fr.lastFrame != nil {
+ fr.lastFrame.invalidate()
+ }
+ fh, err := http2readFrameHeader(fr.headerBuf[:], fr.r)
+ if err != nil {
+ return nil, err
+ }
+ if fh.Length > fr.maxReadSize {
+ return nil, http2ErrFrameTooLarge
+ }
+ payload := fr.getReadBuf(fh.Length)
+ if _, err := io.ReadFull(fr.r, payload); err != nil {
+ return nil, err
+ }
+ f, err := http2typeFrameParser(fh.Type)(fr.frameCache, fh, fr.countError, payload)
+ if err != nil {
+ if ce, ok := err.(http2connError); ok {
+ return nil, fr.connError(ce.Code, ce.Reason)
+ }
+ return nil, err
+ }
+ if err := fr.checkFrameOrder(f); err != nil {
+ return nil, err
+ }
+ if fr.logReads {
+ fr.debugReadLoggerf("http2: Framer %p: read %v", fr, http2summarizeFrame(f))
+ }
+ if fh.Type == http2FrameHeaders && fr.ReadMetaHeaders != nil {
+ return fr.readMetaFrame(f.(*http2HeadersFrame))
+ }
+ return f, nil
+}
+
+// connError returns ConnectionError(code) but first
+// stashes away a public reason to the caller can optionally relay it
+// to the peer before hanging up on them. This might help others debug
+// their implementations.
+func (fr *http2Framer) connError(code http2ErrCode, reason string) error {
+ fr.errDetail = errors.New(reason)
+ return http2ConnectionError(code)
+}
+
+// checkFrameOrder reports an error if f is an invalid frame to return
+// next from ReadFrame. Mostly it checks whether HEADERS and
+// CONTINUATION frames are contiguous.
+func (fr *http2Framer) checkFrameOrder(f http2Frame) error {
+ last := fr.lastFrame
+ fr.lastFrame = f
+ if fr.AllowIllegalReads {
+ return nil
+ }
+
+ fh := f.Header()
+ if fr.lastHeaderStream != 0 {
+ if fh.Type != http2FrameContinuation {
+ return fr.connError(http2ErrCodeProtocol,
+ fmt.Sprintf("got %s for stream %d; expected CONTINUATION following %s for stream %d",
+ fh.Type, fh.StreamID,
+ last.Header().Type, fr.lastHeaderStream))
+ }
+ if fh.StreamID != fr.lastHeaderStream {
+ return fr.connError(http2ErrCodeProtocol,
+ fmt.Sprintf("got CONTINUATION for stream %d; expected stream %d",
+ fh.StreamID, fr.lastHeaderStream))
+ }
+ } else if fh.Type == http2FrameContinuation {
+ return fr.connError(http2ErrCodeProtocol, fmt.Sprintf("unexpected CONTINUATION for stream %d", fh.StreamID))
+ }
+
+ switch fh.Type {
+ case http2FrameHeaders, http2FrameContinuation:
+ if fh.Flags.Has(http2FlagHeadersEndHeaders) {
+ fr.lastHeaderStream = 0
+ } else {
+ fr.lastHeaderStream = fh.StreamID
+ }
+ }
+
+ return nil
+}
+
+// A DataFrame conveys arbitrary, variable-length sequences of octets
+// associated with a stream.
+// See https://httpwg.org/specs/rfc7540.html#rfc.section.6.1
+type http2DataFrame struct {
+ http2FrameHeader
+ data []byte
+}
+
+func (f *http2DataFrame) StreamEnded() bool {
+ return f.http2FrameHeader.Flags.Has(http2FlagDataEndStream)
+}
+
+// Data returns the frame's data octets, not including any padding
+// size byte or padding suffix bytes.
+// The caller must not retain the returned memory past the next
+// call to ReadFrame.
+func (f *http2DataFrame) Data() []byte {
+ f.checkValid()
+ return f.data
+}
+
+func http2parseDataFrame(fc *http2frameCache, fh http2FrameHeader, countError func(string), payload []byte) (http2Frame, error) {
+ if fh.StreamID == 0 {
+ // DATA frames MUST be associated with a stream. If a
+ // DATA frame is received whose stream identifier
+ // field is 0x0, the recipient MUST respond with a
+ // connection error (Section 5.4.1) of type
+ // PROTOCOL_ERROR.
+ countError("frame_data_stream_0")
+ return nil, http2connError{http2ErrCodeProtocol, "DATA frame with stream ID 0"}
+ }
+ f := fc.getDataFrame()
+ f.http2FrameHeader = fh
+
+ var padSize byte
+ if fh.Flags.Has(http2FlagDataPadded) {
+ var err error
+ payload, padSize, err = http2readByte(payload)
+ if err != nil {
+ countError("frame_data_pad_byte_short")
+ return nil, err
+ }
+ }
+ if int(padSize) > len(payload) {
+ // If the length of the padding is greater than the
+ // length of the frame payload, the recipient MUST
+ // treat this as a connection error.
+ // Filed: https://github.com/http2/http2-spec/issues/610
+ countError("frame_data_pad_too_big")
+ return nil, http2connError{http2ErrCodeProtocol, "pad size larger than data payload"}
+ }
+ f.data = payload[:len(payload)-int(padSize)]
+ return f, nil
+}
+
+var (
+ http2errStreamID = errors.New("invalid stream ID")
+ http2errDepStreamID = errors.New("invalid dependent stream ID")
+ http2errPadLength = errors.New("pad length too large")
+ http2errPadBytes = errors.New("padding bytes must all be zeros unless AllowIllegalWrites is enabled")
+)
+
+func http2validStreamIDOrZero(streamID uint32) bool {
+ return streamID&(1<<31) == 0
+}
+
+func http2validStreamID(streamID uint32) bool {
+ return streamID != 0 && streamID&(1<<31) == 0
+}
+
+// WriteData writes a DATA frame.
+//
+// It will perform exactly one Write to the underlying Writer.
+// It is the caller's responsibility not to violate the maximum frame size
+// and to not call other Write methods concurrently.
+func (f *http2Framer) WriteData(streamID uint32, endStream bool, data []byte) error {
+ return f.WriteDataPadded(streamID, endStream, data, nil)
+}
+
+// WriteDataPadded writes a DATA frame with optional padding.
+//
+// If pad is nil, the padding bit is not sent.
+// The length of pad must not exceed 255 bytes.
+// The bytes of pad must all be zero, unless f.AllowIllegalWrites is set.
+//
+// It will perform exactly one Write to the underlying Writer.
+// It is the caller's responsibility not to violate the maximum frame size
+// and to not call other Write methods concurrently.
+func (f *http2Framer) WriteDataPadded(streamID uint32, endStream bool, data, pad []byte) error {
+ if err := f.startWriteDataPadded(streamID, endStream, data, pad); err != nil {
+ return err
+ }
+ return f.endWrite()
+}
+
+// startWriteDataPadded is WriteDataPadded, but only writes the frame to the Framer's internal buffer.
+// The caller should call endWrite to flush the frame to the underlying writer.
+func (f *http2Framer) startWriteDataPadded(streamID uint32, endStream bool, data, pad []byte) error {
+ if !http2validStreamID(streamID) && !f.AllowIllegalWrites {
+ return http2errStreamID
+ }
+ if len(pad) > 0 {
+ if len(pad) > 255 {
+ return http2errPadLength
+ }
+ if !f.AllowIllegalWrites {
+ for _, b := range pad {
+ if b != 0 {
+ // "Padding octets MUST be set to zero when sending."
+ return http2errPadBytes
+ }
+ }
+ }
+ }
+ var flags http2Flags
+ if endStream {
+ flags |= http2FlagDataEndStream
+ }
+ if pad != nil {
+ flags |= http2FlagDataPadded
+ }
+ f.startWrite(http2FrameData, flags, streamID)
+ if pad != nil {
+ f.wbuf = append(f.wbuf, byte(len(pad)))
+ }
+ f.wbuf = append(f.wbuf, data...)
+ f.wbuf = append(f.wbuf, pad...)
+ return nil
+}
+
+// A SettingsFrame conveys configuration parameters that affect how
+// endpoints communicate, such as preferences and constraints on peer
+// behavior.
+//
+// See https://httpwg.org/specs/rfc7540.html#SETTINGS
+type http2SettingsFrame struct {
+ http2FrameHeader
+ p []byte
+}
+
+func http2parseSettingsFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) {
+ if fh.Flags.Has(http2FlagSettingsAck) && fh.Length > 0 {
+ // When this (ACK 0x1) bit is set, the payload of the
+ // SETTINGS frame MUST be empty. Receipt of a
+ // SETTINGS frame with the ACK flag set and a length
+ // field value other than 0 MUST be treated as a
+ // connection error (Section 5.4.1) of type
+ // FRAME_SIZE_ERROR.
+ countError("frame_settings_ack_with_length")
+ return nil, http2ConnectionError(http2ErrCodeFrameSize)
+ }
+ if fh.StreamID != 0 {
+ // SETTINGS frames always apply to a connection,
+ // never a single stream. The stream identifier for a
+ // SETTINGS frame MUST be zero (0x0). If an endpoint
+ // receives a SETTINGS frame whose stream identifier
+ // field is anything other than 0x0, the endpoint MUST
+ // respond with a connection error (Section 5.4.1) of
+ // type PROTOCOL_ERROR.
+ countError("frame_settings_has_stream")
+ return nil, http2ConnectionError(http2ErrCodeProtocol)
+ }
+ if len(p)%6 != 0 {
+ countError("frame_settings_mod_6")
+ // Expecting even number of 6 byte settings.
+ return nil, http2ConnectionError(http2ErrCodeFrameSize)
+ }
+ f := &http2SettingsFrame{http2FrameHeader: fh, p: p}
+ if v, ok := f.Value(http2SettingInitialWindowSize); ok && v > (1<<31)-1 {
+ countError("frame_settings_window_size_too_big")
+ // Values above the maximum flow control window size of 2^31 - 1 MUST
+ // be treated as a connection error (Section 5.4.1) of type
+ // FLOW_CONTROL_ERROR.
+ return nil, http2ConnectionError(http2ErrCodeFlowControl)
+ }
+ return f, nil
+}
+
+func (f *http2SettingsFrame) IsAck() bool {
+ return f.http2FrameHeader.Flags.Has(http2FlagSettingsAck)
+}
+
+func (f *http2SettingsFrame) Value(id http2SettingID) (v uint32, ok bool) {
+ f.checkValid()
+ for i := 0; i < f.NumSettings(); i++ {
+ if s := f.Setting(i); s.ID == id {
+ return s.Val, true
+ }
+ }
+ return 0, false
+}
+
+// Setting returns the setting from the frame at the given 0-based index.
+// The index must be >= 0 and less than f.NumSettings().
+func (f *http2SettingsFrame) Setting(i int) http2Setting {
+ buf := f.p
+ return http2Setting{
+ ID: http2SettingID(binary.BigEndian.Uint16(buf[i*6 : i*6+2])),
+ Val: binary.BigEndian.Uint32(buf[i*6+2 : i*6+6]),
+ }
+}
+
+func (f *http2SettingsFrame) NumSettings() int { return len(f.p) / 6 }
+
+// HasDuplicates reports whether f contains any duplicate setting IDs.
+func (f *http2SettingsFrame) HasDuplicates() bool {
+ num := f.NumSettings()
+ if num == 0 {
+ return false
+ }
+ // If it's small enough (the common case), just do the n^2
+ // thing and avoid a map allocation.
+ if num < 10 {
+ for i := 0; i < num; i++ {
+ idi := f.Setting(i).ID
+ for j := i + 1; j < num; j++ {
+ idj := f.Setting(j).ID
+ if idi == idj {
+ return true
+ }
+ }
+ }
+ return false
+ }
+ seen := map[http2SettingID]bool{}
+ for i := 0; i < num; i++ {
+ id := f.Setting(i).ID
+ if seen[id] {
+ return true
+ }
+ seen[id] = true
+ }
+ return false
+}
+
+// ForeachSetting runs fn for each setting.
+// It stops and returns the first error.
+func (f *http2SettingsFrame) ForeachSetting(fn func(http2Setting) error) error {
+ f.checkValid()
+ for i := 0; i < f.NumSettings(); i++ {
+ if err := fn(f.Setting(i)); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// WriteSettings writes a SETTINGS frame with zero or more settings
+// specified and the ACK bit not set.
+//
+// It will perform exactly one Write to the underlying Writer.
+// It is the caller's responsibility to not call other Write methods concurrently.
+func (f *http2Framer) WriteSettings(settings ...http2Setting) error {
+ f.startWrite(http2FrameSettings, 0, 0)
+ for _, s := range settings {
+ f.writeUint16(uint16(s.ID))
+ f.writeUint32(s.Val)
+ }
+ return f.endWrite()
+}
+
+// WriteSettingsAck writes an empty SETTINGS frame with the ACK bit set.
+//
+// It will perform exactly one Write to the underlying Writer.
+// It is the caller's responsibility to not call other Write methods concurrently.
+func (f *http2Framer) WriteSettingsAck() error {
+ f.startWrite(http2FrameSettings, http2FlagSettingsAck, 0)
+ return f.endWrite()
+}
+
+// A PingFrame is a mechanism for measuring a minimal round trip time
+// from the sender, as well as determining whether an idle connection
+// is still functional.
+// See https://httpwg.org/specs/rfc7540.html#rfc.section.6.7
+type http2PingFrame struct {
+ http2FrameHeader
+ Data [8]byte
+}
+
+func (f *http2PingFrame) IsAck() bool { return f.Flags.Has(http2FlagPingAck) }
+
+func http2parsePingFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), payload []byte) (http2Frame, error) {
+ if len(payload) != 8 {
+ countError("frame_ping_length")
+ return nil, http2ConnectionError(http2ErrCodeFrameSize)
+ }
+ if fh.StreamID != 0 {
+ countError("frame_ping_has_stream")
+ return nil, http2ConnectionError(http2ErrCodeProtocol)
+ }
+ f := &http2PingFrame{http2FrameHeader: fh}
+ copy(f.Data[:], payload)
+ return f, nil
+}
+
+func (f *http2Framer) WritePing(ack bool, data [8]byte) error {
+ var flags http2Flags
+ if ack {
+ flags = http2FlagPingAck
+ }
+ f.startWrite(http2FramePing, flags, 0)
+ f.writeBytes(data[:])
+ return f.endWrite()
+}
+
+// A GoAwayFrame informs the remote peer to stop creating streams on this connection.
+// See https://httpwg.org/specs/rfc7540.html#rfc.section.6.8
+type http2GoAwayFrame struct {
+ http2FrameHeader
+ LastStreamID uint32
+ ErrCode http2ErrCode
+ debugData []byte
+}
+
+// DebugData returns any debug data in the GOAWAY frame. Its contents
+// are not defined.
+// The caller must not retain the returned memory past the next
+// call to ReadFrame.
+func (f *http2GoAwayFrame) DebugData() []byte {
+ f.checkValid()
+ return f.debugData
+}
+
+func http2parseGoAwayFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) {
+ if fh.StreamID != 0 {
+ countError("frame_goaway_has_stream")
+ return nil, http2ConnectionError(http2ErrCodeProtocol)
+ }
+ if len(p) < 8 {
+ countError("frame_goaway_short")
+ return nil, http2ConnectionError(http2ErrCodeFrameSize)
+ }
+ return &http2GoAwayFrame{
+ http2FrameHeader: fh,
+ LastStreamID: binary.BigEndian.Uint32(p[:4]) & (1<<31 - 1),
+ ErrCode: http2ErrCode(binary.BigEndian.Uint32(p[4:8])),
+ debugData: p[8:],
+ }, nil
+}
+
+func (f *http2Framer) WriteGoAway(maxStreamID uint32, code http2ErrCode, debugData []byte) error {
+ f.startWrite(http2FrameGoAway, 0, 0)
+ f.writeUint32(maxStreamID & (1<<31 - 1))
+ f.writeUint32(uint32(code))
+ f.writeBytes(debugData)
+ return f.endWrite()
+}
+
+// An UnknownFrame is the frame type returned when the frame type is unknown
+// or no specific frame type parser exists.
+type http2UnknownFrame struct {
+ http2FrameHeader
+ p []byte
+}
+
+// Payload returns the frame's payload (after the header). It is not
+// valid to call this method after a subsequent call to
+// Framer.ReadFrame, nor is it valid to retain the returned slice.
+// The memory is owned by the Framer and is invalidated when the next
+// frame is read.
+func (f *http2UnknownFrame) Payload() []byte {
+ f.checkValid()
+ return f.p
+}
+
+func http2parseUnknownFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) {
+ return &http2UnknownFrame{fh, p}, nil
+}
+
+// A WindowUpdateFrame is used to implement flow control.
+// See https://httpwg.org/specs/rfc7540.html#rfc.section.6.9
+type http2WindowUpdateFrame struct {
+ http2FrameHeader
+ Increment uint32 // never read with high bit set
+}
+
+func http2parseWindowUpdateFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) {
+ if len(p) != 4 {
+ countError("frame_windowupdate_bad_len")
+ return nil, http2ConnectionError(http2ErrCodeFrameSize)
+ }
+ inc := binary.BigEndian.Uint32(p[:4]) & 0x7fffffff // mask off high reserved bit
+ if inc == 0 {
+ // A receiver MUST treat the receipt of a
+ // WINDOW_UPDATE frame with an flow control window
+ // increment of 0 as a stream error (Section 5.4.2) of
+ // type PROTOCOL_ERROR; errors on the connection flow
+ // control window MUST be treated as a connection
+ // error (Section 5.4.1).
+ if fh.StreamID == 0 {
+ countError("frame_windowupdate_zero_inc_conn")
+ return nil, http2ConnectionError(http2ErrCodeProtocol)
+ }
+ countError("frame_windowupdate_zero_inc_stream")
+ return nil, http2streamError(fh.StreamID, http2ErrCodeProtocol)
+ }
+ return &http2WindowUpdateFrame{
+ http2FrameHeader: fh,
+ Increment: inc,
+ }, nil
+}
+
+// WriteWindowUpdate writes a WINDOW_UPDATE frame.
+// The increment value must be between 1 and 2,147,483,647, inclusive.
+// If the Stream ID is zero, the window update applies to the
+// connection as a whole.
+func (f *http2Framer) WriteWindowUpdate(streamID, incr uint32) error {
+ // "The legal range for the increment to the flow control window is 1 to 2^31-1 (2,147,483,647) octets."
+ if (incr < 1 || incr > 2147483647) && !f.AllowIllegalWrites {
+ return errors.New("illegal window increment value")
+ }
+ f.startWrite(http2FrameWindowUpdate, 0, streamID)
+ f.writeUint32(incr)
+ return f.endWrite()
+}
+
+// A HeadersFrame is used to open a stream and additionally carries a
+// header block fragment.
+type http2HeadersFrame struct {
+ http2FrameHeader
+
+ // Priority is set if FlagHeadersPriority is set in the FrameHeader.
+ Priority http2PriorityParam
+
+ headerFragBuf []byte // not owned
+}
+
+func (f *http2HeadersFrame) HeaderBlockFragment() []byte {
+ f.checkValid()
+ return f.headerFragBuf
+}
+
+func (f *http2HeadersFrame) HeadersEnded() bool {
+ return f.http2FrameHeader.Flags.Has(http2FlagHeadersEndHeaders)
+}
+
+func (f *http2HeadersFrame) StreamEnded() bool {
+ return f.http2FrameHeader.Flags.Has(http2FlagHeadersEndStream)
+}
+
+func (f *http2HeadersFrame) HasPriority() bool {
+ return f.http2FrameHeader.Flags.Has(http2FlagHeadersPriority)
+}
+
+func http2parseHeadersFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (_ http2Frame, err error) {
+ hf := &http2HeadersFrame{
+ http2FrameHeader: fh,
+ }
+ if fh.StreamID == 0 {
+ // HEADERS frames MUST be associated with a stream. If a HEADERS frame
+ // is received whose stream identifier field is 0x0, the recipient MUST
+ // respond with a connection error (Section 5.4.1) of type
+ // PROTOCOL_ERROR.
+ countError("frame_headers_zero_stream")
+ return nil, http2connError{http2ErrCodeProtocol, "HEADERS frame with stream ID 0"}
+ }
+ var padLength uint8
+ if fh.Flags.Has(http2FlagHeadersPadded) {
+ if p, padLength, err = http2readByte(p); err != nil {
+ countError("frame_headers_pad_short")
+ return
+ }
+ }
+ if fh.Flags.Has(http2FlagHeadersPriority) {
+ var v uint32
+ p, v, err = http2readUint32(p)
+ if err != nil {
+ countError("frame_headers_prio_short")
+ return nil, err
+ }
+ hf.Priority.StreamDep = v & 0x7fffffff
+ hf.Priority.Exclusive = (v != hf.Priority.StreamDep) // high bit was set
+ p, hf.Priority.Weight, err = http2readByte(p)
+ if err != nil {
+ countError("frame_headers_prio_weight_short")
+ return nil, err
+ }
+ }
+ if len(p)-int(padLength) < 0 {
+ countError("frame_headers_pad_too_big")
+ return nil, http2streamError(fh.StreamID, http2ErrCodeProtocol)
+ }
+ hf.headerFragBuf = p[:len(p)-int(padLength)]
+ return hf, nil
+}
+
+// HeadersFrameParam are the parameters for writing a HEADERS frame.
+type http2HeadersFrameParam struct {
+ // StreamID is the required Stream ID to initiate.
+ StreamID uint32
+ // BlockFragment is part (or all) of a Header Block.
+ BlockFragment []byte
+
+ // EndStream indicates that the header block is the last that
+ // the endpoint will send for the identified stream. Setting
+ // this flag causes the stream to enter one of "half closed"
+ // states.
+ EndStream bool
+
+ // EndHeaders indicates that this frame contains an entire
+ // header block and is not followed by any
+ // CONTINUATION frames.
+ EndHeaders bool
+
+ // PadLength is the optional number of bytes of zeros to add
+ // to this frame.
+ PadLength uint8
+
+ // Priority, if non-zero, includes stream priority information
+ // in the HEADER frame.
+ Priority http2PriorityParam
+}
+
+// WriteHeaders writes a single HEADERS frame.
+//
+// This is a low-level header writing method. Encoding headers and
+// splitting them into any necessary CONTINUATION frames is handled
+// elsewhere.
+//
+// It will perform exactly one Write to the underlying Writer.
+// It is the caller's responsibility to not call other Write methods concurrently.
+func (f *http2Framer) WriteHeaders(p http2HeadersFrameParam) error {
+ if !http2validStreamID(p.StreamID) && !f.AllowIllegalWrites {
+ return http2errStreamID
+ }
+ var flags http2Flags
+ if p.PadLength != 0 {
+ flags |= http2FlagHeadersPadded
+ }
+ if p.EndStream {
+ flags |= http2FlagHeadersEndStream
+ }
+ if p.EndHeaders {
+ flags |= http2FlagHeadersEndHeaders
+ }
+ if !p.Priority.IsZero() {
+ flags |= http2FlagHeadersPriority
+ }
+ f.startWrite(http2FrameHeaders, flags, p.StreamID)
+ if p.PadLength != 0 {
+ f.writeByte(p.PadLength)
+ }
+ if !p.Priority.IsZero() {
+ v := p.Priority.StreamDep
+ if !http2validStreamIDOrZero(v) && !f.AllowIllegalWrites {
+ return http2errDepStreamID
+ }
+ if p.Priority.Exclusive {
+ v |= 1 << 31
+ }
+ f.writeUint32(v)
+ f.writeByte(p.Priority.Weight)
+ }
+ f.wbuf = append(f.wbuf, p.BlockFragment...)
+ f.wbuf = append(f.wbuf, http2padZeros[:p.PadLength]...)
+ return f.endWrite()
+}
+
+// A PriorityFrame specifies the sender-advised priority of a stream.
+// See https://httpwg.org/specs/rfc7540.html#rfc.section.6.3
+type http2PriorityFrame struct {
+ http2FrameHeader
+ http2PriorityParam
+}
+
+// PriorityParam are the stream prioritzation parameters.
+type http2PriorityParam struct {
+ // StreamDep is a 31-bit stream identifier for the
+ // stream that this stream depends on. Zero means no
+ // dependency.
+ StreamDep uint32
+
+ // Exclusive is whether the dependency is exclusive.
+ Exclusive bool
+
+ // Weight is the stream's zero-indexed weight. It should be
+ // set together with StreamDep, or neither should be set. Per
+ // the spec, "Add one to the value to obtain a weight between
+ // 1 and 256."
+ Weight uint8
+}
+
+func (p http2PriorityParam) IsZero() bool {
+ return p == http2PriorityParam{}
+}
+
+func http2parsePriorityFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), payload []byte) (http2Frame, error) {
+ if fh.StreamID == 0 {
+ countError("frame_priority_zero_stream")
+ return nil, http2connError{http2ErrCodeProtocol, "PRIORITY frame with stream ID 0"}
+ }
+ if len(payload) != 5 {
+ countError("frame_priority_bad_length")
+ return nil, http2connError{http2ErrCodeFrameSize, fmt.Sprintf("PRIORITY frame payload size was %d; want 5", len(payload))}
+ }
+ v := binary.BigEndian.Uint32(payload[:4])
+ streamID := v & 0x7fffffff // mask off high bit
+ return &http2PriorityFrame{
+ http2FrameHeader: fh,
+ http2PriorityParam: http2PriorityParam{
+ Weight: payload[4],
+ StreamDep: streamID,
+ Exclusive: streamID != v, // was high bit set?
+ },
+ }, nil
+}
+
+// WritePriority writes a PRIORITY frame.
+//
+// It will perform exactly one Write to the underlying Writer.
+// It is the caller's responsibility to not call other Write methods concurrently.
+func (f *http2Framer) WritePriority(streamID uint32, p http2PriorityParam) error {
+ if !http2validStreamID(streamID) && !f.AllowIllegalWrites {
+ return http2errStreamID
+ }
+ if !http2validStreamIDOrZero(p.StreamDep) {
+ return http2errDepStreamID
+ }
+ f.startWrite(http2FramePriority, 0, streamID)
+ v := p.StreamDep
+ if p.Exclusive {
+ v |= 1 << 31
+ }
+ f.writeUint32(v)
+ f.writeByte(p.Weight)
+ return f.endWrite()
+}
+
+// A RSTStreamFrame allows for abnormal termination of a stream.
+// See https://httpwg.org/specs/rfc7540.html#rfc.section.6.4
+type http2RSTStreamFrame struct {
+ http2FrameHeader
+ ErrCode http2ErrCode
+}
+
+func http2parseRSTStreamFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) {
+ if len(p) != 4 {
+ countError("frame_rststream_bad_len")
+ return nil, http2ConnectionError(http2ErrCodeFrameSize)
+ }
+ if fh.StreamID == 0 {
+ countError("frame_rststream_zero_stream")
+ return nil, http2ConnectionError(http2ErrCodeProtocol)
+ }
+ return &http2RSTStreamFrame{fh, http2ErrCode(binary.BigEndian.Uint32(p[:4]))}, nil
+}
+
+// WriteRSTStream writes a RST_STREAM frame.
+//
+// It will perform exactly one Write to the underlying Writer.
+// It is the caller's responsibility to not call other Write methods concurrently.
+func (f *http2Framer) WriteRSTStream(streamID uint32, code http2ErrCode) error {
+ if !http2validStreamID(streamID) && !f.AllowIllegalWrites {
+ return http2errStreamID
+ }
+ f.startWrite(http2FrameRSTStream, 0, streamID)
+ f.writeUint32(uint32(code))
+ return f.endWrite()
+}
+
+// A ContinuationFrame is used to continue a sequence of header block fragments.
+// See https://httpwg.org/specs/rfc7540.html#rfc.section.6.10
+type http2ContinuationFrame struct {
+ http2FrameHeader
+ headerFragBuf []byte
+}
+
+func http2parseContinuationFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) {
+ if fh.StreamID == 0 {
+ countError("frame_continuation_zero_stream")
+ return nil, http2connError{http2ErrCodeProtocol, "CONTINUATION frame with stream ID 0"}
+ }
+ return &http2ContinuationFrame{fh, p}, nil
+}
+
+func (f *http2ContinuationFrame) HeaderBlockFragment() []byte {
+ f.checkValid()
+ return f.headerFragBuf
+}
+
+func (f *http2ContinuationFrame) HeadersEnded() bool {
+ return f.http2FrameHeader.Flags.Has(http2FlagContinuationEndHeaders)
+}
+
+// WriteContinuation writes a CONTINUATION frame.
+//
+// It will perform exactly one Write to the underlying Writer.
+// It is the caller's responsibility to not call other Write methods concurrently.
+func (f *http2Framer) WriteContinuation(streamID uint32, endHeaders bool, headerBlockFragment []byte) error {
+ if !http2validStreamID(streamID) && !f.AllowIllegalWrites {
+ return http2errStreamID
+ }
+ var flags http2Flags
+ if endHeaders {
+ flags |= http2FlagContinuationEndHeaders
+ }
+ f.startWrite(http2FrameContinuation, flags, streamID)
+ f.wbuf = append(f.wbuf, headerBlockFragment...)
+ return f.endWrite()
+}
+
+// A PushPromiseFrame is used to initiate a server stream.
+// See https://httpwg.org/specs/rfc7540.html#rfc.section.6.6
+type http2PushPromiseFrame struct {
+ http2FrameHeader
+ PromiseID uint32
+ headerFragBuf []byte // not owned
+}
+
+func (f *http2PushPromiseFrame) HeaderBlockFragment() []byte {
+ f.checkValid()
+ return f.headerFragBuf
+}
+
+func (f *http2PushPromiseFrame) HeadersEnded() bool {
+ return f.http2FrameHeader.Flags.Has(http2FlagPushPromiseEndHeaders)
+}
+
+func http2parsePushPromise(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (_ http2Frame, err error) {
+ pp := &http2PushPromiseFrame{
+ http2FrameHeader: fh,
+ }
+ if pp.StreamID == 0 {
+ // PUSH_PROMISE frames MUST be associated with an existing,
+ // peer-initiated stream. The stream identifier of a
+ // PUSH_PROMISE frame indicates the stream it is associated
+ // with. If the stream identifier field specifies the value
+ // 0x0, a recipient MUST respond with a connection error
+ // (Section 5.4.1) of type PROTOCOL_ERROR.
+ countError("frame_pushpromise_zero_stream")
+ return nil, http2ConnectionError(http2ErrCodeProtocol)
+ }
+ // The PUSH_PROMISE frame includes optional padding.
+ // Padding fields and flags are identical to those defined for DATA frames
+ var padLength uint8
+ if fh.Flags.Has(http2FlagPushPromisePadded) {
+ if p, padLength, err = http2readByte(p); err != nil {
+ countError("frame_pushpromise_pad_short")
+ return
+ }
+ }
+
+ p, pp.PromiseID, err = http2readUint32(p)
+ if err != nil {
+ countError("frame_pushpromise_promiseid_short")
+ return
+ }
+ pp.PromiseID = pp.PromiseID & (1<<31 - 1)
+
+ if int(padLength) > len(p) {
+ // like the DATA frame, error out if padding is longer than the body.
+ countError("frame_pushpromise_pad_too_big")
+ return nil, http2ConnectionError(http2ErrCodeProtocol)
+ }
+ pp.headerFragBuf = p[:len(p)-int(padLength)]
+ return pp, nil
+}
+
+// PushPromiseParam are the parameters for writing a PUSH_PROMISE frame.
+type http2PushPromiseParam struct {
+ // StreamID is the required Stream ID to initiate.
+ StreamID uint32
+
+ // PromiseID is the required Stream ID which this
+ // Push Promises
+ PromiseID uint32
+
+ // BlockFragment is part (or all) of a Header Block.
+ BlockFragment []byte
+
+ // EndHeaders indicates that this frame contains an entire
+ // header block and is not followed by any
+ // CONTINUATION frames.
+ EndHeaders bool
+
+ // PadLength is the optional number of bytes of zeros to add
+ // to this frame.
+ PadLength uint8
+}
+
+// WritePushPromise writes a single PushPromise Frame.
+//
+// As with Header Frames, This is the low level call for writing
+// individual frames. Continuation frames are handled elsewhere.
+//
+// It will perform exactly one Write to the underlying Writer.
+// It is the caller's responsibility to not call other Write methods concurrently.
+func (f *http2Framer) WritePushPromise(p http2PushPromiseParam) error {
+ if !http2validStreamID(p.StreamID) && !f.AllowIllegalWrites {
+ return http2errStreamID
+ }
+ var flags http2Flags
+ if p.PadLength != 0 {
+ flags |= http2FlagPushPromisePadded
+ }
+ if p.EndHeaders {
+ flags |= http2FlagPushPromiseEndHeaders
+ }
+ f.startWrite(http2FramePushPromise, flags, p.StreamID)
+ if p.PadLength != 0 {
+ f.writeByte(p.PadLength)
+ }
+ if !http2validStreamID(p.PromiseID) && !f.AllowIllegalWrites {
+ return http2errStreamID
+ }
+ f.writeUint32(p.PromiseID)
+ f.wbuf = append(f.wbuf, p.BlockFragment...)
+ f.wbuf = append(f.wbuf, http2padZeros[:p.PadLength]...)
+ return f.endWrite()
+}
+
+// WriteRawFrame writes a raw frame. This can be used to write
+// extension frames unknown to this package.
+func (f *http2Framer) WriteRawFrame(t http2FrameType, flags http2Flags, streamID uint32, payload []byte) error {
+ f.startWrite(t, flags, streamID)
+ f.writeBytes(payload)
+ return f.endWrite()
+}
+
+func http2readByte(p []byte) (remain []byte, b byte, err error) {
+ if len(p) == 0 {
+ return nil, 0, io.ErrUnexpectedEOF
+ }
+ return p[1:], p[0], nil
+}
+
+func http2readUint32(p []byte) (remain []byte, v uint32, err error) {
+ if len(p) < 4 {
+ return nil, 0, io.ErrUnexpectedEOF
+ }
+ return p[4:], binary.BigEndian.Uint32(p[:4]), nil
+}
+
+type http2streamEnder interface {
+ StreamEnded() bool
+}
+
+type http2headersEnder interface {
+ HeadersEnded() bool
+}
+
+type http2headersOrContinuation interface {
+ http2headersEnder
+ HeaderBlockFragment() []byte
+}
+
+// A MetaHeadersFrame is the representation of one HEADERS frame and
+// zero or more contiguous CONTINUATION frames and the decoding of
+// their HPACK-encoded contents.
+//
+// This type of frame does not appear on the wire and is only returned
+// by the Framer when Framer.ReadMetaHeaders is set.
+type http2MetaHeadersFrame struct {
+ *http2HeadersFrame
+
+ // Fields are the fields contained in the HEADERS and
+ // CONTINUATION frames. The underlying slice is owned by the
+ // Framer and must not be retained after the next call to
+ // ReadFrame.
+ //
+ // Fields are guaranteed to be in the correct http2 order and
+ // not have unknown pseudo header fields or invalid header
+ // field names or values. Required pseudo header fields may be
+ // missing, however. Use the MetaHeadersFrame.Pseudo accessor
+ // method access pseudo headers.
+ Fields []hpack.HeaderField
+
+ // Truncated is whether the max header list size limit was hit
+ // and Fields is incomplete. The hpack decoder state is still
+ // valid, however.
+ Truncated bool
+}
+
+// PseudoValue returns the given pseudo header field's value.
+// The provided pseudo field should not contain the leading colon.
+func (mh *http2MetaHeadersFrame) PseudoValue(pseudo string) string {
+ for _, hf := range mh.Fields {
+ if !hf.IsPseudo() {
+ return ""
+ }
+ if hf.Name[1:] == pseudo {
+ return hf.Value
+ }
+ }
+ return ""
+}
+
+// RegularFields returns the regular (non-pseudo) header fields of mh.
+// The caller does not own the returned slice.
+func (mh *http2MetaHeadersFrame) RegularFields() []hpack.HeaderField {
+ for i, hf := range mh.Fields {
+ if !hf.IsPseudo() {
+ return mh.Fields[i:]
+ }
+ }
+ return nil
+}
+
+// PseudoFields returns the pseudo header fields of mh.
+// The caller does not own the returned slice.
+func (mh *http2MetaHeadersFrame) PseudoFields() []hpack.HeaderField {
+ for i, hf := range mh.Fields {
+ if !hf.IsPseudo() {
+ return mh.Fields[:i]
+ }
+ }
+ return mh.Fields
+}
+
+func (mh *http2MetaHeadersFrame) checkPseudos() error {
+ var isRequest, isResponse bool
+ pf := mh.PseudoFields()
+ for i, hf := range pf {
+ switch hf.Name {
+ case ":method", ":path", ":scheme", ":authority":
+ isRequest = true
+ case ":status":
+ isResponse = true
+ default:
+ return http2pseudoHeaderError(hf.Name)
+ }
+ // Check for duplicates.
+ // This would be a bad algorithm, but N is 4.
+ // And this doesn't allocate.
+ for _, hf2 := range pf[:i] {
+ if hf.Name == hf2.Name {
+ return http2duplicatePseudoHeaderError(hf.Name)
+ }
+ }
+ }
+ if isRequest && isResponse {
+ return http2errMixPseudoHeaderTypes
+ }
+ return nil
+}
+
+func (fr *http2Framer) maxHeaderStringLen() int {
+ v := fr.maxHeaderListSize()
+ if uint32(int(v)) == v {
+ return int(v)
+ }
+ // They had a crazy big number for MaxHeaderBytes anyway,
+ // so give them unlimited header lengths:
+ return 0
+}
+
+// readMetaFrame returns 0 or more CONTINUATION frames from fr and
+// merge them into the provided hf and returns a MetaHeadersFrame
+// with the decoded hpack values.
+func (fr *http2Framer) readMetaFrame(hf *http2HeadersFrame) (*http2MetaHeadersFrame, error) {
+ if fr.AllowIllegalReads {
+ return nil, errors.New("illegal use of AllowIllegalReads with ReadMetaHeaders")
+ }
+ mh := &http2MetaHeadersFrame{
+ http2HeadersFrame: hf,
+ }
+ var remainSize = fr.maxHeaderListSize()
+ var sawRegular bool
+
+ var invalid error // pseudo header field errors
+ hdec := fr.ReadMetaHeaders
+ hdec.SetEmitEnabled(true)
+ hdec.SetMaxStringLength(fr.maxHeaderStringLen())
+ hdec.SetEmitFunc(func(hf hpack.HeaderField) {
+ if http2VerboseLogs && fr.logReads {
+ fr.debugReadLoggerf("http2: decoded hpack field %+v", hf)
+ }
+ if !httpguts.ValidHeaderFieldValue(hf.Value) {
+ // Don't include the value in the error, because it may be sensitive.
+ invalid = http2headerFieldValueError(hf.Name)
+ }
+ isPseudo := strings.HasPrefix(hf.Name, ":")
+ if isPseudo {
+ if sawRegular {
+ invalid = http2errPseudoAfterRegular
+ }
+ } else {
+ sawRegular = true
+ if !http2validWireHeaderFieldName(hf.Name) {
+ invalid = http2headerFieldNameError(hf.Name)
+ }
+ }
+
+ if invalid != nil {
+ hdec.SetEmitEnabled(false)
+ return
+ }
+
+ size := hf.Size()
+ if size > remainSize {
+ hdec.SetEmitEnabled(false)
+ mh.Truncated = true
+ return
+ }
+ remainSize -= size
+
+ mh.Fields = append(mh.Fields, hf)
+ })
+ // Lose reference to MetaHeadersFrame:
+ defer hdec.SetEmitFunc(func(hf hpack.HeaderField) {})
+
+ var hc http2headersOrContinuation = hf
+ for {
+ frag := hc.HeaderBlockFragment()
+ if _, err := hdec.Write(frag); err != nil {
+ return nil, http2ConnectionError(http2ErrCodeCompression)
+ }
+
+ if hc.HeadersEnded() {
+ break
+ }
+ if f, err := fr.ReadFrame(); err != nil {
+ return nil, err
+ } else {
+ hc = f.(*http2ContinuationFrame) // guaranteed by checkFrameOrder
+ }
+ }
+
+ mh.http2HeadersFrame.headerFragBuf = nil
+ mh.http2HeadersFrame.invalidate()
+
+ if err := hdec.Close(); err != nil {
+ return nil, http2ConnectionError(http2ErrCodeCompression)
+ }
+ if invalid != nil {
+ fr.errDetail = invalid
+ if http2VerboseLogs {
+ log.Printf("http2: invalid header: %v", invalid)
+ }
+ return nil, http2StreamError{mh.StreamID, http2ErrCodeProtocol, invalid}
+ }
+ if err := mh.checkPseudos(); err != nil {
+ fr.errDetail = err
+ if http2VerboseLogs {
+ log.Printf("http2: invalid pseudo headers: %v", err)
+ }
+ return nil, http2StreamError{mh.StreamID, http2ErrCodeProtocol, err}
+ }
+ return mh, nil
+}
+
+func http2summarizeFrame(f http2Frame) string {
+ var buf bytes.Buffer
+ f.Header().writeDebug(&buf)
+ switch f := f.(type) {
+ case *http2SettingsFrame:
+ n := 0
+ f.ForeachSetting(func(s http2Setting) error {
+ n++
+ if n == 1 {
+ buf.WriteString(", settings:")
+ }
+ fmt.Fprintf(&buf, " %v=%v,", s.ID, s.Val)
+ return nil
+ })
+ if n > 0 {
+ buf.Truncate(buf.Len() - 1) // remove trailing comma
+ }
+ case *http2DataFrame:
+ data := f.Data()
+ const max = 256
+ if len(data) > max {
+ data = data[:max]
+ }
+ fmt.Fprintf(&buf, " data=%q", data)
+ if len(f.Data()) > max {
+ fmt.Fprintf(&buf, " (%d bytes omitted)", len(f.Data())-max)
+ }
+ case *http2WindowUpdateFrame:
+ if f.StreamID == 0 {
+ buf.WriteString(" (conn)")
+ }
+ fmt.Fprintf(&buf, " incr=%v", f.Increment)
+ case *http2PingFrame:
+ fmt.Fprintf(&buf, " ping=%q", f.Data[:])
+ case *http2GoAwayFrame:
+ fmt.Fprintf(&buf, " LastStreamID=%v ErrCode=%v Debug=%q",
+ f.LastStreamID, f.ErrCode, f.debugData)
+ case *http2RSTStreamFrame:
+ fmt.Fprintf(&buf, " ErrCode=%v", f.ErrCode)
+ }
+ return buf.String()
+}
+
+func http2traceHasWroteHeaderField(trace *httptrace.ClientTrace) bool {
+ return trace != nil && trace.WroteHeaderField != nil
+}
+
+func http2traceWroteHeaderField(trace *httptrace.ClientTrace, k, v string) {
+ if trace != nil && trace.WroteHeaderField != nil {
+ trace.WroteHeaderField(k, []string{v})
+ }
+}
+
+func http2traceGot1xxResponseFunc(trace *httptrace.ClientTrace) func(int, textproto.MIMEHeader) error {
+ if trace != nil {
+ return trace.Got1xxResponse
+ }
+ return nil
+}
+
+// dialTLSWithContext uses tls.Dialer, added in Go 1.15, to open a TLS
+// connection.
+func (t *http2Transport) dialTLSWithContext(ctx context.Context, network, addr string, cfg *tls.Config) (*tls.Conn, error) {
+ dialer := &tls.Dialer{
+ Config: cfg,
+ }
+ cn, err := dialer.DialContext(ctx, network, addr)
+ if err != nil {
+ return nil, err
+ }
+ tlsCn := cn.(*tls.Conn) // DialContext comment promises this will always succeed
+ return tlsCn, nil
+}
+
+func http2tlsUnderlyingConn(tc *tls.Conn) net.Conn {
+ return tc.NetConn()
+}
+
+var http2DebugGoroutines = os.Getenv("DEBUG_HTTP2_GOROUTINES") == "1"
+
+type http2goroutineLock uint64
+
+func http2newGoroutineLock() http2goroutineLock {
+ if !http2DebugGoroutines {
+ return 0
+ }
+ return http2goroutineLock(http2curGoroutineID())
+}
+
+func (g http2goroutineLock) check() {
+ if !http2DebugGoroutines {
+ return
+ }
+ if http2curGoroutineID() != uint64(g) {
+ panic("running on the wrong goroutine")
+ }
+}
+
+func (g http2goroutineLock) checkNotOn() {
+ if !http2DebugGoroutines {
+ return
+ }
+ if http2curGoroutineID() == uint64(g) {
+ panic("running on the wrong goroutine")
+ }
+}
+
+var http2goroutineSpace = []byte("goroutine ")
+
+func http2curGoroutineID() uint64 {
+ bp := http2littleBuf.Get().(*[]byte)
+ defer http2littleBuf.Put(bp)
+ b := *bp
+ b = b[:runtime.Stack(b, false)]
+ // Parse the 4707 out of "goroutine 4707 ["
+ b = bytes.TrimPrefix(b, http2goroutineSpace)
+ i := bytes.IndexByte(b, ' ')
+ if i < 0 {
+ panic(fmt.Sprintf("No space found in %q", b))
+ }
+ b = b[:i]
+ n, err := http2parseUintBytes(b, 10, 64)
+ if err != nil {
+ panic(fmt.Sprintf("Failed to parse goroutine ID out of %q: %v", b, err))
+ }
+ return n
+}
+
+var http2littleBuf = sync.Pool{
+ New: func() interface{} {
+ buf := make([]byte, 64)
+ return &buf
+ },
+}
+
+// parseUintBytes is like strconv.ParseUint, but using a []byte.
+func http2parseUintBytes(s []byte, base int, bitSize int) (n uint64, err error) {
+ var cutoff, maxVal uint64
+
+ if bitSize == 0 {
+ bitSize = int(strconv.IntSize)
+ }
+
+ s0 := s
+ switch {
+ case len(s) < 1:
+ err = strconv.ErrSyntax
+ goto Error
+
+ case 2 <= base && base <= 36:
+ // valid base; nothing to do
+
+ case base == 0:
+ // Look for octal, hex prefix.
+ switch {
+ case s[0] == '0' && len(s) > 1 && (s[1] == 'x' || s[1] == 'X'):
+ base = 16
+ s = s[2:]
+ if len(s) < 1 {
+ err = strconv.ErrSyntax
+ goto Error
+ }
+ case s[0] == '0':
+ base = 8
+ default:
+ base = 10
+ }
+
+ default:
+ err = errors.New("invalid base " + strconv.Itoa(base))
+ goto Error
+ }
+
+ n = 0
+ cutoff = http2cutoff64(base)
+ maxVal = 1<<uint(bitSize) - 1
+
+ for i := 0; i < len(s); i++ {
+ var v byte
+ d := s[i]
+ switch {
+ case '0' <= d && d <= '9':
+ v = d - '0'
+ case 'a' <= d && d <= 'z':
+ v = d - 'a' + 10
+ case 'A' <= d && d <= 'Z':
+ v = d - 'A' + 10
+ default:
+ n = 0
+ err = strconv.ErrSyntax
+ goto Error
+ }
+ if int(v) >= base {
+ n = 0
+ err = strconv.ErrSyntax
+ goto Error
+ }
+
+ if n >= cutoff {
+ // n*base overflows
+ n = 1<<64 - 1
+ err = strconv.ErrRange
+ goto Error
+ }
+ n *= uint64(base)
+
+ n1 := n + uint64(v)
+ if n1 < n || n1 > maxVal {
+ // n+v overflows
+ n = 1<<64 - 1
+ err = strconv.ErrRange
+ goto Error
+ }
+ n = n1
+ }
+
+ return n, nil
+
+Error:
+ return n, &strconv.NumError{Func: "ParseUint", Num: string(s0), Err: err}
+}
+
+// Return the first number n such that n*base >= 1<<64.
+func http2cutoff64(base int) uint64 {
+ if base < 2 {
+ return 0
+ }
+ return (1<<64-1)/uint64(base) + 1
+}
+
+var (
+ http2commonBuildOnce sync.Once
+ http2commonLowerHeader map[string]string // Go-Canonical-Case -> lower-case
+ http2commonCanonHeader map[string]string // lower-case -> Go-Canonical-Case
+)
+
+func http2buildCommonHeaderMapsOnce() {
+ http2commonBuildOnce.Do(http2buildCommonHeaderMaps)
+}
+
+func http2buildCommonHeaderMaps() {
+ common := []string{
+ "accept",
+ "accept-charset",
+ "accept-encoding",
+ "accept-language",
+ "accept-ranges",
+ "age",
+ "access-control-allow-credentials",
+ "access-control-allow-headers",
+ "access-control-allow-methods",
+ "access-control-allow-origin",
+ "access-control-expose-headers",
+ "access-control-max-age",
+ "access-control-request-headers",
+ "access-control-request-method",
+ "allow",
+ "authorization",
+ "cache-control",
+ "content-disposition",
+ "content-encoding",
+ "content-language",
+ "content-length",
+ "content-location",
+ "content-range",
+ "content-type",
+ "cookie",
+ "date",
+ "etag",
+ "expect",
+ "expires",
+ "from",
+ "host",
+ "if-match",
+ "if-modified-since",
+ "if-none-match",
+ "if-unmodified-since",
+ "last-modified",
+ "link",
+ "location",
+ "max-forwards",
+ "origin",
+ "proxy-authenticate",
+ "proxy-authorization",
+ "range",
+ "referer",
+ "refresh",
+ "retry-after",
+ "server",
+ "set-cookie",
+ "strict-transport-security",
+ "trailer",
+ "transfer-encoding",
+ "user-agent",
+ "vary",
+ "via",
+ "www-authenticate",
+ "x-forwarded-for",
+ "x-forwarded-proto",
+ }
+ http2commonLowerHeader = make(map[string]string, len(common))
+ http2commonCanonHeader = make(map[string]string, len(common))
+ for _, v := range common {
+ chk := CanonicalHeaderKey(v)
+ http2commonLowerHeader[chk] = v
+ http2commonCanonHeader[v] = chk
+ }
+}
+
+func http2lowerHeader(v string) (lower string, ascii bool) {
+ http2buildCommonHeaderMapsOnce()
+ if s, ok := http2commonLowerHeader[v]; ok {
+ return s, true
+ }
+ return http2asciiToLower(v)
+}
+
+func http2canonicalHeader(v string) string {
+ http2buildCommonHeaderMapsOnce()
+ if s, ok := http2commonCanonHeader[v]; ok {
+ return s
+ }
+ return CanonicalHeaderKey(v)
+}
+
+var (
+ http2VerboseLogs bool
+ http2logFrameWrites bool
+ http2logFrameReads bool
+ http2inTests bool
+)
+
+func init() {
+ e := os.Getenv("GODEBUG")
+ if strings.Contains(e, "http2debug=1") {
+ http2VerboseLogs = true
+ }
+ if strings.Contains(e, "http2debug=2") {
+ http2VerboseLogs = true
+ http2logFrameWrites = true
+ http2logFrameReads = true
+ }
+}
+
+const (
+ // ClientPreface is the string that must be sent by new
+ // connections from clients.
+ http2ClientPreface = "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"
+
+ // SETTINGS_MAX_FRAME_SIZE default
+ // https://httpwg.org/specs/rfc7540.html#rfc.section.6.5.2
+ http2initialMaxFrameSize = 16384
+
+ // NextProtoTLS is the NPN/ALPN protocol negotiated during
+ // HTTP/2's TLS setup.
+ http2NextProtoTLS = "h2"
+
+ // https://httpwg.org/specs/rfc7540.html#SettingValues
+ http2initialHeaderTableSize = 4096
+
+ http2initialWindowSize = 65535 // 6.9.2 Initial Flow Control Window Size
+
+ http2defaultMaxReadFrameSize = 1 << 20
+)
+
+var (
+ http2clientPreface = []byte(http2ClientPreface)
+)
+
+type http2streamState int
+
+// HTTP/2 stream states.
+//
+// See http://tools.ietf.org/html/rfc7540#section-5.1.
+//
+// For simplicity, the server code merges "reserved (local)" into
+// "half-closed (remote)". This is one less state transition to track.
+// The only downside is that we send PUSH_PROMISEs slightly less
+// liberally than allowable. More discussion here:
+// https://lists.w3.org/Archives/Public/ietf-http-wg/2016JulSep/0599.html
+//
+// "reserved (remote)" is omitted since the client code does not
+// support server push.
+const (
+ http2stateIdle http2streamState = iota
+ http2stateOpen
+ http2stateHalfClosedLocal
+ http2stateHalfClosedRemote
+ http2stateClosed
+)
+
+var http2stateName = [...]string{
+ http2stateIdle: "Idle",
+ http2stateOpen: "Open",
+ http2stateHalfClosedLocal: "HalfClosedLocal",
+ http2stateHalfClosedRemote: "HalfClosedRemote",
+ http2stateClosed: "Closed",
+}
+
+func (st http2streamState) String() string {
+ return http2stateName[st]
+}
+
+// Setting is a setting parameter: which setting it is, and its value.
+type http2Setting struct {
+ // ID is which setting is being set.
+ // See https://httpwg.org/specs/rfc7540.html#SettingFormat
+ ID http2SettingID
+
+ // Val is the value.
+ Val uint32
+}
+
+func (s http2Setting) String() string {
+ return fmt.Sprintf("[%v = %d]", s.ID, s.Val)
+}
+
+// Valid reports whether the setting is valid.
+func (s http2Setting) Valid() error {
+ // Limits and error codes from 6.5.2 Defined SETTINGS Parameters
+ switch s.ID {
+ case http2SettingEnablePush:
+ if s.Val != 1 && s.Val != 0 {
+ return http2ConnectionError(http2ErrCodeProtocol)
+ }
+ case http2SettingInitialWindowSize:
+ if s.Val > 1<<31-1 {
+ return http2ConnectionError(http2ErrCodeFlowControl)
+ }
+ case http2SettingMaxFrameSize:
+ if s.Val < 16384 || s.Val > 1<<24-1 {
+ return http2ConnectionError(http2ErrCodeProtocol)
+ }
+ }
+ return nil
+}
+
+// A SettingID is an HTTP/2 setting as defined in
+// https://httpwg.org/specs/rfc7540.html#iana-settings
+type http2SettingID uint16
+
+const (
+ http2SettingHeaderTableSize http2SettingID = 0x1
+ http2SettingEnablePush http2SettingID = 0x2
+ http2SettingMaxConcurrentStreams http2SettingID = 0x3
+ http2SettingInitialWindowSize http2SettingID = 0x4
+ http2SettingMaxFrameSize http2SettingID = 0x5
+ http2SettingMaxHeaderListSize http2SettingID = 0x6
+)
+
+var http2settingName = map[http2SettingID]string{
+ http2SettingHeaderTableSize: "HEADER_TABLE_SIZE",
+ http2SettingEnablePush: "ENABLE_PUSH",
+ http2SettingMaxConcurrentStreams: "MAX_CONCURRENT_STREAMS",
+ http2SettingInitialWindowSize: "INITIAL_WINDOW_SIZE",
+ http2SettingMaxFrameSize: "MAX_FRAME_SIZE",
+ http2SettingMaxHeaderListSize: "MAX_HEADER_LIST_SIZE",
+}
+
+func (s http2SettingID) String() string {
+ if v, ok := http2settingName[s]; ok {
+ return v
+ }
+ return fmt.Sprintf("UNKNOWN_SETTING_%d", uint16(s))
+}
+
+// validWireHeaderFieldName reports whether v is a valid header field
+// name (key). See httpguts.ValidHeaderName for the base rules.
+//
+// Further, http2 says:
+//
+// "Just as in HTTP/1.x, header field names are strings of ASCII
+// characters that are compared in a case-insensitive
+// fashion. However, header field names MUST be converted to
+// lowercase prior to their encoding in HTTP/2. "
+func http2validWireHeaderFieldName(v string) bool {
+ if len(v) == 0 {
+ return false
+ }
+ for _, r := range v {
+ if !httpguts.IsTokenRune(r) {
+ return false
+ }
+ if 'A' <= r && r <= 'Z' {
+ return false
+ }
+ }
+ return true
+}
+
+func http2httpCodeString(code int) string {
+ switch code {
+ case 200:
+ return "200"
+ case 404:
+ return "404"
+ }
+ return strconv.Itoa(code)
+}
+
+// from pkg io
+type http2stringWriter interface {
+ WriteString(s string) (n int, err error)
+}
+
+// A gate lets two goroutines coordinate their activities.
+type http2gate chan struct{}
+
+func (g http2gate) Done() { g <- struct{}{} }
+
+func (g http2gate) Wait() { <-g }
+
+// A closeWaiter is like a sync.WaitGroup but only goes 1 to 0 (open to closed).
+type http2closeWaiter chan struct{}
+
+// Init makes a closeWaiter usable.
+// It exists because so a closeWaiter value can be placed inside a
+// larger struct and have the Mutex and Cond's memory in the same
+// allocation.
+func (cw *http2closeWaiter) Init() {
+ *cw = make(chan struct{})
+}
+
+// Close marks the closeWaiter as closed and unblocks any waiters.
+func (cw http2closeWaiter) Close() {
+ close(cw)
+}
+
+// Wait waits for the closeWaiter to become closed.
+func (cw http2closeWaiter) Wait() {
+ <-cw
+}
+
+// bufferedWriter is a buffered writer that writes to w.
+// Its buffered writer is lazily allocated as needed, to minimize
+// idle memory usage with many connections.
+type http2bufferedWriter struct {
+ _ http2incomparable
+ w io.Writer // immutable
+ bw *bufio.Writer // non-nil when data is buffered
+}
+
+func http2newBufferedWriter(w io.Writer) *http2bufferedWriter {
+ return &http2bufferedWriter{w: w}
+}
+
+// bufWriterPoolBufferSize is the size of bufio.Writer's
+// buffers created using bufWriterPool.
+//
+// TODO: pick a less arbitrary value? this is a bit under
+// (3 x typical 1500 byte MTU) at least. Other than that,
+// not much thought went into it.
+const http2bufWriterPoolBufferSize = 4 << 10
+
+var http2bufWriterPool = sync.Pool{
+ New: func() interface{} {
+ return bufio.NewWriterSize(nil, http2bufWriterPoolBufferSize)
+ },
+}
+
+func (w *http2bufferedWriter) Available() int {
+ if w.bw == nil {
+ return http2bufWriterPoolBufferSize
+ }
+ return w.bw.Available()
+}
+
+func (w *http2bufferedWriter) Write(p []byte) (n int, err error) {
+ if w.bw == nil {
+ bw := http2bufWriterPool.Get().(*bufio.Writer)
+ bw.Reset(w.w)
+ w.bw = bw
+ }
+ return w.bw.Write(p)
+}
+
+func (w *http2bufferedWriter) Flush() error {
+ bw := w.bw
+ if bw == nil {
+ return nil
+ }
+ err := bw.Flush()
+ bw.Reset(nil)
+ http2bufWriterPool.Put(bw)
+ w.bw = nil
+ return err
+}
+
+func http2mustUint31(v int32) uint32 {
+ if v < 0 || v > 2147483647 {
+ panic("out of range")
+ }
+ return uint32(v)
+}
+
+// bodyAllowedForStatus reports whether a given response status code
+// permits a body. See RFC 7230, section 3.3.
+func http2bodyAllowedForStatus(status int) bool {
+ switch {
+ case status >= 100 && status <= 199:
+ return false
+ case status == 204:
+ return false
+ case status == 304:
+ return false
+ }
+ return true
+}
+
+type http2httpError struct {
+ _ http2incomparable
+ msg string
+ timeout bool
+}
+
+func (e *http2httpError) Error() string { return e.msg }
+
+func (e *http2httpError) Timeout() bool { return e.timeout }
+
+func (e *http2httpError) Temporary() bool { return true }
+
+var http2errTimeout error = &http2httpError{msg: "http2: timeout awaiting response headers", timeout: true}
+
+type http2connectionStater interface {
+ ConnectionState() tls.ConnectionState
+}
+
+var http2sorterPool = sync.Pool{New: func() interface{} { return new(http2sorter) }}
+
+type http2sorter struct {
+ v []string // owned by sorter
+}
+
+func (s *http2sorter) Len() int { return len(s.v) }
+
+func (s *http2sorter) Swap(i, j int) { s.v[i], s.v[j] = s.v[j], s.v[i] }
+
+func (s *http2sorter) Less(i, j int) bool { return s.v[i] < s.v[j] }
+
+// Keys returns the sorted keys of h.
+//
+// The returned slice is only valid until s used again or returned to
+// its pool.
+func (s *http2sorter) Keys(h Header) []string {
+ keys := s.v[:0]
+ for k := range h {
+ keys = append(keys, k)
+ }
+ s.v = keys
+ sort.Sort(s)
+ return keys
+}
+
+func (s *http2sorter) SortStrings(ss []string) {
+ // Our sorter works on s.v, which sorter owns, so
+ // stash it away while we sort the user's buffer.
+ save := s.v
+ s.v = ss
+ sort.Sort(s)
+ s.v = save
+}
+
+// validPseudoPath reports whether v is a valid :path pseudo-header
+// value. It must be either:
+//
+// - a non-empty string starting with '/'
+// - the string '*', for OPTIONS requests.
+//
+// For now this is only used a quick check for deciding when to clean
+// up Opaque URLs before sending requests from the Transport.
+// See golang.org/issue/16847
+//
+// We used to enforce that the path also didn't start with "//", but
+// Google's GFE accepts such paths and Chrome sends them, so ignore
+// that part of the spec. See golang.org/issue/19103.
+func http2validPseudoPath(v string) bool {
+ return (len(v) > 0 && v[0] == '/') || v == "*"
+}
+
+// incomparable is a zero-width, non-comparable type. Adding it to a struct
+// makes that struct also non-comparable, and generally doesn't add
+// any size (as long as it's first).
+type http2incomparable [0]func()
+
+// pipe is a goroutine-safe io.Reader/io.Writer pair. It's like
+// io.Pipe except there are no PipeReader/PipeWriter halves, and the
+// underlying buffer is an interface. (io.Pipe is always unbuffered)
+type http2pipe struct {
+ mu sync.Mutex
+ c sync.Cond // c.L lazily initialized to &p.mu
+ b http2pipeBuffer // nil when done reading
+ unread int // bytes unread when done
+ err error // read error once empty. non-nil means closed.
+ breakErr error // immediate read error (caller doesn't see rest of b)
+ donec chan struct{} // closed on error
+ readFn func() // optional code to run in Read before error
+}
+
+type http2pipeBuffer interface {
+ Len() int
+ io.Writer
+ io.Reader
+}
+
+// setBuffer initializes the pipe buffer.
+// It has no effect if the pipe is already closed.
+func (p *http2pipe) setBuffer(b http2pipeBuffer) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ if p.err != nil || p.breakErr != nil {
+ return
+ }
+ p.b = b
+}
+
+func (p *http2pipe) Len() int {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ if p.b == nil {
+ return p.unread
+ }
+ return p.b.Len()
+}
+
+// Read waits until data is available and copies bytes
+// from the buffer into p.
+func (p *http2pipe) Read(d []byte) (n int, err error) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ if p.c.L == nil {
+ p.c.L = &p.mu
+ }
+ for {
+ if p.breakErr != nil {
+ return 0, p.breakErr
+ }
+ if p.b != nil && p.b.Len() > 0 {
+ return p.b.Read(d)
+ }
+ if p.err != nil {
+ if p.readFn != nil {
+ p.readFn() // e.g. copy trailers
+ p.readFn = nil // not sticky like p.err
+ }
+ p.b = nil
+ return 0, p.err
+ }
+ p.c.Wait()
+ }
+}
+
+var http2errClosedPipeWrite = errors.New("write on closed buffer")
+
+// Write copies bytes from p into the buffer and wakes a reader.
+// It is an error to write more data than the buffer can hold.
+func (p *http2pipe) Write(d []byte) (n int, err error) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ if p.c.L == nil {
+ p.c.L = &p.mu
+ }
+ defer p.c.Signal()
+ if p.err != nil || p.breakErr != nil {
+ return 0, http2errClosedPipeWrite
+ }
+ return p.b.Write(d)
+}
+
+// CloseWithError causes the next Read (waking up a current blocked
+// Read if needed) to return the provided err after all data has been
+// read.
+//
+// The error must be non-nil.
+func (p *http2pipe) CloseWithError(err error) { p.closeWithError(&p.err, err, nil) }
+
+// BreakWithError causes the next Read (waking up a current blocked
+// Read if needed) to return the provided err immediately, without
+// waiting for unread data.
+func (p *http2pipe) BreakWithError(err error) { p.closeWithError(&p.breakErr, err, nil) }
+
+// closeWithErrorAndCode is like CloseWithError but also sets some code to run
+// in the caller's goroutine before returning the error.
+func (p *http2pipe) closeWithErrorAndCode(err error, fn func()) { p.closeWithError(&p.err, err, fn) }
+
+func (p *http2pipe) closeWithError(dst *error, err error, fn func()) {
+ if err == nil {
+ panic("err must be non-nil")
+ }
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ if p.c.L == nil {
+ p.c.L = &p.mu
+ }
+ defer p.c.Signal()
+ if *dst != nil {
+ // Already been done.
+ return
+ }
+ p.readFn = fn
+ if dst == &p.breakErr {
+ if p.b != nil {
+ p.unread += p.b.Len()
+ }
+ p.b = nil
+ }
+ *dst = err
+ p.closeDoneLocked()
+}
+
+// requires p.mu be held.
+func (p *http2pipe) closeDoneLocked() {
+ if p.donec == nil {
+ return
+ }
+ // Close if unclosed. This isn't racy since we always
+ // hold p.mu while closing.
+ select {
+ case <-p.donec:
+ default:
+ close(p.donec)
+ }
+}
+
+// Err returns the error (if any) first set by BreakWithError or CloseWithError.
+func (p *http2pipe) Err() error {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ if p.breakErr != nil {
+ return p.breakErr
+ }
+ return p.err
+}
+
+// Done returns a channel which is closed if and when this pipe is closed
+// with CloseWithError.
+func (p *http2pipe) Done() <-chan struct{} {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ if p.donec == nil {
+ p.donec = make(chan struct{})
+ if p.err != nil || p.breakErr != nil {
+ // Already hit an error.
+ p.closeDoneLocked()
+ }
+ }
+ return p.donec
+}
+
+const (
+ http2prefaceTimeout = 10 * time.Second
+ http2firstSettingsTimeout = 2 * time.Second // should be in-flight with preface anyway
+ http2handlerChunkWriteSize = 4 << 10
+ http2defaultMaxStreams = 250 // TODO: make this 100 as the GFE seems to?
+ http2maxQueuedControlFrames = 10000
+)
+
+var (
+ http2errClientDisconnected = errors.New("client disconnected")
+ http2errClosedBody = errors.New("body closed by handler")
+ http2errHandlerComplete = errors.New("http2: request body closed due to handler exiting")
+ http2errStreamClosed = errors.New("http2: stream closed")
+)
+
+var http2responseWriterStatePool = sync.Pool{
+ New: func() interface{} {
+ rws := &http2responseWriterState{}
+ rws.bw = bufio.NewWriterSize(http2chunkWriter{rws}, http2handlerChunkWriteSize)
+ return rws
+ },
+}
+
+// Test hooks.
+var (
+ http2testHookOnConn func()
+ http2testHookGetServerConn func(*http2serverConn)
+ http2testHookOnPanicMu *sync.Mutex // nil except in tests
+ http2testHookOnPanic func(sc *http2serverConn, panicVal interface{}) (rePanic bool)
+)
+
+// Server is an HTTP/2 server.
+type http2Server struct {
+ // MaxHandlers limits the number of http.Handler ServeHTTP goroutines
+ // which may run at a time over all connections.
+ // Negative or zero no limit.
+ // TODO: implement
+ MaxHandlers int
+
+ // MaxConcurrentStreams optionally specifies the number of
+ // concurrent streams that each client may have open at a
+ // time. This is unrelated to the number of http.Handler goroutines
+ // which may be active globally, which is MaxHandlers.
+ // If zero, MaxConcurrentStreams defaults to at least 100, per
+ // the HTTP/2 spec's recommendations.
+ MaxConcurrentStreams uint32
+
+ // MaxDecoderHeaderTableSize optionally specifies the http2
+ // SETTINGS_HEADER_TABLE_SIZE to send in the initial settings frame. It
+ // informs the remote endpoint of the maximum size of the header compression
+ // table used to decode header blocks, in octets. If zero, the default value
+ // of 4096 is used.
+ MaxDecoderHeaderTableSize uint32
+
+ // MaxEncoderHeaderTableSize optionally specifies an upper limit for the
+ // header compression table used for encoding request headers. Received
+ // SETTINGS_HEADER_TABLE_SIZE settings are capped at this limit. If zero,
+ // the default value of 4096 is used.
+ MaxEncoderHeaderTableSize uint32
+
+ // MaxReadFrameSize optionally specifies the largest frame
+ // this server is willing to read. A valid value is between
+ // 16k and 16M, inclusive. If zero or otherwise invalid, a
+ // default value is used.
+ MaxReadFrameSize uint32
+
+ // PermitProhibitedCipherSuites, if true, permits the use of
+ // cipher suites prohibited by the HTTP/2 spec.
+ PermitProhibitedCipherSuites bool
+
+ // IdleTimeout specifies how long until idle clients should be
+ // closed with a GOAWAY frame. PING frames are not considered
+ // activity for the purposes of IdleTimeout.
+ IdleTimeout time.Duration
+
+ // MaxUploadBufferPerConnection is the size of the initial flow
+ // control window for each connections. The HTTP/2 spec does not
+ // allow this to be smaller than 65535 or larger than 2^32-1.
+ // If the value is outside this range, a default value will be
+ // used instead.
+ MaxUploadBufferPerConnection int32
+
+ // MaxUploadBufferPerStream is the size of the initial flow control
+ // window for each stream. The HTTP/2 spec does not allow this to
+ // be larger than 2^32-1. If the value is zero or larger than the
+ // maximum, a default value will be used instead.
+ MaxUploadBufferPerStream int32
+
+ // NewWriteScheduler constructs a write scheduler for a connection.
+ // If nil, a default scheduler is chosen.
+ NewWriteScheduler func() http2WriteScheduler
+
+ // CountError, if non-nil, is called on HTTP/2 server errors.
+ // It's intended to increment a metric for monitoring, such
+ // as an expvar or Prometheus metric.
+ // The errType consists of only ASCII word characters.
+ CountError func(errType string)
+
+ // Internal state. This is a pointer (rather than embedded directly)
+ // so that we don't embed a Mutex in this struct, which will make the
+ // struct non-copyable, which might break some callers.
+ state *http2serverInternalState
+}
+
+func (s *http2Server) initialConnRecvWindowSize() int32 {
+ if s.MaxUploadBufferPerConnection >= http2initialWindowSize {
+ return s.MaxUploadBufferPerConnection
+ }
+ return 1 << 20
+}
+
+func (s *http2Server) initialStreamRecvWindowSize() int32 {
+ if s.MaxUploadBufferPerStream > 0 {
+ return s.MaxUploadBufferPerStream
+ }
+ return 1 << 20
+}
+
+func (s *http2Server) maxReadFrameSize() uint32 {
+ if v := s.MaxReadFrameSize; v >= http2minMaxFrameSize && v <= http2maxFrameSize {
+ return v
+ }
+ return http2defaultMaxReadFrameSize
+}
+
+func (s *http2Server) maxConcurrentStreams() uint32 {
+ if v := s.MaxConcurrentStreams; v > 0 {
+ return v
+ }
+ return http2defaultMaxStreams
+}
+
+func (s *http2Server) maxDecoderHeaderTableSize() uint32 {
+ if v := s.MaxDecoderHeaderTableSize; v > 0 {
+ return v
+ }
+ return http2initialHeaderTableSize
+}
+
+func (s *http2Server) maxEncoderHeaderTableSize() uint32 {
+ if v := s.MaxEncoderHeaderTableSize; v > 0 {
+ return v
+ }
+ return http2initialHeaderTableSize
+}
+
+// maxQueuedControlFrames is the maximum number of control frames like
+// SETTINGS, PING and RST_STREAM that will be queued for writing before
+// the connection is closed to prevent memory exhaustion attacks.
+func (s *http2Server) maxQueuedControlFrames() int {
+ // TODO: if anybody asks, add a Server field, and remember to define the
+ // behavior of negative values.
+ return http2maxQueuedControlFrames
+}
+
+type http2serverInternalState struct {
+ mu sync.Mutex
+ activeConns map[*http2serverConn]struct{}
+}
+
+func (s *http2serverInternalState) registerConn(sc *http2serverConn) {
+ if s == nil {
+ return // if the Server was used without calling ConfigureServer
+ }
+ s.mu.Lock()
+ s.activeConns[sc] = struct{}{}
+ s.mu.Unlock()
+}
+
+func (s *http2serverInternalState) unregisterConn(sc *http2serverConn) {
+ if s == nil {
+ return // if the Server was used without calling ConfigureServer
+ }
+ s.mu.Lock()
+ delete(s.activeConns, sc)
+ s.mu.Unlock()
+}
+
+func (s *http2serverInternalState) startGracefulShutdown() {
+ if s == nil {
+ return // if the Server was used without calling ConfigureServer
+ }
+ s.mu.Lock()
+ for sc := range s.activeConns {
+ sc.startGracefulShutdown()
+ }
+ s.mu.Unlock()
+}
+
+// ConfigureServer adds HTTP/2 support to a net/http Server.
+//
+// The configuration conf may be nil.
+//
+// ConfigureServer must be called before s begins serving.
+func http2ConfigureServer(s *Server, conf *http2Server) error {
+ if s == nil {
+ panic("nil *http.Server")
+ }
+ if conf == nil {
+ conf = new(http2Server)
+ }
+ conf.state = &http2serverInternalState{activeConns: make(map[*http2serverConn]struct{})}
+ if h1, h2 := s, conf; h2.IdleTimeout == 0 {
+ if h1.IdleTimeout != 0 {
+ h2.IdleTimeout = h1.IdleTimeout
+ } else {
+ h2.IdleTimeout = h1.ReadTimeout
+ }
+ }
+ s.RegisterOnShutdown(conf.state.startGracefulShutdown)
+
+ if s.TLSConfig == nil {
+ s.TLSConfig = new(tls.Config)
+ } else if s.TLSConfig.CipherSuites != nil && s.TLSConfig.MinVersion < tls.VersionTLS13 {
+ // If they already provided a TLS 1.0–1.2 CipherSuite list, return an
+ // error if it is missing ECDHE_RSA_WITH_AES_128_GCM_SHA256 or
+ // ECDHE_ECDSA_WITH_AES_128_GCM_SHA256.
+ haveRequired := false
+ for _, cs := range s.TLSConfig.CipherSuites {
+ switch cs {
+ case tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
+ // Alternative MTI cipher to not discourage ECDSA-only servers.
+ // See http://golang.org/cl/30721 for further information.
+ tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256:
+ haveRequired = true
+ }
+ }
+ if !haveRequired {
+ return fmt.Errorf("http2: TLSConfig.CipherSuites is missing an HTTP/2-required AES_128_GCM_SHA256 cipher (need at least one of TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 or TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256)")
+ }
+ }
+
+ // Note: not setting MinVersion to tls.VersionTLS12,
+ // as we don't want to interfere with HTTP/1.1 traffic
+ // on the user's server. We enforce TLS 1.2 later once
+ // we accept a connection. Ideally this should be done
+ // during next-proto selection, but using TLS <1.2 with
+ // HTTP/2 is still the client's bug.
+
+ s.TLSConfig.PreferServerCipherSuites = true
+
+ if !http2strSliceContains(s.TLSConfig.NextProtos, http2NextProtoTLS) {
+ s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, http2NextProtoTLS)
+ }
+ if !http2strSliceContains(s.TLSConfig.NextProtos, "http/1.1") {
+ s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, "http/1.1")
+ }
+
+ if s.TLSNextProto == nil {
+ s.TLSNextProto = map[string]func(*Server, *tls.Conn, Handler){}
+ }
+ protoHandler := func(hs *Server, c *tls.Conn, h Handler) {
+ if http2testHookOnConn != nil {
+ http2testHookOnConn()
+ }
+ // The TLSNextProto interface predates contexts, so
+ // the net/http package passes down its per-connection
+ // base context via an exported but unadvertised
+ // method on the Handler. This is for internal
+ // net/http<=>http2 use only.
+ var ctx context.Context
+ type baseContexter interface {
+ BaseContext() context.Context
+ }
+ if bc, ok := h.(baseContexter); ok {
+ ctx = bc.BaseContext()
+ }
+ conf.ServeConn(c, &http2ServeConnOpts{
+ Context: ctx,
+ Handler: h,
+ BaseConfig: hs,
+ })
+ }
+ s.TLSNextProto[http2NextProtoTLS] = protoHandler
+ return nil
+}
+
+// ServeConnOpts are options for the Server.ServeConn method.
+type http2ServeConnOpts struct {
+ // Context is the base context to use.
+ // If nil, context.Background is used.
+ Context context.Context
+
+ // BaseConfig optionally sets the base configuration
+ // for values. If nil, defaults are used.
+ BaseConfig *Server
+
+ // Handler specifies which handler to use for processing
+ // requests. If nil, BaseConfig.Handler is used. If BaseConfig
+ // or BaseConfig.Handler is nil, http.DefaultServeMux is used.
+ Handler Handler
+
+ // UpgradeRequest is an initial request received on a connection
+ // undergoing an h2c upgrade. The request body must have been
+ // completely read from the connection before calling ServeConn,
+ // and the 101 Switching Protocols response written.
+ UpgradeRequest *Request
+
+ // Settings is the decoded contents of the HTTP2-Settings header
+ // in an h2c upgrade request.
+ Settings []byte
+
+ // SawClientPreface is set if the HTTP/2 connection preface
+ // has already been read from the connection.
+ SawClientPreface bool
+}
+
+func (o *http2ServeConnOpts) context() context.Context {
+ if o != nil && o.Context != nil {
+ return o.Context
+ }
+ return context.Background()
+}
+
+func (o *http2ServeConnOpts) baseConfig() *Server {
+ if o != nil && o.BaseConfig != nil {
+ return o.BaseConfig
+ }
+ return new(Server)
+}
+
+func (o *http2ServeConnOpts) handler() Handler {
+ if o != nil {
+ if o.Handler != nil {
+ return o.Handler
+ }
+ if o.BaseConfig != nil && o.BaseConfig.Handler != nil {
+ return o.BaseConfig.Handler
+ }
+ }
+ return DefaultServeMux
+}
+
+// ServeConn serves HTTP/2 requests on the provided connection and
+// blocks until the connection is no longer readable.
+//
+// ServeConn starts speaking HTTP/2 assuming that c has not had any
+// reads or writes. It writes its initial settings frame and expects
+// to be able to read the preface and settings frame from the
+// client. If c has a ConnectionState method like a *tls.Conn, the
+// ConnectionState is used to verify the TLS ciphersuite and to set
+// the Request.TLS field in Handlers.
+//
+// ServeConn does not support h2c by itself. Any h2c support must be
+// implemented in terms of providing a suitably-behaving net.Conn.
+//
+// The opts parameter is optional. If nil, default values are used.
+func (s *http2Server) ServeConn(c net.Conn, opts *http2ServeConnOpts) {
+ baseCtx, cancel := http2serverConnBaseContext(c, opts)
+ defer cancel()
+
+ sc := &http2serverConn{
+ srv: s,
+ hs: opts.baseConfig(),
+ conn: c,
+ baseCtx: baseCtx,
+ remoteAddrStr: c.RemoteAddr().String(),
+ bw: http2newBufferedWriter(c),
+ handler: opts.handler(),
+ streams: make(map[uint32]*http2stream),
+ readFrameCh: make(chan http2readFrameResult),
+ wantWriteFrameCh: make(chan http2FrameWriteRequest, 8),
+ serveMsgCh: make(chan interface{}, 8),
+ wroteFrameCh: make(chan http2frameWriteResult, 1), // buffered; one send in writeFrameAsync
+ bodyReadCh: make(chan http2bodyReadMsg), // buffering doesn't matter either way
+ doneServing: make(chan struct{}),
+ clientMaxStreams: math.MaxUint32, // Section 6.5.2: "Initially, there is no limit to this value"
+ advMaxStreams: s.maxConcurrentStreams(),
+ initialStreamSendWindowSize: http2initialWindowSize,
+ maxFrameSize: http2initialMaxFrameSize,
+ serveG: http2newGoroutineLock(),
+ pushEnabled: true,
+ sawClientPreface: opts.SawClientPreface,
+ }
+
+ s.state.registerConn(sc)
+ defer s.state.unregisterConn(sc)
+
+ // The net/http package sets the write deadline from the
+ // http.Server.WriteTimeout during the TLS handshake, but then
+ // passes the connection off to us with the deadline already set.
+ // Write deadlines are set per stream in serverConn.newStream.
+ // Disarm the net.Conn write deadline here.
+ if sc.hs.WriteTimeout != 0 {
+ sc.conn.SetWriteDeadline(time.Time{})
+ }
+
+ if s.NewWriteScheduler != nil {
+ sc.writeSched = s.NewWriteScheduler()
+ } else {
+ sc.writeSched = http2newRoundRobinWriteScheduler()
+ }
+
+ // These start at the RFC-specified defaults. If there is a higher
+ // configured value for inflow, that will be updated when we send a
+ // WINDOW_UPDATE shortly after sending SETTINGS.
+ sc.flow.add(http2initialWindowSize)
+ sc.inflow.init(http2initialWindowSize)
+ sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf)
+ sc.hpackEncoder.SetMaxDynamicTableSizeLimit(s.maxEncoderHeaderTableSize())
+
+ fr := http2NewFramer(sc.bw, c)
+ if s.CountError != nil {
+ fr.countError = s.CountError
+ }
+ fr.ReadMetaHeaders = hpack.NewDecoder(s.maxDecoderHeaderTableSize(), nil)
+ fr.MaxHeaderListSize = sc.maxHeaderListSize()
+ fr.SetMaxReadFrameSize(s.maxReadFrameSize())
+ sc.framer = fr
+
+ if tc, ok := c.(http2connectionStater); ok {
+ sc.tlsState = new(tls.ConnectionState)
+ *sc.tlsState = tc.ConnectionState()
+ // 9.2 Use of TLS Features
+ // An implementation of HTTP/2 over TLS MUST use TLS
+ // 1.2 or higher with the restrictions on feature set
+ // and cipher suite described in this section. Due to
+ // implementation limitations, it might not be
+ // possible to fail TLS negotiation. An endpoint MUST
+ // immediately terminate an HTTP/2 connection that
+ // does not meet the TLS requirements described in
+ // this section with a connection error (Section
+ // 5.4.1) of type INADEQUATE_SECURITY.
+ if sc.tlsState.Version < tls.VersionTLS12 {
+ sc.rejectConn(http2ErrCodeInadequateSecurity, "TLS version too low")
+ return
+ }
+
+ if sc.tlsState.ServerName == "" {
+ // Client must use SNI, but we don't enforce that anymore,
+ // since it was causing problems when connecting to bare IP
+ // addresses during development.
+ //
+ // TODO: optionally enforce? Or enforce at the time we receive
+ // a new request, and verify the ServerName matches the :authority?
+ // But that precludes proxy situations, perhaps.
+ //
+ // So for now, do nothing here again.
+ }
+
+ if !s.PermitProhibitedCipherSuites && http2isBadCipher(sc.tlsState.CipherSuite) {
+ // "Endpoints MAY choose to generate a connection error
+ // (Section 5.4.1) of type INADEQUATE_SECURITY if one of
+ // the prohibited cipher suites are negotiated."
+ //
+ // We choose that. In my opinion, the spec is weak
+ // here. It also says both parties must support at least
+ // TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 so there's no
+ // excuses here. If we really must, we could allow an
+ // "AllowInsecureWeakCiphers" option on the server later.
+ // Let's see how it plays out first.
+ sc.rejectConn(http2ErrCodeInadequateSecurity, fmt.Sprintf("Prohibited TLS 1.2 Cipher Suite: %x", sc.tlsState.CipherSuite))
+ return
+ }
+ }
+
+ if opts.Settings != nil {
+ fr := &http2SettingsFrame{
+ http2FrameHeader: http2FrameHeader{valid: true},
+ p: opts.Settings,
+ }
+ if err := fr.ForeachSetting(sc.processSetting); err != nil {
+ sc.rejectConn(http2ErrCodeProtocol, "invalid settings")
+ return
+ }
+ opts.Settings = nil
+ }
+
+ if hook := http2testHookGetServerConn; hook != nil {
+ hook(sc)
+ }
+
+ if opts.UpgradeRequest != nil {
+ sc.upgradeRequest(opts.UpgradeRequest)
+ opts.UpgradeRequest = nil
+ }
+
+ sc.serve()
+}
+
+func http2serverConnBaseContext(c net.Conn, opts *http2ServeConnOpts) (ctx context.Context, cancel func()) {
+ ctx, cancel = context.WithCancel(opts.context())
+ ctx = context.WithValue(ctx, LocalAddrContextKey, c.LocalAddr())
+ if hs := opts.baseConfig(); hs != nil {
+ ctx = context.WithValue(ctx, ServerContextKey, hs)
+ }
+ return
+}
+
+func (sc *http2serverConn) rejectConn(err http2ErrCode, debug string) {
+ sc.vlogf("http2: server rejecting conn: %v, %s", err, debug)
+ // ignoring errors. hanging up anyway.
+ sc.framer.WriteGoAway(0, err, []byte(debug))
+ sc.bw.Flush()
+ sc.conn.Close()
+}
+
+type http2serverConn struct {
+ // Immutable:
+ srv *http2Server
+ hs *Server
+ conn net.Conn
+ bw *http2bufferedWriter // writing to conn
+ handler Handler
+ baseCtx context.Context
+ framer *http2Framer
+ doneServing chan struct{} // closed when serverConn.serve ends
+ readFrameCh chan http2readFrameResult // written by serverConn.readFrames
+ wantWriteFrameCh chan http2FrameWriteRequest // from handlers -> serve
+ wroteFrameCh chan http2frameWriteResult // from writeFrameAsync -> serve, tickles more frame writes
+ bodyReadCh chan http2bodyReadMsg // from handlers -> serve
+ serveMsgCh chan interface{} // misc messages & code to send to / run on the serve loop
+ flow http2outflow // conn-wide (not stream-specific) outbound flow control
+ inflow http2inflow // conn-wide inbound flow control
+ tlsState *tls.ConnectionState // shared by all handlers, like net/http
+ remoteAddrStr string
+ writeSched http2WriteScheduler
+
+ // Everything following is owned by the serve loop; use serveG.check():
+ serveG http2goroutineLock // used to verify funcs are on serve()
+ pushEnabled bool
+ sawClientPreface bool // preface has already been read, used in h2c upgrade
+ sawFirstSettings bool // got the initial SETTINGS frame after the preface
+ needToSendSettingsAck bool
+ unackedSettings int // how many SETTINGS have we sent without ACKs?
+ queuedControlFrames int // control frames in the writeSched queue
+ clientMaxStreams uint32 // SETTINGS_MAX_CONCURRENT_STREAMS from client (our PUSH_PROMISE limit)
+ advMaxStreams uint32 // our SETTINGS_MAX_CONCURRENT_STREAMS advertised the client
+ curClientStreams uint32 // number of open streams initiated by the client
+ curPushedStreams uint32 // number of open streams initiated by server push
+ curHandlers uint32 // number of running handler goroutines
+ maxClientStreamID uint32 // max ever seen from client (odd), or 0 if there have been no client requests
+ maxPushPromiseID uint32 // ID of the last push promise (even), or 0 if there have been no pushes
+ streams map[uint32]*http2stream
+ unstartedHandlers []http2unstartedHandler
+ initialStreamSendWindowSize int32
+ maxFrameSize int32
+ peerMaxHeaderListSize uint32 // zero means unknown (default)
+ canonHeader map[string]string // http2-lower-case -> Go-Canonical-Case
+ canonHeaderKeysSize int // canonHeader keys size in bytes
+ writingFrame bool // started writing a frame (on serve goroutine or separate)
+ writingFrameAsync bool // started a frame on its own goroutine but haven't heard back on wroteFrameCh
+ needsFrameFlush bool // last frame write wasn't a flush
+ inGoAway bool // we've started to or sent GOAWAY
+ inFrameScheduleLoop bool // whether we're in the scheduleFrameWrite loop
+ needToSendGoAway bool // we need to schedule a GOAWAY frame write
+ goAwayCode http2ErrCode
+ shutdownTimer *time.Timer // nil until used
+ idleTimer *time.Timer // nil if unused
+
+ // Owned by the writeFrameAsync goroutine:
+ headerWriteBuf bytes.Buffer
+ hpackEncoder *hpack.Encoder
+
+ // Used by startGracefulShutdown.
+ shutdownOnce sync.Once
+}
+
+func (sc *http2serverConn) maxHeaderListSize() uint32 {
+ n := sc.hs.MaxHeaderBytes
+ if n <= 0 {
+ n = DefaultMaxHeaderBytes
+ }
+ // http2's count is in a slightly different unit and includes 32 bytes per pair.
+ // So, take the net/http.Server value and pad it up a bit, assuming 10 headers.
+ const perFieldOverhead = 32 // per http2 spec
+ const typicalHeaders = 10 // conservative
+ return uint32(n + typicalHeaders*perFieldOverhead)
+}
+
+func (sc *http2serverConn) curOpenStreams() uint32 {
+ sc.serveG.check()
+ return sc.curClientStreams + sc.curPushedStreams
+}
+
+// stream represents a stream. This is the minimal metadata needed by
+// the serve goroutine. Most of the actual stream state is owned by
+// the http.Handler's goroutine in the responseWriter. Because the
+// responseWriter's responseWriterState is recycled at the end of a
+// handler, this struct intentionally has no pointer to the
+// *responseWriter{,State} itself, as the Handler ending nils out the
+// responseWriter's state field.
+type http2stream struct {
+ // immutable:
+ sc *http2serverConn
+ id uint32
+ body *http2pipe // non-nil if expecting DATA frames
+ cw http2closeWaiter // closed wait stream transitions to closed state
+ ctx context.Context
+ cancelCtx func()
+
+ // owned by serverConn's serve loop:
+ bodyBytes int64 // body bytes seen so far
+ declBodyBytes int64 // or -1 if undeclared
+ flow http2outflow // limits writing from Handler to client
+ inflow http2inflow // what the client is allowed to POST/etc to us
+ state http2streamState
+ resetQueued bool // RST_STREAM queued for write; set by sc.resetStream
+ gotTrailerHeader bool // HEADER frame for trailers was seen
+ wroteHeaders bool // whether we wrote headers (not status 100)
+ readDeadline *time.Timer // nil if unused
+ writeDeadline *time.Timer // nil if unused
+ closeErr error // set before cw is closed
+
+ trailer Header // accumulated trailers
+ reqTrailer Header // handler's Request.Trailer
+}
+
+func (sc *http2serverConn) Framer() *http2Framer { return sc.framer }
+
+func (sc *http2serverConn) CloseConn() error { return sc.conn.Close() }
+
+func (sc *http2serverConn) Flush() error { return sc.bw.Flush() }
+
+func (sc *http2serverConn) HeaderEncoder() (*hpack.Encoder, *bytes.Buffer) {
+ return sc.hpackEncoder, &sc.headerWriteBuf
+}
+
+func (sc *http2serverConn) state(streamID uint32) (http2streamState, *http2stream) {
+ sc.serveG.check()
+ // http://tools.ietf.org/html/rfc7540#section-5.1
+ if st, ok := sc.streams[streamID]; ok {
+ return st.state, st
+ }
+ // "The first use of a new stream identifier implicitly closes all
+ // streams in the "idle" state that might have been initiated by
+ // that peer with a lower-valued stream identifier. For example, if
+ // a client sends a HEADERS frame on stream 7 without ever sending a
+ // frame on stream 5, then stream 5 transitions to the "closed"
+ // state when the first frame for stream 7 is sent or received."
+ if streamID%2 == 1 {
+ if streamID <= sc.maxClientStreamID {
+ return http2stateClosed, nil
+ }
+ } else {
+ if streamID <= sc.maxPushPromiseID {
+ return http2stateClosed, nil
+ }
+ }
+ return http2stateIdle, nil
+}
+
+// setConnState calls the net/http ConnState hook for this connection, if configured.
+// Note that the net/http package does StateNew and StateClosed for us.
+// There is currently no plan for StateHijacked or hijacking HTTP/2 connections.
+func (sc *http2serverConn) setConnState(state ConnState) {
+ if sc.hs.ConnState != nil {
+ sc.hs.ConnState(sc.conn, state)
+ }
+}
+
+func (sc *http2serverConn) vlogf(format string, args ...interface{}) {
+ if http2VerboseLogs {
+ sc.logf(format, args...)
+ }
+}
+
+func (sc *http2serverConn) logf(format string, args ...interface{}) {
+ if lg := sc.hs.ErrorLog; lg != nil {
+ lg.Printf(format, args...)
+ } else {
+ log.Printf(format, args...)
+ }
+}
+
+// errno returns v's underlying uintptr, else 0.
+//
+// TODO: remove this helper function once http2 can use build
+// tags. See comment in isClosedConnError.
+func http2errno(v error) uintptr {
+ if rv := reflect.ValueOf(v); rv.Kind() == reflect.Uintptr {
+ return uintptr(rv.Uint())
+ }
+ return 0
+}
+
+// isClosedConnError reports whether err is an error from use of a closed
+// network connection.
+func http2isClosedConnError(err error) bool {
+ if err == nil {
+ return false
+ }
+
+ // TODO: remove this string search and be more like the Windows
+ // case below. That might involve modifying the standard library
+ // to return better error types.
+ str := err.Error()
+ if strings.Contains(str, "use of closed network connection") {
+ return true
+ }
+
+ // TODO(bradfitz): x/tools/cmd/bundle doesn't really support
+ // build tags, so I can't make an http2_windows.go file with
+ // Windows-specific stuff. Fix that and move this, once we
+ // have a way to bundle this into std's net/http somehow.
+ if runtime.GOOS == "windows" {
+ if oe, ok := err.(*net.OpError); ok && oe.Op == "read" {
+ if se, ok := oe.Err.(*os.SyscallError); ok && se.Syscall == "wsarecv" {
+ const WSAECONNABORTED = 10053
+ const WSAECONNRESET = 10054
+ if n := http2errno(se.Err); n == WSAECONNRESET || n == WSAECONNABORTED {
+ return true
+ }
+ }
+ }
+ }
+ return false
+}
+
+func (sc *http2serverConn) condlogf(err error, format string, args ...interface{}) {
+ if err == nil {
+ return
+ }
+ if err == io.EOF || err == io.ErrUnexpectedEOF || http2isClosedConnError(err) || err == http2errPrefaceTimeout {
+ // Boring, expected errors.
+ sc.vlogf(format, args...)
+ } else {
+ sc.logf(format, args...)
+ }
+}
+
+// maxCachedCanonicalHeadersKeysSize is an arbitrarily-chosen limit on the size
+// of the entries in the canonHeader cache.
+// This should be larger than the size of unique, uncommon header keys likely to
+// be sent by the peer, while not so high as to permit unreasonable memory usage
+// if the peer sends an unbounded number of unique header keys.
+const http2maxCachedCanonicalHeadersKeysSize = 2048
+
+func (sc *http2serverConn) canonicalHeader(v string) string {
+ sc.serveG.check()
+ http2buildCommonHeaderMapsOnce()
+ cv, ok := http2commonCanonHeader[v]
+ if ok {
+ return cv
+ }
+ cv, ok = sc.canonHeader[v]
+ if ok {
+ return cv
+ }
+ if sc.canonHeader == nil {
+ sc.canonHeader = make(map[string]string)
+ }
+ cv = CanonicalHeaderKey(v)
+ size := 100 + len(v)*2 // 100 bytes of map overhead + key + value
+ if sc.canonHeaderKeysSize+size <= http2maxCachedCanonicalHeadersKeysSize {
+ sc.canonHeader[v] = cv
+ sc.canonHeaderKeysSize += size
+ }
+ return cv
+}
+
+type http2readFrameResult struct {
+ f http2Frame // valid until readMore is called
+ err error
+
+ // readMore should be called once the consumer no longer needs or
+ // retains f. After readMore, f is invalid and more frames can be
+ // read.
+ readMore func()
+}
+
+// readFrames is the loop that reads incoming frames.
+// It takes care to only read one frame at a time, blocking until the
+// consumer is done with the frame.
+// It's run on its own goroutine.
+func (sc *http2serverConn) readFrames() {
+ gate := make(http2gate)
+ gateDone := gate.Done
+ for {
+ f, err := sc.framer.ReadFrame()
+ select {
+ case sc.readFrameCh <- http2readFrameResult{f, err, gateDone}:
+ case <-sc.doneServing:
+ return
+ }
+ select {
+ case <-gate:
+ case <-sc.doneServing:
+ return
+ }
+ if http2terminalReadFrameError(err) {
+ return
+ }
+ }
+}
+
+// frameWriteResult is the message passed from writeFrameAsync to the serve goroutine.
+type http2frameWriteResult struct {
+ _ http2incomparable
+ wr http2FrameWriteRequest // what was written (or attempted)
+ err error // result of the writeFrame call
+}
+
+// writeFrameAsync runs in its own goroutine and writes a single frame
+// and then reports when it's done.
+// At most one goroutine can be running writeFrameAsync at a time per
+// serverConn.
+func (sc *http2serverConn) writeFrameAsync(wr http2FrameWriteRequest, wd *http2writeData) {
+ var err error
+ if wd == nil {
+ err = wr.write.writeFrame(sc)
+ } else {
+ err = sc.framer.endWrite()
+ }
+ sc.wroteFrameCh <- http2frameWriteResult{wr: wr, err: err}
+}
+
+func (sc *http2serverConn) closeAllStreamsOnConnClose() {
+ sc.serveG.check()
+ for _, st := range sc.streams {
+ sc.closeStream(st, http2errClientDisconnected)
+ }
+}
+
+func (sc *http2serverConn) stopShutdownTimer() {
+ sc.serveG.check()
+ if t := sc.shutdownTimer; t != nil {
+ t.Stop()
+ }
+}
+
+func (sc *http2serverConn) notePanic() {
+ // Note: this is for serverConn.serve panicking, not http.Handler code.
+ if http2testHookOnPanicMu != nil {
+ http2testHookOnPanicMu.Lock()
+ defer http2testHookOnPanicMu.Unlock()
+ }
+ if http2testHookOnPanic != nil {
+ if e := recover(); e != nil {
+ if http2testHookOnPanic(sc, e) {
+ panic(e)
+ }
+ }
+ }
+}
+
+func (sc *http2serverConn) serve() {
+ sc.serveG.check()
+ defer sc.notePanic()
+ defer sc.conn.Close()
+ defer sc.closeAllStreamsOnConnClose()
+ defer sc.stopShutdownTimer()
+ defer close(sc.doneServing) // unblocks handlers trying to send
+
+ if http2VerboseLogs {
+ sc.vlogf("http2: server connection from %v on %p", sc.conn.RemoteAddr(), sc.hs)
+ }
+
+ sc.writeFrame(http2FrameWriteRequest{
+ write: http2writeSettings{
+ {http2SettingMaxFrameSize, sc.srv.maxReadFrameSize()},
+ {http2SettingMaxConcurrentStreams, sc.advMaxStreams},
+ {http2SettingMaxHeaderListSize, sc.maxHeaderListSize()},
+ {http2SettingHeaderTableSize, sc.srv.maxDecoderHeaderTableSize()},
+ {http2SettingInitialWindowSize, uint32(sc.srv.initialStreamRecvWindowSize())},
+ },
+ })
+ sc.unackedSettings++
+
+ // Each connection starts with initialWindowSize inflow tokens.
+ // If a higher value is configured, we add more tokens.
+ if diff := sc.srv.initialConnRecvWindowSize() - http2initialWindowSize; diff > 0 {
+ sc.sendWindowUpdate(nil, int(diff))
+ }
+
+ if err := sc.readPreface(); err != nil {
+ sc.condlogf(err, "http2: server: error reading preface from client %v: %v", sc.conn.RemoteAddr(), err)
+ return
+ }
+ // Now that we've got the preface, get us out of the
+ // "StateNew" state. We can't go directly to idle, though.
+ // Active means we read some data and anticipate a request. We'll
+ // do another Active when we get a HEADERS frame.
+ sc.setConnState(StateActive)
+ sc.setConnState(StateIdle)
+
+ if sc.srv.IdleTimeout != 0 {
+ sc.idleTimer = time.AfterFunc(sc.srv.IdleTimeout, sc.onIdleTimer)
+ defer sc.idleTimer.Stop()
+ }
+
+ go sc.readFrames() // closed by defer sc.conn.Close above
+
+ settingsTimer := time.AfterFunc(http2firstSettingsTimeout, sc.onSettingsTimer)
+ defer settingsTimer.Stop()
+
+ loopNum := 0
+ for {
+ loopNum++
+ select {
+ case wr := <-sc.wantWriteFrameCh:
+ if se, ok := wr.write.(http2StreamError); ok {
+ sc.resetStream(se)
+ break
+ }
+ sc.writeFrame(wr)
+ case res := <-sc.wroteFrameCh:
+ sc.wroteFrame(res)
+ case res := <-sc.readFrameCh:
+ // Process any written frames before reading new frames from the client since a
+ // written frame could have triggered a new stream to be started.
+ if sc.writingFrameAsync {
+ select {
+ case wroteRes := <-sc.wroteFrameCh:
+ sc.wroteFrame(wroteRes)
+ default:
+ }
+ }
+ if !sc.processFrameFromReader(res) {
+ return
+ }
+ res.readMore()
+ if settingsTimer != nil {
+ settingsTimer.Stop()
+ settingsTimer = nil
+ }
+ case m := <-sc.bodyReadCh:
+ sc.noteBodyRead(m.st, m.n)
+ case msg := <-sc.serveMsgCh:
+ switch v := msg.(type) {
+ case func(int):
+ v(loopNum) // for testing
+ case *http2serverMessage:
+ switch v {
+ case http2settingsTimerMsg:
+ sc.logf("timeout waiting for SETTINGS frames from %v", sc.conn.RemoteAddr())
+ return
+ case http2idleTimerMsg:
+ sc.vlogf("connection is idle")
+ sc.goAway(http2ErrCodeNo)
+ case http2shutdownTimerMsg:
+ sc.vlogf("GOAWAY close timer fired; closing conn from %v", sc.conn.RemoteAddr())
+ return
+ case http2gracefulShutdownMsg:
+ sc.startGracefulShutdownInternal()
+ case http2handlerDoneMsg:
+ sc.handlerDone()
+ default:
+ panic("unknown timer")
+ }
+ case *http2startPushRequest:
+ sc.startPush(v)
+ case func(*http2serverConn):
+ v(sc)
+ default:
+ panic(fmt.Sprintf("unexpected type %T", v))
+ }
+ }
+
+ // If the peer is causing us to generate a lot of control frames,
+ // but not reading them from us, assume they are trying to make us
+ // run out of memory.
+ if sc.queuedControlFrames > sc.srv.maxQueuedControlFrames() {
+ sc.vlogf("http2: too many control frames in send queue, closing connection")
+ return
+ }
+
+ // Start the shutdown timer after sending a GOAWAY. When sending GOAWAY
+ // with no error code (graceful shutdown), don't start the timer until
+ // all open streams have been completed.
+ sentGoAway := sc.inGoAway && !sc.needToSendGoAway && !sc.writingFrame
+ gracefulShutdownComplete := sc.goAwayCode == http2ErrCodeNo && sc.curOpenStreams() == 0
+ if sentGoAway && sc.shutdownTimer == nil && (sc.goAwayCode != http2ErrCodeNo || gracefulShutdownComplete) {
+ sc.shutDownIn(http2goAwayTimeout)
+ }
+ }
+}
+
+func (sc *http2serverConn) awaitGracefulShutdown(sharedCh <-chan struct{}, privateCh chan struct{}) {
+ select {
+ case <-sc.doneServing:
+ case <-sharedCh:
+ close(privateCh)
+ }
+}
+
+type http2serverMessage int
+
+// Message values sent to serveMsgCh.
+var (
+ http2settingsTimerMsg = new(http2serverMessage)
+ http2idleTimerMsg = new(http2serverMessage)
+ http2shutdownTimerMsg = new(http2serverMessage)
+ http2gracefulShutdownMsg = new(http2serverMessage)
+ http2handlerDoneMsg = new(http2serverMessage)
+)
+
+func (sc *http2serverConn) onSettingsTimer() { sc.sendServeMsg(http2settingsTimerMsg) }
+
+func (sc *http2serverConn) onIdleTimer() { sc.sendServeMsg(http2idleTimerMsg) }
+
+func (sc *http2serverConn) onShutdownTimer() { sc.sendServeMsg(http2shutdownTimerMsg) }
+
+func (sc *http2serverConn) sendServeMsg(msg interface{}) {
+ sc.serveG.checkNotOn() // NOT
+ select {
+ case sc.serveMsgCh <- msg:
+ case <-sc.doneServing:
+ }
+}
+
+var http2errPrefaceTimeout = errors.New("timeout waiting for client preface")
+
+// readPreface reads the ClientPreface greeting from the peer or
+// returns errPrefaceTimeout on timeout, or an error if the greeting
+// is invalid.
+func (sc *http2serverConn) readPreface() error {
+ if sc.sawClientPreface {
+ return nil
+ }
+ errc := make(chan error, 1)
+ go func() {
+ // Read the client preface
+ buf := make([]byte, len(http2ClientPreface))
+ if _, err := io.ReadFull(sc.conn, buf); err != nil {
+ errc <- err
+ } else if !bytes.Equal(buf, http2clientPreface) {
+ errc <- fmt.Errorf("bogus greeting %q", buf)
+ } else {
+ errc <- nil
+ }
+ }()
+ timer := time.NewTimer(http2prefaceTimeout) // TODO: configurable on *Server?
+ defer timer.Stop()
+ select {
+ case <-timer.C:
+ return http2errPrefaceTimeout
+ case err := <-errc:
+ if err == nil {
+ if http2VerboseLogs {
+ sc.vlogf("http2: server: client %v said hello", sc.conn.RemoteAddr())
+ }
+ }
+ return err
+ }
+}
+
+var http2errChanPool = sync.Pool{
+ New: func() interface{} { return make(chan error, 1) },
+}
+
+var http2writeDataPool = sync.Pool{
+ New: func() interface{} { return new(http2writeData) },
+}
+
+// writeDataFromHandler writes DATA response frames from a handler on
+// the given stream.
+func (sc *http2serverConn) writeDataFromHandler(stream *http2stream, data []byte, endStream bool) error {
+ ch := http2errChanPool.Get().(chan error)
+ writeArg := http2writeDataPool.Get().(*http2writeData)
+ *writeArg = http2writeData{stream.id, data, endStream}
+ err := sc.writeFrameFromHandler(http2FrameWriteRequest{
+ write: writeArg,
+ stream: stream,
+ done: ch,
+ })
+ if err != nil {
+ return err
+ }
+ var frameWriteDone bool // the frame write is done (successfully or not)
+ select {
+ case err = <-ch:
+ frameWriteDone = true
+ case <-sc.doneServing:
+ return http2errClientDisconnected
+ case <-stream.cw:
+ // If both ch and stream.cw were ready (as might
+ // happen on the final Write after an http.Handler
+ // ends), prefer the write result. Otherwise this
+ // might just be us successfully closing the stream.
+ // The writeFrameAsync and serve goroutines guarantee
+ // that the ch send will happen before the stream.cw
+ // close.
+ select {
+ case err = <-ch:
+ frameWriteDone = true
+ default:
+ return http2errStreamClosed
+ }
+ }
+ http2errChanPool.Put(ch)
+ if frameWriteDone {
+ http2writeDataPool.Put(writeArg)
+ }
+ return err
+}
+
+// writeFrameFromHandler sends wr to sc.wantWriteFrameCh, but aborts
+// if the connection has gone away.
+//
+// This must not be run from the serve goroutine itself, else it might
+// deadlock writing to sc.wantWriteFrameCh (which is only mildly
+// buffered and is read by serve itself). If you're on the serve
+// goroutine, call writeFrame instead.
+func (sc *http2serverConn) writeFrameFromHandler(wr http2FrameWriteRequest) error {
+ sc.serveG.checkNotOn() // NOT
+ select {
+ case sc.wantWriteFrameCh <- wr:
+ return nil
+ case <-sc.doneServing:
+ // Serve loop is gone.
+ // Client has closed their connection to the server.
+ return http2errClientDisconnected
+ }
+}
+
+// writeFrame schedules a frame to write and sends it if there's nothing
+// already being written.
+//
+// There is no pushback here (the serve goroutine never blocks). It's
+// the http.Handlers that block, waiting for their previous frames to
+// make it onto the wire
+//
+// If you're not on the serve goroutine, use writeFrameFromHandler instead.
+func (sc *http2serverConn) writeFrame(wr http2FrameWriteRequest) {
+ sc.serveG.check()
+
+ // If true, wr will not be written and wr.done will not be signaled.
+ var ignoreWrite bool
+
+ // We are not allowed to write frames on closed streams. RFC 7540 Section
+ // 5.1.1 says: "An endpoint MUST NOT send frames other than PRIORITY on
+ // a closed stream." Our server never sends PRIORITY, so that exception
+ // does not apply.
+ //
+ // The serverConn might close an open stream while the stream's handler
+ // is still running. For example, the server might close a stream when it
+ // receives bad data from the client. If this happens, the handler might
+ // attempt to write a frame after the stream has been closed (since the
+ // handler hasn't yet been notified of the close). In this case, we simply
+ // ignore the frame. The handler will notice that the stream is closed when
+ // it waits for the frame to be written.
+ //
+ // As an exception to this rule, we allow sending RST_STREAM after close.
+ // This allows us to immediately reject new streams without tracking any
+ // state for those streams (except for the queued RST_STREAM frame). This
+ // may result in duplicate RST_STREAMs in some cases, but the client should
+ // ignore those.
+ if wr.StreamID() != 0 {
+ _, isReset := wr.write.(http2StreamError)
+ if state, _ := sc.state(wr.StreamID()); state == http2stateClosed && !isReset {
+ ignoreWrite = true
+ }
+ }
+
+ // Don't send a 100-continue response if we've already sent headers.
+ // See golang.org/issue/14030.
+ switch wr.write.(type) {
+ case *http2writeResHeaders:
+ wr.stream.wroteHeaders = true
+ case http2write100ContinueHeadersFrame:
+ if wr.stream.wroteHeaders {
+ // We do not need to notify wr.done because this frame is
+ // never written with wr.done != nil.
+ if wr.done != nil {
+ panic("wr.done != nil for write100ContinueHeadersFrame")
+ }
+ ignoreWrite = true
+ }
+ }
+
+ if !ignoreWrite {
+ if wr.isControl() {
+ sc.queuedControlFrames++
+ // For extra safety, detect wraparounds, which should not happen,
+ // and pull the plug.
+ if sc.queuedControlFrames < 0 {
+ sc.conn.Close()
+ }
+ }
+ sc.writeSched.Push(wr)
+ }
+ sc.scheduleFrameWrite()
+}
+
+// startFrameWrite starts a goroutine to write wr (in a separate
+// goroutine since that might block on the network), and updates the
+// serve goroutine's state about the world, updated from info in wr.
+func (sc *http2serverConn) startFrameWrite(wr http2FrameWriteRequest) {
+ sc.serveG.check()
+ if sc.writingFrame {
+ panic("internal error: can only be writing one frame at a time")
+ }
+
+ st := wr.stream
+ if st != nil {
+ switch st.state {
+ case http2stateHalfClosedLocal:
+ switch wr.write.(type) {
+ case http2StreamError, http2handlerPanicRST, http2writeWindowUpdate:
+ // RFC 7540 Section 5.1 allows sending RST_STREAM, PRIORITY, and WINDOW_UPDATE
+ // in this state. (We never send PRIORITY from the server, so that is not checked.)
+ default:
+ panic(fmt.Sprintf("internal error: attempt to send frame on a half-closed-local stream: %v", wr))
+ }
+ case http2stateClosed:
+ panic(fmt.Sprintf("internal error: attempt to send frame on a closed stream: %v", wr))
+ }
+ }
+ if wpp, ok := wr.write.(*http2writePushPromise); ok {
+ var err error
+ wpp.promisedID, err = wpp.allocatePromisedID()
+ if err != nil {
+ sc.writingFrameAsync = false
+ wr.replyToWriter(err)
+ return
+ }
+ }
+
+ sc.writingFrame = true
+ sc.needsFrameFlush = true
+ if wr.write.staysWithinBuffer(sc.bw.Available()) {
+ sc.writingFrameAsync = false
+ err := wr.write.writeFrame(sc)
+ sc.wroteFrame(http2frameWriteResult{wr: wr, err: err})
+ } else if wd, ok := wr.write.(*http2writeData); ok {
+ // Encode the frame in the serve goroutine, to ensure we don't have
+ // any lingering asynchronous references to data passed to Write.
+ // See https://go.dev/issue/58446.
+ sc.framer.startWriteDataPadded(wd.streamID, wd.endStream, wd.p, nil)
+ sc.writingFrameAsync = true
+ go sc.writeFrameAsync(wr, wd)
+ } else {
+ sc.writingFrameAsync = true
+ go sc.writeFrameAsync(wr, nil)
+ }
+}
+
+// errHandlerPanicked is the error given to any callers blocked in a read from
+// Request.Body when the main goroutine panics. Since most handlers read in the
+// main ServeHTTP goroutine, this will show up rarely.
+var http2errHandlerPanicked = errors.New("http2: handler panicked")
+
+// wroteFrame is called on the serve goroutine with the result of
+// whatever happened on writeFrameAsync.
+func (sc *http2serverConn) wroteFrame(res http2frameWriteResult) {
+ sc.serveG.check()
+ if !sc.writingFrame {
+ panic("internal error: expected to be already writing a frame")
+ }
+ sc.writingFrame = false
+ sc.writingFrameAsync = false
+
+ wr := res.wr
+
+ if http2writeEndsStream(wr.write) {
+ st := wr.stream
+ if st == nil {
+ panic("internal error: expecting non-nil stream")
+ }
+ switch st.state {
+ case http2stateOpen:
+ // Here we would go to stateHalfClosedLocal in
+ // theory, but since our handler is done and
+ // the net/http package provides no mechanism
+ // for closing a ResponseWriter while still
+ // reading data (see possible TODO at top of
+ // this file), we go into closed state here
+ // anyway, after telling the peer we're
+ // hanging up on them. We'll transition to
+ // stateClosed after the RST_STREAM frame is
+ // written.
+ st.state = http2stateHalfClosedLocal
+ // Section 8.1: a server MAY request that the client abort
+ // transmission of a request without error by sending a
+ // RST_STREAM with an error code of NO_ERROR after sending
+ // a complete response.
+ sc.resetStream(http2streamError(st.id, http2ErrCodeNo))
+ case http2stateHalfClosedRemote:
+ sc.closeStream(st, http2errHandlerComplete)
+ }
+ } else {
+ switch v := wr.write.(type) {
+ case http2StreamError:
+ // st may be unknown if the RST_STREAM was generated to reject bad input.
+ if st, ok := sc.streams[v.StreamID]; ok {
+ sc.closeStream(st, v)
+ }
+ case http2handlerPanicRST:
+ sc.closeStream(wr.stream, http2errHandlerPanicked)
+ }
+ }
+
+ // Reply (if requested) to unblock the ServeHTTP goroutine.
+ wr.replyToWriter(res.err)
+
+ sc.scheduleFrameWrite()
+}
+
+// scheduleFrameWrite tickles the frame writing scheduler.
+//
+// If a frame is already being written, nothing happens. This will be called again
+// when the frame is done being written.
+//
+// If a frame isn't being written and we need to send one, the best frame
+// to send is selected by writeSched.
+//
+// If a frame isn't being written and there's nothing else to send, we
+// flush the write buffer.
+func (sc *http2serverConn) scheduleFrameWrite() {
+ sc.serveG.check()
+ if sc.writingFrame || sc.inFrameScheduleLoop {
+ return
+ }
+ sc.inFrameScheduleLoop = true
+ for !sc.writingFrameAsync {
+ if sc.needToSendGoAway {
+ sc.needToSendGoAway = false
+ sc.startFrameWrite(http2FrameWriteRequest{
+ write: &http2writeGoAway{
+ maxStreamID: sc.maxClientStreamID,
+ code: sc.goAwayCode,
+ },
+ })
+ continue
+ }
+ if sc.needToSendSettingsAck {
+ sc.needToSendSettingsAck = false
+ sc.startFrameWrite(http2FrameWriteRequest{write: http2writeSettingsAck{}})
+ continue
+ }
+ if !sc.inGoAway || sc.goAwayCode == http2ErrCodeNo {
+ if wr, ok := sc.writeSched.Pop(); ok {
+ if wr.isControl() {
+ sc.queuedControlFrames--
+ }
+ sc.startFrameWrite(wr)
+ continue
+ }
+ }
+ if sc.needsFrameFlush {
+ sc.startFrameWrite(http2FrameWriteRequest{write: http2flushFrameWriter{}})
+ sc.needsFrameFlush = false // after startFrameWrite, since it sets this true
+ continue
+ }
+ break
+ }
+ sc.inFrameScheduleLoop = false
+}
+
+// startGracefulShutdown gracefully shuts down a connection. This
+// sends GOAWAY with ErrCodeNo to tell the client we're gracefully
+// shutting down. The connection isn't closed until all current
+// streams are done.
+//
+// startGracefulShutdown returns immediately; it does not wait until
+// the connection has shut down.
+func (sc *http2serverConn) startGracefulShutdown() {
+ sc.serveG.checkNotOn() // NOT
+ sc.shutdownOnce.Do(func() { sc.sendServeMsg(http2gracefulShutdownMsg) })
+}
+
+// After sending GOAWAY with an error code (non-graceful shutdown), the
+// connection will close after goAwayTimeout.
+//
+// If we close the connection immediately after sending GOAWAY, there may
+// be unsent data in our kernel receive buffer, which will cause the kernel
+// to send a TCP RST on close() instead of a FIN. This RST will abort the
+// connection immediately, whether or not the client had received the GOAWAY.
+//
+// Ideally we should delay for at least 1 RTT + epsilon so the client has
+// a chance to read the GOAWAY and stop sending messages. Measuring RTT
+// is hard, so we approximate with 1 second. See golang.org/issue/18701.
+//
+// This is a var so it can be shorter in tests, where all requests uses the
+// loopback interface making the expected RTT very small.
+//
+// TODO: configurable?
+var http2goAwayTimeout = 1 * time.Second
+
+func (sc *http2serverConn) startGracefulShutdownInternal() {
+ sc.goAway(http2ErrCodeNo)
+}
+
+func (sc *http2serverConn) goAway(code http2ErrCode) {
+ sc.serveG.check()
+ if sc.inGoAway {
+ if sc.goAwayCode == http2ErrCodeNo {
+ sc.goAwayCode = code
+ }
+ return
+ }
+ sc.inGoAway = true
+ sc.needToSendGoAway = true
+ sc.goAwayCode = code
+ sc.scheduleFrameWrite()
+}
+
+func (sc *http2serverConn) shutDownIn(d time.Duration) {
+ sc.serveG.check()
+ sc.shutdownTimer = time.AfterFunc(d, sc.onShutdownTimer)
+}
+
+func (sc *http2serverConn) resetStream(se http2StreamError) {
+ sc.serveG.check()
+ sc.writeFrame(http2FrameWriteRequest{write: se})
+ if st, ok := sc.streams[se.StreamID]; ok {
+ st.resetQueued = true
+ }
+}
+
+// processFrameFromReader processes the serve loop's read from readFrameCh from the
+// frame-reading goroutine.
+// processFrameFromReader returns whether the connection should be kept open.
+func (sc *http2serverConn) processFrameFromReader(res http2readFrameResult) bool {
+ sc.serveG.check()
+ err := res.err
+ if err != nil {
+ if err == http2ErrFrameTooLarge {
+ sc.goAway(http2ErrCodeFrameSize)
+ return true // goAway will close the loop
+ }
+ clientGone := err == io.EOF || err == io.ErrUnexpectedEOF || http2isClosedConnError(err)
+ if clientGone {
+ // TODO: could we also get into this state if
+ // the peer does a half close
+ // (e.g. CloseWrite) because they're done
+ // sending frames but they're still wanting
+ // our open replies? Investigate.
+ // TODO: add CloseWrite to crypto/tls.Conn first
+ // so we have a way to test this? I suppose
+ // just for testing we could have a non-TLS mode.
+ return false
+ }
+ } else {
+ f := res.f
+ if http2VerboseLogs {
+ sc.vlogf("http2: server read frame %v", http2summarizeFrame(f))
+ }
+ err = sc.processFrame(f)
+ if err == nil {
+ return true
+ }
+ }
+
+ switch ev := err.(type) {
+ case http2StreamError:
+ sc.resetStream(ev)
+ return true
+ case http2goAwayFlowError:
+ sc.goAway(http2ErrCodeFlowControl)
+ return true
+ case http2ConnectionError:
+ sc.logf("http2: server connection error from %v: %v", sc.conn.RemoteAddr(), ev)
+ sc.goAway(http2ErrCode(ev))
+ return true // goAway will handle shutdown
+ default:
+ if res.err != nil {
+ sc.vlogf("http2: server closing client connection; error reading frame from client %s: %v", sc.conn.RemoteAddr(), err)
+ } else {
+ sc.logf("http2: server closing client connection: %v", err)
+ }
+ return false
+ }
+}
+
+func (sc *http2serverConn) processFrame(f http2Frame) error {
+ sc.serveG.check()
+
+ // First frame received must be SETTINGS.
+ if !sc.sawFirstSettings {
+ if _, ok := f.(*http2SettingsFrame); !ok {
+ return sc.countError("first_settings", http2ConnectionError(http2ErrCodeProtocol))
+ }
+ sc.sawFirstSettings = true
+ }
+
+ // Discard frames for streams initiated after the identified last
+ // stream sent in a GOAWAY, or all frames after sending an error.
+ // We still need to return connection-level flow control for DATA frames.
+ // RFC 9113 Section 6.8.
+ if sc.inGoAway && (sc.goAwayCode != http2ErrCodeNo || f.Header().StreamID > sc.maxClientStreamID) {
+
+ if f, ok := f.(*http2DataFrame); ok {
+ if !sc.inflow.take(f.Length) {
+ return sc.countError("data_flow", http2streamError(f.Header().StreamID, http2ErrCodeFlowControl))
+ }
+ sc.sendWindowUpdate(nil, int(f.Length)) // conn-level
+ }
+ return nil
+ }
+
+ switch f := f.(type) {
+ case *http2SettingsFrame:
+ return sc.processSettings(f)
+ case *http2MetaHeadersFrame:
+ return sc.processHeaders(f)
+ case *http2WindowUpdateFrame:
+ return sc.processWindowUpdate(f)
+ case *http2PingFrame:
+ return sc.processPing(f)
+ case *http2DataFrame:
+ return sc.processData(f)
+ case *http2RSTStreamFrame:
+ return sc.processResetStream(f)
+ case *http2PriorityFrame:
+ return sc.processPriority(f)
+ case *http2GoAwayFrame:
+ return sc.processGoAway(f)
+ case *http2PushPromiseFrame:
+ // A client cannot push. Thus, servers MUST treat the receipt of a PUSH_PROMISE
+ // frame as a connection error (Section 5.4.1) of type PROTOCOL_ERROR.
+ return sc.countError("push_promise", http2ConnectionError(http2ErrCodeProtocol))
+ default:
+ sc.vlogf("http2: server ignoring frame: %v", f.Header())
+ return nil
+ }
+}
+
+func (sc *http2serverConn) processPing(f *http2PingFrame) error {
+ sc.serveG.check()
+ if f.IsAck() {
+ // 6.7 PING: " An endpoint MUST NOT respond to PING frames
+ // containing this flag."
+ return nil
+ }
+ if f.StreamID != 0 {
+ // "PING frames are not associated with any individual
+ // stream. If a PING frame is received with a stream
+ // identifier field value other than 0x0, the recipient MUST
+ // respond with a connection error (Section 5.4.1) of type
+ // PROTOCOL_ERROR."
+ return sc.countError("ping_on_stream", http2ConnectionError(http2ErrCodeProtocol))
+ }
+ sc.writeFrame(http2FrameWriteRequest{write: http2writePingAck{f}})
+ return nil
+}
+
+func (sc *http2serverConn) processWindowUpdate(f *http2WindowUpdateFrame) error {
+ sc.serveG.check()
+ switch {
+ case f.StreamID != 0: // stream-level flow control
+ state, st := sc.state(f.StreamID)
+ if state == http2stateIdle {
+ // Section 5.1: "Receiving any frame other than HEADERS
+ // or PRIORITY on a stream in this state MUST be
+ // treated as a connection error (Section 5.4.1) of
+ // type PROTOCOL_ERROR."
+ return sc.countError("stream_idle", http2ConnectionError(http2ErrCodeProtocol))
+ }
+ if st == nil {
+ // "WINDOW_UPDATE can be sent by a peer that has sent a
+ // frame bearing the END_STREAM flag. This means that a
+ // receiver could receive a WINDOW_UPDATE frame on a "half
+ // closed (remote)" or "closed" stream. A receiver MUST
+ // NOT treat this as an error, see Section 5.1."
+ return nil
+ }
+ if !st.flow.add(int32(f.Increment)) {
+ return sc.countError("bad_flow", http2streamError(f.StreamID, http2ErrCodeFlowControl))
+ }
+ default: // connection-level flow control
+ if !sc.flow.add(int32(f.Increment)) {
+ return http2goAwayFlowError{}
+ }
+ }
+ sc.scheduleFrameWrite()
+ return nil
+}
+
+func (sc *http2serverConn) processResetStream(f *http2RSTStreamFrame) error {
+ sc.serveG.check()
+
+ state, st := sc.state(f.StreamID)
+ if state == http2stateIdle {
+ // 6.4 "RST_STREAM frames MUST NOT be sent for a
+ // stream in the "idle" state. If a RST_STREAM frame
+ // identifying an idle stream is received, the
+ // recipient MUST treat this as a connection error
+ // (Section 5.4.1) of type PROTOCOL_ERROR.
+ return sc.countError("reset_idle_stream", http2ConnectionError(http2ErrCodeProtocol))
+ }
+ if st != nil {
+ st.cancelCtx()
+ sc.closeStream(st, http2streamError(f.StreamID, f.ErrCode))
+ }
+ return nil
+}
+
+func (sc *http2serverConn) closeStream(st *http2stream, err error) {
+ sc.serveG.check()
+ if st.state == http2stateIdle || st.state == http2stateClosed {
+ panic(fmt.Sprintf("invariant; can't close stream in state %v", st.state))
+ }
+ st.state = http2stateClosed
+ if st.readDeadline != nil {
+ st.readDeadline.Stop()
+ }
+ if st.writeDeadline != nil {
+ st.writeDeadline.Stop()
+ }
+ if st.isPushed() {
+ sc.curPushedStreams--
+ } else {
+ sc.curClientStreams--
+ }
+ delete(sc.streams, st.id)
+ if len(sc.streams) == 0 {
+ sc.setConnState(StateIdle)
+ if sc.srv.IdleTimeout != 0 {
+ sc.idleTimer.Reset(sc.srv.IdleTimeout)
+ }
+ if http2h1ServerKeepAlivesDisabled(sc.hs) {
+ sc.startGracefulShutdownInternal()
+ }
+ }
+ if p := st.body; p != nil {
+ // Return any buffered unread bytes worth of conn-level flow control.
+ // See golang.org/issue/16481
+ sc.sendWindowUpdate(nil, p.Len())
+
+ p.CloseWithError(err)
+ }
+ if e, ok := err.(http2StreamError); ok {
+ if e.Cause != nil {
+ err = e.Cause
+ } else {
+ err = http2errStreamClosed
+ }
+ }
+ st.closeErr = err
+ st.cw.Close() // signals Handler's CloseNotifier, unblocks writes, etc
+ sc.writeSched.CloseStream(st.id)
+}
+
+func (sc *http2serverConn) processSettings(f *http2SettingsFrame) error {
+ sc.serveG.check()
+ if f.IsAck() {
+ sc.unackedSettings--
+ if sc.unackedSettings < 0 {
+ // Why is the peer ACKing settings we never sent?
+ // The spec doesn't mention this case, but
+ // hang up on them anyway.
+ return sc.countError("ack_mystery", http2ConnectionError(http2ErrCodeProtocol))
+ }
+ return nil
+ }
+ if f.NumSettings() > 100 || f.HasDuplicates() {
+ // This isn't actually in the spec, but hang up on
+ // suspiciously large settings frames or those with
+ // duplicate entries.
+ return sc.countError("settings_big_or_dups", http2ConnectionError(http2ErrCodeProtocol))
+ }
+ if err := f.ForeachSetting(sc.processSetting); err != nil {
+ return err
+ }
+ // TODO: judging by RFC 7540, Section 6.5.3 each SETTINGS frame should be
+ // acknowledged individually, even if multiple are received before the ACK.
+ sc.needToSendSettingsAck = true
+ sc.scheduleFrameWrite()
+ return nil
+}
+
+func (sc *http2serverConn) processSetting(s http2Setting) error {
+ sc.serveG.check()
+ if err := s.Valid(); err != nil {
+ return err
+ }
+ if http2VerboseLogs {
+ sc.vlogf("http2: server processing setting %v", s)
+ }
+ switch s.ID {
+ case http2SettingHeaderTableSize:
+ sc.hpackEncoder.SetMaxDynamicTableSize(s.Val)
+ case http2SettingEnablePush:
+ sc.pushEnabled = s.Val != 0
+ case http2SettingMaxConcurrentStreams:
+ sc.clientMaxStreams = s.Val
+ case http2SettingInitialWindowSize:
+ return sc.processSettingInitialWindowSize(s.Val)
+ case http2SettingMaxFrameSize:
+ sc.maxFrameSize = int32(s.Val) // the maximum valid s.Val is < 2^31
+ case http2SettingMaxHeaderListSize:
+ sc.peerMaxHeaderListSize = s.Val
+ default:
+ // Unknown setting: "An endpoint that receives a SETTINGS
+ // frame with any unknown or unsupported identifier MUST
+ // ignore that setting."
+ if http2VerboseLogs {
+ sc.vlogf("http2: server ignoring unknown setting %v", s)
+ }
+ }
+ return nil
+}
+
+func (sc *http2serverConn) processSettingInitialWindowSize(val uint32) error {
+ sc.serveG.check()
+ // Note: val already validated to be within range by
+ // processSetting's Valid call.
+
+ // "A SETTINGS frame can alter the initial flow control window
+ // size for all current streams. When the value of
+ // SETTINGS_INITIAL_WINDOW_SIZE changes, a receiver MUST
+ // adjust the size of all stream flow control windows that it
+ // maintains by the difference between the new value and the
+ // old value."
+ old := sc.initialStreamSendWindowSize
+ sc.initialStreamSendWindowSize = int32(val)
+ growth := int32(val) - old // may be negative
+ for _, st := range sc.streams {
+ if !st.flow.add(growth) {
+ // 6.9.2 Initial Flow Control Window Size
+ // "An endpoint MUST treat a change to
+ // SETTINGS_INITIAL_WINDOW_SIZE that causes any flow
+ // control window to exceed the maximum size as a
+ // connection error (Section 5.4.1) of type
+ // FLOW_CONTROL_ERROR."
+ return sc.countError("setting_win_size", http2ConnectionError(http2ErrCodeFlowControl))
+ }
+ }
+ return nil
+}
+
+func (sc *http2serverConn) processData(f *http2DataFrame) error {
+ sc.serveG.check()
+ id := f.Header().StreamID
+
+ data := f.Data()
+ state, st := sc.state(id)
+ if id == 0 || state == http2stateIdle {
+ // Section 6.1: "DATA frames MUST be associated with a
+ // stream. If a DATA frame is received whose stream
+ // identifier field is 0x0, the recipient MUST respond
+ // with a connection error (Section 5.4.1) of type
+ // PROTOCOL_ERROR."
+ //
+ // Section 5.1: "Receiving any frame other than HEADERS
+ // or PRIORITY on a stream in this state MUST be
+ // treated as a connection error (Section 5.4.1) of
+ // type PROTOCOL_ERROR."
+ return sc.countError("data_on_idle", http2ConnectionError(http2ErrCodeProtocol))
+ }
+
+ // "If a DATA frame is received whose stream is not in "open"
+ // or "half closed (local)" state, the recipient MUST respond
+ // with a stream error (Section 5.4.2) of type STREAM_CLOSED."
+ if st == nil || state != http2stateOpen || st.gotTrailerHeader || st.resetQueued {
+ // This includes sending a RST_STREAM if the stream is
+ // in stateHalfClosedLocal (which currently means that
+ // the http.Handler returned, so it's done reading &
+ // done writing). Try to stop the client from sending
+ // more DATA.
+
+ // But still enforce their connection-level flow control,
+ // and return any flow control bytes since we're not going
+ // to consume them.
+ if !sc.inflow.take(f.Length) {
+ return sc.countError("data_flow", http2streamError(id, http2ErrCodeFlowControl))
+ }
+ sc.sendWindowUpdate(nil, int(f.Length)) // conn-level
+
+ if st != nil && st.resetQueued {
+ // Already have a stream error in flight. Don't send another.
+ return nil
+ }
+ return sc.countError("closed", http2streamError(id, http2ErrCodeStreamClosed))
+ }
+ if st.body == nil {
+ panic("internal error: should have a body in this state")
+ }
+
+ // Sender sending more than they'd declared?
+ if st.declBodyBytes != -1 && st.bodyBytes+int64(len(data)) > st.declBodyBytes {
+ if !sc.inflow.take(f.Length) {
+ return sc.countError("data_flow", http2streamError(id, http2ErrCodeFlowControl))
+ }
+ sc.sendWindowUpdate(nil, int(f.Length)) // conn-level
+
+ st.body.CloseWithError(fmt.Errorf("sender tried to send more than declared Content-Length of %d bytes", st.declBodyBytes))
+ // RFC 7540, sec 8.1.2.6: A request or response is also malformed if the
+ // value of a content-length header field does not equal the sum of the
+ // DATA frame payload lengths that form the body.
+ return sc.countError("send_too_much", http2streamError(id, http2ErrCodeProtocol))
+ }
+ if f.Length > 0 {
+ // Check whether the client has flow control quota.
+ if !http2takeInflows(&sc.inflow, &st.inflow, f.Length) {
+ return sc.countError("flow_on_data_length", http2streamError(id, http2ErrCodeFlowControl))
+ }
+
+ if len(data) > 0 {
+ st.bodyBytes += int64(len(data))
+ wrote, err := st.body.Write(data)
+ if err != nil {
+ // The handler has closed the request body.
+ // Return the connection-level flow control for the discarded data,
+ // but not the stream-level flow control.
+ sc.sendWindowUpdate(nil, int(f.Length)-wrote)
+ return nil
+ }
+ if wrote != len(data) {
+ panic("internal error: bad Writer")
+ }
+ }
+
+ // Return any padded flow control now, since we won't
+ // refund it later on body reads.
+ // Call sendWindowUpdate even if there is no padding,
+ // to return buffered flow control credit if the sent
+ // window has shrunk.
+ pad := int32(f.Length) - int32(len(data))
+ sc.sendWindowUpdate32(nil, pad)
+ sc.sendWindowUpdate32(st, pad)
+ }
+ if f.StreamEnded() {
+ st.endStream()
+ }
+ return nil
+}
+
+func (sc *http2serverConn) processGoAway(f *http2GoAwayFrame) error {
+ sc.serveG.check()
+ if f.ErrCode != http2ErrCodeNo {
+ sc.logf("http2: received GOAWAY %+v, starting graceful shutdown", f)
+ } else {
+ sc.vlogf("http2: received GOAWAY %+v, starting graceful shutdown", f)
+ }
+ sc.startGracefulShutdownInternal()
+ // http://tools.ietf.org/html/rfc7540#section-6.8
+ // We should not create any new streams, which means we should disable push.
+ sc.pushEnabled = false
+ return nil
+}
+
+// isPushed reports whether the stream is server-initiated.
+func (st *http2stream) isPushed() bool {
+ return st.id%2 == 0
+}
+
+// endStream closes a Request.Body's pipe. It is called when a DATA
+// frame says a request body is over (or after trailers).
+func (st *http2stream) endStream() {
+ sc := st.sc
+ sc.serveG.check()
+
+ if st.declBodyBytes != -1 && st.declBodyBytes != st.bodyBytes {
+ st.body.CloseWithError(fmt.Errorf("request declared a Content-Length of %d but only wrote %d bytes",
+ st.declBodyBytes, st.bodyBytes))
+ } else {
+ st.body.closeWithErrorAndCode(io.EOF, st.copyTrailersToHandlerRequest)
+ st.body.CloseWithError(io.EOF)
+ }
+ st.state = http2stateHalfClosedRemote
+}
+
+// copyTrailersToHandlerRequest is run in the Handler's goroutine in
+// its Request.Body.Read just before it gets io.EOF.
+func (st *http2stream) copyTrailersToHandlerRequest() {
+ for k, vv := range st.trailer {
+ if _, ok := st.reqTrailer[k]; ok {
+ // Only copy it over it was pre-declared.
+ st.reqTrailer[k] = vv
+ }
+ }
+}
+
+// onReadTimeout is run on its own goroutine (from time.AfterFunc)
+// when the stream's ReadTimeout has fired.
+func (st *http2stream) onReadTimeout() {
+ // Wrap the ErrDeadlineExceeded to avoid callers depending on us
+ // returning the bare error.
+ st.body.CloseWithError(fmt.Errorf("%w", os.ErrDeadlineExceeded))
+}
+
+// onWriteTimeout is run on its own goroutine (from time.AfterFunc)
+// when the stream's WriteTimeout has fired.
+func (st *http2stream) onWriteTimeout() {
+ st.sc.writeFrameFromHandler(http2FrameWriteRequest{write: http2StreamError{
+ StreamID: st.id,
+ Code: http2ErrCodeInternal,
+ Cause: os.ErrDeadlineExceeded,
+ }})
+}
+
+func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error {
+ sc.serveG.check()
+ id := f.StreamID
+ // http://tools.ietf.org/html/rfc7540#section-5.1.1
+ // Streams initiated by a client MUST use odd-numbered stream
+ // identifiers. [...] An endpoint that receives an unexpected
+ // stream identifier MUST respond with a connection error
+ // (Section 5.4.1) of type PROTOCOL_ERROR.
+ if id%2 != 1 {
+ return sc.countError("headers_even", http2ConnectionError(http2ErrCodeProtocol))
+ }
+ // A HEADERS frame can be used to create a new stream or
+ // send a trailer for an open one. If we already have a stream
+ // open, let it process its own HEADERS frame (trailers at this
+ // point, if it's valid).
+ if st := sc.streams[f.StreamID]; st != nil {
+ if st.resetQueued {
+ // We're sending RST_STREAM to close the stream, so don't bother
+ // processing this frame.
+ return nil
+ }
+ // RFC 7540, sec 5.1: If an endpoint receives additional frames, other than
+ // WINDOW_UPDATE, PRIORITY, or RST_STREAM, for a stream that is in
+ // this state, it MUST respond with a stream error (Section 5.4.2) of
+ // type STREAM_CLOSED.
+ if st.state == http2stateHalfClosedRemote {
+ return sc.countError("headers_half_closed", http2streamError(id, http2ErrCodeStreamClosed))
+ }
+ return st.processTrailerHeaders(f)
+ }
+
+ // [...] The identifier of a newly established stream MUST be
+ // numerically greater than all streams that the initiating
+ // endpoint has opened or reserved. [...] An endpoint that
+ // receives an unexpected stream identifier MUST respond with
+ // a connection error (Section 5.4.1) of type PROTOCOL_ERROR.
+ if id <= sc.maxClientStreamID {
+ return sc.countError("stream_went_down", http2ConnectionError(http2ErrCodeProtocol))
+ }
+ sc.maxClientStreamID = id
+
+ if sc.idleTimer != nil {
+ sc.idleTimer.Stop()
+ }
+
+ // http://tools.ietf.org/html/rfc7540#section-5.1.2
+ // [...] Endpoints MUST NOT exceed the limit set by their peer. An
+ // endpoint that receives a HEADERS frame that causes their
+ // advertised concurrent stream limit to be exceeded MUST treat
+ // this as a stream error (Section 5.4.2) of type PROTOCOL_ERROR
+ // or REFUSED_STREAM.
+ if sc.curClientStreams+1 > sc.advMaxStreams {
+ if sc.unackedSettings == 0 {
+ // They should know better.
+ return sc.countError("over_max_streams", http2streamError(id, http2ErrCodeProtocol))
+ }
+ // Assume it's a network race, where they just haven't
+ // received our last SETTINGS update. But actually
+ // this can't happen yet, because we don't yet provide
+ // a way for users to adjust server parameters at
+ // runtime.
+ return sc.countError("over_max_streams_race", http2streamError(id, http2ErrCodeRefusedStream))
+ }
+
+ initialState := http2stateOpen
+ if f.StreamEnded() {
+ initialState = http2stateHalfClosedRemote
+ }
+ st := sc.newStream(id, 0, initialState)
+
+ if f.HasPriority() {
+ if err := sc.checkPriority(f.StreamID, f.Priority); err != nil {
+ return err
+ }
+ sc.writeSched.AdjustStream(st.id, f.Priority)
+ }
+
+ rw, req, err := sc.newWriterAndRequest(st, f)
+ if err != nil {
+ return err
+ }
+ st.reqTrailer = req.Trailer
+ if st.reqTrailer != nil {
+ st.trailer = make(Header)
+ }
+ st.body = req.Body.(*http2requestBody).pipe // may be nil
+ st.declBodyBytes = req.ContentLength
+
+ handler := sc.handler.ServeHTTP
+ if f.Truncated {
+ // Their header list was too long. Send a 431 error.
+ handler = http2handleHeaderListTooLong
+ } else if err := http2checkValidHTTP2RequestHeaders(req.Header); err != nil {
+ handler = http2new400Handler(err)
+ }
+
+ // The net/http package sets the read deadline from the
+ // http.Server.ReadTimeout during the TLS handshake, but then
+ // passes the connection off to us with the deadline already
+ // set. Disarm it here after the request headers are read,
+ // similar to how the http1 server works. Here it's
+ // technically more like the http1 Server's ReadHeaderTimeout
+ // (in Go 1.8), though. That's a more sane option anyway.
+ if sc.hs.ReadTimeout != 0 {
+ sc.conn.SetReadDeadline(time.Time{})
+ if st.body != nil {
+ st.readDeadline = time.AfterFunc(sc.hs.ReadTimeout, st.onReadTimeout)
+ }
+ }
+
+ return sc.scheduleHandler(id, rw, req, handler)
+}
+
+func (sc *http2serverConn) upgradeRequest(req *Request) {
+ sc.serveG.check()
+ id := uint32(1)
+ sc.maxClientStreamID = id
+ st := sc.newStream(id, 0, http2stateHalfClosedRemote)
+ st.reqTrailer = req.Trailer
+ if st.reqTrailer != nil {
+ st.trailer = make(Header)
+ }
+ rw := sc.newResponseWriter(st, req)
+
+ // Disable any read deadline set by the net/http package
+ // prior to the upgrade.
+ if sc.hs.ReadTimeout != 0 {
+ sc.conn.SetReadDeadline(time.Time{})
+ }
+
+ // This is the first request on the connection,
+ // so start the handler directly rather than going
+ // through scheduleHandler.
+ sc.curHandlers++
+ go sc.runHandler(rw, req, sc.handler.ServeHTTP)
+}
+
+func (st *http2stream) processTrailerHeaders(f *http2MetaHeadersFrame) error {
+ sc := st.sc
+ sc.serveG.check()
+ if st.gotTrailerHeader {
+ return sc.countError("dup_trailers", http2ConnectionError(http2ErrCodeProtocol))
+ }
+ st.gotTrailerHeader = true
+ if !f.StreamEnded() {
+ return sc.countError("trailers_not_ended", http2streamError(st.id, http2ErrCodeProtocol))
+ }
+
+ if len(f.PseudoFields()) > 0 {
+ return sc.countError("trailers_pseudo", http2streamError(st.id, http2ErrCodeProtocol))
+ }
+ if st.trailer != nil {
+ for _, hf := range f.RegularFields() {
+ key := sc.canonicalHeader(hf.Name)
+ if !httpguts.ValidTrailerHeader(key) {
+ // TODO: send more details to the peer somehow. But http2 has
+ // no way to send debug data at a stream level. Discuss with
+ // HTTP folk.
+ return sc.countError("trailers_bogus", http2streamError(st.id, http2ErrCodeProtocol))
+ }
+ st.trailer[key] = append(st.trailer[key], hf.Value)
+ }
+ }
+ st.endStream()
+ return nil
+}
+
+func (sc *http2serverConn) checkPriority(streamID uint32, p http2PriorityParam) error {
+ if streamID == p.StreamDep {
+ // Section 5.3.1: "A stream cannot depend on itself. An endpoint MUST treat
+ // this as a stream error (Section 5.4.2) of type PROTOCOL_ERROR."
+ // Section 5.3.3 says that a stream can depend on one of its dependencies,
+ // so it's only self-dependencies that are forbidden.
+ return sc.countError("priority", http2streamError(streamID, http2ErrCodeProtocol))
+ }
+ return nil
+}
+
+func (sc *http2serverConn) processPriority(f *http2PriorityFrame) error {
+ if err := sc.checkPriority(f.StreamID, f.http2PriorityParam); err != nil {
+ return err
+ }
+ sc.writeSched.AdjustStream(f.StreamID, f.http2PriorityParam)
+ return nil
+}
+
+func (sc *http2serverConn) newStream(id, pusherID uint32, state http2streamState) *http2stream {
+ sc.serveG.check()
+ if id == 0 {
+ panic("internal error: cannot create stream with id 0")
+ }
+
+ ctx, cancelCtx := context.WithCancel(sc.baseCtx)
+ st := &http2stream{
+ sc: sc,
+ id: id,
+ state: state,
+ ctx: ctx,
+ cancelCtx: cancelCtx,
+ }
+ st.cw.Init()
+ st.flow.conn = &sc.flow // link to conn-level counter
+ st.flow.add(sc.initialStreamSendWindowSize)
+ st.inflow.init(sc.srv.initialStreamRecvWindowSize())
+ if sc.hs.WriteTimeout != 0 {
+ st.writeDeadline = time.AfterFunc(sc.hs.WriteTimeout, st.onWriteTimeout)
+ }
+
+ sc.streams[id] = st
+ sc.writeSched.OpenStream(st.id, http2OpenStreamOptions{PusherID: pusherID})
+ if st.isPushed() {
+ sc.curPushedStreams++
+ } else {
+ sc.curClientStreams++
+ }
+ if sc.curOpenStreams() == 1 {
+ sc.setConnState(StateActive)
+ }
+
+ return st
+}
+
+func (sc *http2serverConn) newWriterAndRequest(st *http2stream, f *http2MetaHeadersFrame) (*http2responseWriter, *Request, error) {
+ sc.serveG.check()
+
+ rp := http2requestParam{
+ method: f.PseudoValue("method"),
+ scheme: f.PseudoValue("scheme"),
+ authority: f.PseudoValue("authority"),
+ path: f.PseudoValue("path"),
+ }
+
+ isConnect := rp.method == "CONNECT"
+ if isConnect {
+ if rp.path != "" || rp.scheme != "" || rp.authority == "" {
+ return nil, nil, sc.countError("bad_connect", http2streamError(f.StreamID, http2ErrCodeProtocol))
+ }
+ } else if rp.method == "" || rp.path == "" || (rp.scheme != "https" && rp.scheme != "http") {
+ // See 8.1.2.6 Malformed Requests and Responses:
+ //
+ // Malformed requests or responses that are detected
+ // MUST be treated as a stream error (Section 5.4.2)
+ // of type PROTOCOL_ERROR."
+ //
+ // 8.1.2.3 Request Pseudo-Header Fields
+ // "All HTTP/2 requests MUST include exactly one valid
+ // value for the :method, :scheme, and :path
+ // pseudo-header fields"
+ return nil, nil, sc.countError("bad_path_method", http2streamError(f.StreamID, http2ErrCodeProtocol))
+ }
+
+ rp.header = make(Header)
+ for _, hf := range f.RegularFields() {
+ rp.header.Add(sc.canonicalHeader(hf.Name), hf.Value)
+ }
+ if rp.authority == "" {
+ rp.authority = rp.header.Get("Host")
+ }
+
+ rw, req, err := sc.newWriterAndRequestNoBody(st, rp)
+ if err != nil {
+ return nil, nil, err
+ }
+ bodyOpen := !f.StreamEnded()
+ if bodyOpen {
+ if vv, ok := rp.header["Content-Length"]; ok {
+ if cl, err := strconv.ParseUint(vv[0], 10, 63); err == nil {
+ req.ContentLength = int64(cl)
+ } else {
+ req.ContentLength = 0
+ }
+ } else {
+ req.ContentLength = -1
+ }
+ req.Body.(*http2requestBody).pipe = &http2pipe{
+ b: &http2dataBuffer{expected: req.ContentLength},
+ }
+ }
+ return rw, req, nil
+}
+
+type http2requestParam struct {
+ method string
+ scheme, authority, path string
+ header Header
+}
+
+func (sc *http2serverConn) newWriterAndRequestNoBody(st *http2stream, rp http2requestParam) (*http2responseWriter, *Request, error) {
+ sc.serveG.check()
+
+ var tlsState *tls.ConnectionState // nil if not scheme https
+ if rp.scheme == "https" {
+ tlsState = sc.tlsState
+ }
+
+ needsContinue := httpguts.HeaderValuesContainsToken(rp.header["Expect"], "100-continue")
+ if needsContinue {
+ rp.header.Del("Expect")
+ }
+ // Merge Cookie headers into one "; "-delimited value.
+ if cookies := rp.header["Cookie"]; len(cookies) > 1 {
+ rp.header.Set("Cookie", strings.Join(cookies, "; "))
+ }
+
+ // Setup Trailers
+ var trailer Header
+ for _, v := range rp.header["Trailer"] {
+ for _, key := range strings.Split(v, ",") {
+ key = CanonicalHeaderKey(textproto.TrimString(key))
+ switch key {
+ case "Transfer-Encoding", "Trailer", "Content-Length":
+ // Bogus. (copy of http1 rules)
+ // Ignore.
+ default:
+ if trailer == nil {
+ trailer = make(Header)
+ }
+ trailer[key] = nil
+ }
+ }
+ }
+ delete(rp.header, "Trailer")
+
+ var url_ *url.URL
+ var requestURI string
+ if rp.method == "CONNECT" {
+ url_ = &url.URL{Host: rp.authority}
+ requestURI = rp.authority // mimic HTTP/1 server behavior
+ } else {
+ var err error
+ url_, err = url.ParseRequestURI(rp.path)
+ if err != nil {
+ return nil, nil, sc.countError("bad_path", http2streamError(st.id, http2ErrCodeProtocol))
+ }
+ requestURI = rp.path
+ }
+
+ body := &http2requestBody{
+ conn: sc,
+ stream: st,
+ needsContinue: needsContinue,
+ }
+ req := &Request{
+ Method: rp.method,
+ URL: url_,
+ RemoteAddr: sc.remoteAddrStr,
+ Header: rp.header,
+ RequestURI: requestURI,
+ Proto: "HTTP/2.0",
+ ProtoMajor: 2,
+ ProtoMinor: 0,
+ TLS: tlsState,
+ Host: rp.authority,
+ Body: body,
+ Trailer: trailer,
+ }
+ req = req.WithContext(st.ctx)
+
+ rw := sc.newResponseWriter(st, req)
+ return rw, req, nil
+}
+
+func (sc *http2serverConn) newResponseWriter(st *http2stream, req *Request) *http2responseWriter {
+ rws := http2responseWriterStatePool.Get().(*http2responseWriterState)
+ bwSave := rws.bw
+ *rws = http2responseWriterState{} // zero all the fields
+ rws.conn = sc
+ rws.bw = bwSave
+ rws.bw.Reset(http2chunkWriter{rws})
+ rws.stream = st
+ rws.req = req
+ return &http2responseWriter{rws: rws}
+}
+
+type http2unstartedHandler struct {
+ streamID uint32
+ rw *http2responseWriter
+ req *Request
+ handler func(ResponseWriter, *Request)
+}
+
+// scheduleHandler starts a handler goroutine,
+// or schedules one to start as soon as an existing handler finishes.
+func (sc *http2serverConn) scheduleHandler(streamID uint32, rw *http2responseWriter, req *Request, handler func(ResponseWriter, *Request)) error {
+ sc.serveG.check()
+ maxHandlers := sc.advMaxStreams
+ if sc.curHandlers < maxHandlers {
+ sc.curHandlers++
+ go sc.runHandler(rw, req, handler)
+ return nil
+ }
+ if len(sc.unstartedHandlers) > int(4*sc.advMaxStreams) {
+ return sc.countError("too_many_early_resets", http2ConnectionError(http2ErrCodeEnhanceYourCalm))
+ }
+ sc.unstartedHandlers = append(sc.unstartedHandlers, http2unstartedHandler{
+ streamID: streamID,
+ rw: rw,
+ req: req,
+ handler: handler,
+ })
+ return nil
+}
+
+func (sc *http2serverConn) handlerDone() {
+ sc.serveG.check()
+ sc.curHandlers--
+ i := 0
+ maxHandlers := sc.advMaxStreams
+ for ; i < len(sc.unstartedHandlers); i++ {
+ u := sc.unstartedHandlers[i]
+ if sc.streams[u.streamID] == nil {
+ // This stream was reset before its goroutine had a chance to start.
+ continue
+ }
+ if sc.curHandlers >= maxHandlers {
+ break
+ }
+ sc.curHandlers++
+ go sc.runHandler(u.rw, u.req, u.handler)
+ sc.unstartedHandlers[i] = http2unstartedHandler{} // don't retain references
+ }
+ sc.unstartedHandlers = sc.unstartedHandlers[i:]
+ if len(sc.unstartedHandlers) == 0 {
+ sc.unstartedHandlers = nil
+ }
+}
+
+// Run on its own goroutine.
+func (sc *http2serverConn) runHandler(rw *http2responseWriter, req *Request, handler func(ResponseWriter, *Request)) {
+ defer sc.sendServeMsg(http2handlerDoneMsg)
+ didPanic := true
+ defer func() {
+ rw.rws.stream.cancelCtx()
+ if req.MultipartForm != nil {
+ req.MultipartForm.RemoveAll()
+ }
+ if didPanic {
+ e := recover()
+ sc.writeFrameFromHandler(http2FrameWriteRequest{
+ write: http2handlerPanicRST{rw.rws.stream.id},
+ stream: rw.rws.stream,
+ })
+ // Same as net/http:
+ if e != nil && e != ErrAbortHandler {
+ const size = 64 << 10
+ buf := make([]byte, size)
+ buf = buf[:runtime.Stack(buf, false)]
+ sc.logf("http2: panic serving %v: %v\n%s", sc.conn.RemoteAddr(), e, buf)
+ }
+ return
+ }
+ rw.handlerDone()
+ }()
+ handler(rw, req)
+ didPanic = false
+}
+
+func http2handleHeaderListTooLong(w ResponseWriter, r *Request) {
+ // 10.5.1 Limits on Header Block Size:
+ // .. "A server that receives a larger header block than it is
+ // willing to handle can send an HTTP 431 (Request Header Fields Too
+ // Large) status code"
+ const statusRequestHeaderFieldsTooLarge = 431 // only in Go 1.6+
+ w.WriteHeader(statusRequestHeaderFieldsTooLarge)
+ io.WriteString(w, "<h1>HTTP Error 431</h1><p>Request Header Field(s) Too Large</p>")
+}
+
+// called from handler goroutines.
+// h may be nil.
+func (sc *http2serverConn) writeHeaders(st *http2stream, headerData *http2writeResHeaders) error {
+ sc.serveG.checkNotOn() // NOT on
+ var errc chan error
+ if headerData.h != nil {
+ // If there's a header map (which we don't own), so we have to block on
+ // waiting for this frame to be written, so an http.Flush mid-handler
+ // writes out the correct value of keys, before a handler later potentially
+ // mutates it.
+ errc = http2errChanPool.Get().(chan error)
+ }
+ if err := sc.writeFrameFromHandler(http2FrameWriteRequest{
+ write: headerData,
+ stream: st,
+ done: errc,
+ }); err != nil {
+ return err
+ }
+ if errc != nil {
+ select {
+ case err := <-errc:
+ http2errChanPool.Put(errc)
+ return err
+ case <-sc.doneServing:
+ return http2errClientDisconnected
+ case <-st.cw:
+ return http2errStreamClosed
+ }
+ }
+ return nil
+}
+
+// called from handler goroutines.
+func (sc *http2serverConn) write100ContinueHeaders(st *http2stream) {
+ sc.writeFrameFromHandler(http2FrameWriteRequest{
+ write: http2write100ContinueHeadersFrame{st.id},
+ stream: st,
+ })
+}
+
+// A bodyReadMsg tells the server loop that the http.Handler read n
+// bytes of the DATA from the client on the given stream.
+type http2bodyReadMsg struct {
+ st *http2stream
+ n int
+}
+
+// called from handler goroutines.
+// Notes that the handler for the given stream ID read n bytes of its body
+// and schedules flow control tokens to be sent.
+func (sc *http2serverConn) noteBodyReadFromHandler(st *http2stream, n int, err error) {
+ sc.serveG.checkNotOn() // NOT on
+ if n > 0 {
+ select {
+ case sc.bodyReadCh <- http2bodyReadMsg{st, n}:
+ case <-sc.doneServing:
+ }
+ }
+}
+
+func (sc *http2serverConn) noteBodyRead(st *http2stream, n int) {
+ sc.serveG.check()
+ sc.sendWindowUpdate(nil, n) // conn-level
+ if st.state != http2stateHalfClosedRemote && st.state != http2stateClosed {
+ // Don't send this WINDOW_UPDATE if the stream is closed
+ // remotely.
+ sc.sendWindowUpdate(st, n)
+ }
+}
+
+// st may be nil for conn-level
+func (sc *http2serverConn) sendWindowUpdate32(st *http2stream, n int32) {
+ sc.sendWindowUpdate(st, int(n))
+}
+
+// st may be nil for conn-level
+func (sc *http2serverConn) sendWindowUpdate(st *http2stream, n int) {
+ sc.serveG.check()
+ var streamID uint32
+ var send int32
+ if st == nil {
+ send = sc.inflow.add(n)
+ } else {
+ streamID = st.id
+ send = st.inflow.add(n)
+ }
+ if send == 0 {
+ return
+ }
+ sc.writeFrame(http2FrameWriteRequest{
+ write: http2writeWindowUpdate{streamID: streamID, n: uint32(send)},
+ stream: st,
+ })
+}
+
+// requestBody is the Handler's Request.Body type.
+// Read and Close may be called concurrently.
+type http2requestBody struct {
+ _ http2incomparable
+ stream *http2stream
+ conn *http2serverConn
+ closeOnce sync.Once // for use by Close only
+ sawEOF bool // for use by Read only
+ pipe *http2pipe // non-nil if we have an HTTP entity message body
+ needsContinue bool // need to send a 100-continue
+}
+
+func (b *http2requestBody) Close() error {
+ b.closeOnce.Do(func() {
+ if b.pipe != nil {
+ b.pipe.BreakWithError(http2errClosedBody)
+ }
+ })
+ return nil
+}
+
+func (b *http2requestBody) Read(p []byte) (n int, err error) {
+ if b.needsContinue {
+ b.needsContinue = false
+ b.conn.write100ContinueHeaders(b.stream)
+ }
+ if b.pipe == nil || b.sawEOF {
+ return 0, io.EOF
+ }
+ n, err = b.pipe.Read(p)
+ if err == io.EOF {
+ b.sawEOF = true
+ }
+ if b.conn == nil && http2inTests {
+ return
+ }
+ b.conn.noteBodyReadFromHandler(b.stream, n, err)
+ return
+}
+
+// responseWriter is the http.ResponseWriter implementation. It's
+// intentionally small (1 pointer wide) to minimize garbage. The
+// responseWriterState pointer inside is zeroed at the end of a
+// request (in handlerDone) and calls on the responseWriter thereafter
+// simply crash (caller's mistake), but the much larger responseWriterState
+// and buffers are reused between multiple requests.
+type http2responseWriter struct {
+ rws *http2responseWriterState
+}
+
+// Optional http.ResponseWriter interfaces implemented.
+var (
+ _ CloseNotifier = (*http2responseWriter)(nil)
+ _ Flusher = (*http2responseWriter)(nil)
+ _ http2stringWriter = (*http2responseWriter)(nil)
+)
+
+type http2responseWriterState struct {
+ // immutable within a request:
+ stream *http2stream
+ req *Request
+ conn *http2serverConn
+
+ // TODO: adjust buffer writing sizes based on server config, frame size updates from peer, etc
+ bw *bufio.Writer // writing to a chunkWriter{this *responseWriterState}
+
+ // mutated by http.Handler goroutine:
+ handlerHeader Header // nil until called
+ snapHeader Header // snapshot of handlerHeader at WriteHeader time
+ trailers []string // set in writeChunk
+ status int // status code passed to WriteHeader
+ wroteHeader bool // WriteHeader called (explicitly or implicitly). Not necessarily sent to user yet.
+ sentHeader bool // have we sent the header frame?
+ handlerDone bool // handler has finished
+ dirty bool // a Write failed; don't reuse this responseWriterState
+
+ sentContentLen int64 // non-zero if handler set a Content-Length header
+ wroteBytes int64
+
+ closeNotifierMu sync.Mutex // guards closeNotifierCh
+ closeNotifierCh chan bool // nil until first used
+}
+
+type http2chunkWriter struct{ rws *http2responseWriterState }
+
+func (cw http2chunkWriter) Write(p []byte) (n int, err error) {
+ n, err = cw.rws.writeChunk(p)
+ if err == http2errStreamClosed {
+ // If writing failed because the stream has been closed,
+ // return the reason it was closed.
+ err = cw.rws.stream.closeErr
+ }
+ return n, err
+}
+
+func (rws *http2responseWriterState) hasTrailers() bool { return len(rws.trailers) > 0 }
+
+func (rws *http2responseWriterState) hasNonemptyTrailers() bool {
+ for _, trailer := range rws.trailers {
+ if _, ok := rws.handlerHeader[trailer]; ok {
+ return true
+ }
+ }
+ return false
+}
+
+// declareTrailer is called for each Trailer header when the
+// response header is written. It notes that a header will need to be
+// written in the trailers at the end of the response.
+func (rws *http2responseWriterState) declareTrailer(k string) {
+ k = CanonicalHeaderKey(k)
+ if !httpguts.ValidTrailerHeader(k) {
+ // Forbidden by RFC 7230, section 4.1.2.
+ rws.conn.logf("ignoring invalid trailer %q", k)
+ return
+ }
+ if !http2strSliceContains(rws.trailers, k) {
+ rws.trailers = append(rws.trailers, k)
+ }
+}
+
+// writeChunk writes chunks from the bufio.Writer. But because
+// bufio.Writer may bypass its chunking, sometimes p may be
+// arbitrarily large.
+//
+// writeChunk is also responsible (on the first chunk) for sending the
+// HEADER response.
+func (rws *http2responseWriterState) writeChunk(p []byte) (n int, err error) {
+ if !rws.wroteHeader {
+ rws.writeHeader(200)
+ }
+
+ if rws.handlerDone {
+ rws.promoteUndeclaredTrailers()
+ }
+
+ isHeadResp := rws.req.Method == "HEAD"
+ if !rws.sentHeader {
+ rws.sentHeader = true
+ var ctype, clen string
+ if clen = rws.snapHeader.Get("Content-Length"); clen != "" {
+ rws.snapHeader.Del("Content-Length")
+ if cl, err := strconv.ParseUint(clen, 10, 63); err == nil {
+ rws.sentContentLen = int64(cl)
+ } else {
+ clen = ""
+ }
+ }
+ _, hasContentLength := rws.snapHeader["Content-Length"]
+ if !hasContentLength && clen == "" && rws.handlerDone && http2bodyAllowedForStatus(rws.status) && (len(p) > 0 || !isHeadResp) {
+ clen = strconv.Itoa(len(p))
+ }
+ _, hasContentType := rws.snapHeader["Content-Type"]
+ // If the Content-Encoding is non-blank, we shouldn't
+ // sniff the body. See Issue golang.org/issue/31753.
+ ce := rws.snapHeader.Get("Content-Encoding")
+ hasCE := len(ce) > 0
+ if !hasCE && !hasContentType && http2bodyAllowedForStatus(rws.status) && len(p) > 0 {
+ ctype = DetectContentType(p)
+ }
+ var date string
+ if _, ok := rws.snapHeader["Date"]; !ok {
+ // TODO(bradfitz): be faster here, like net/http? measure.
+ date = time.Now().UTC().Format(TimeFormat)
+ }
+
+ for _, v := range rws.snapHeader["Trailer"] {
+ http2foreachHeaderElement(v, rws.declareTrailer)
+ }
+
+ // "Connection" headers aren't allowed in HTTP/2 (RFC 7540, 8.1.2.2),
+ // but respect "Connection" == "close" to mean sending a GOAWAY and tearing
+ // down the TCP connection when idle, like we do for HTTP/1.
+ // TODO: remove more Connection-specific header fields here, in addition
+ // to "Connection".
+ if _, ok := rws.snapHeader["Connection"]; ok {
+ v := rws.snapHeader.Get("Connection")
+ delete(rws.snapHeader, "Connection")
+ if v == "close" {
+ rws.conn.startGracefulShutdown()
+ }
+ }
+
+ endStream := (rws.handlerDone && !rws.hasTrailers() && len(p) == 0) || isHeadResp
+ err = rws.conn.writeHeaders(rws.stream, &http2writeResHeaders{
+ streamID: rws.stream.id,
+ httpResCode: rws.status,
+ h: rws.snapHeader,
+ endStream: endStream,
+ contentType: ctype,
+ contentLength: clen,
+ date: date,
+ })
+ if err != nil {
+ rws.dirty = true
+ return 0, err
+ }
+ if endStream {
+ return 0, nil
+ }
+ }
+ if isHeadResp {
+ return len(p), nil
+ }
+ if len(p) == 0 && !rws.handlerDone {
+ return 0, nil
+ }
+
+ // only send trailers if they have actually been defined by the
+ // server handler.
+ hasNonemptyTrailers := rws.hasNonemptyTrailers()
+ endStream := rws.handlerDone && !hasNonemptyTrailers
+ if len(p) > 0 || endStream {
+ // only send a 0 byte DATA frame if we're ending the stream.
+ if err := rws.conn.writeDataFromHandler(rws.stream, p, endStream); err != nil {
+ rws.dirty = true
+ return 0, err
+ }
+ }
+
+ if rws.handlerDone && hasNonemptyTrailers {
+ err = rws.conn.writeHeaders(rws.stream, &http2writeResHeaders{
+ streamID: rws.stream.id,
+ h: rws.handlerHeader,
+ trailers: rws.trailers,
+ endStream: true,
+ })
+ if err != nil {
+ rws.dirty = true
+ }
+ return len(p), err
+ }
+ return len(p), nil
+}
+
+// TrailerPrefix is a magic prefix for ResponseWriter.Header map keys
+// that, if present, signals that the map entry is actually for
+// the response trailers, and not the response headers. The prefix
+// is stripped after the ServeHTTP call finishes and the values are
+// sent in the trailers.
+//
+// This mechanism is intended only for trailers that are not known
+// prior to the headers being written. If the set of trailers is fixed
+// or known before the header is written, the normal Go trailers mechanism
+// is preferred:
+//
+// https://golang.org/pkg/net/http/#ResponseWriter
+// https://golang.org/pkg/net/http/#example_ResponseWriter_trailers
+const http2TrailerPrefix = "Trailer:"
+
+// promoteUndeclaredTrailers permits http.Handlers to set trailers
+// after the header has already been flushed. Because the Go
+// ResponseWriter interface has no way to set Trailers (only the
+// Header), and because we didn't want to expand the ResponseWriter
+// interface, and because nobody used trailers, and because RFC 7230
+// says you SHOULD (but not must) predeclare any trailers in the
+// header, the official ResponseWriter rules said trailers in Go must
+// be predeclared, and then we reuse the same ResponseWriter.Header()
+// map to mean both Headers and Trailers. When it's time to write the
+// Trailers, we pick out the fields of Headers that were declared as
+// trailers. That worked for a while, until we found the first major
+// user of Trailers in the wild: gRPC (using them only over http2),
+// and gRPC libraries permit setting trailers mid-stream without
+// predeclaring them. So: change of plans. We still permit the old
+// way, but we also permit this hack: if a Header() key begins with
+// "Trailer:", the suffix of that key is a Trailer. Because ':' is an
+// invalid token byte anyway, there is no ambiguity. (And it's already
+// filtered out) It's mildly hacky, but not terrible.
+//
+// This method runs after the Handler is done and promotes any Header
+// fields to be trailers.
+func (rws *http2responseWriterState) promoteUndeclaredTrailers() {
+ for k, vv := range rws.handlerHeader {
+ if !strings.HasPrefix(k, http2TrailerPrefix) {
+ continue
+ }
+ trailerKey := strings.TrimPrefix(k, http2TrailerPrefix)
+ rws.declareTrailer(trailerKey)
+ rws.handlerHeader[CanonicalHeaderKey(trailerKey)] = vv
+ }
+
+ if len(rws.trailers) > 1 {
+ sorter := http2sorterPool.Get().(*http2sorter)
+ sorter.SortStrings(rws.trailers)
+ http2sorterPool.Put(sorter)
+ }
+}
+
+func (w *http2responseWriter) SetReadDeadline(deadline time.Time) error {
+ st := w.rws.stream
+ if !deadline.IsZero() && deadline.Before(time.Now()) {
+ // If we're setting a deadline in the past, reset the stream immediately
+ // so writes after SetWriteDeadline returns will fail.
+ st.onReadTimeout()
+ return nil
+ }
+ w.rws.conn.sendServeMsg(func(sc *http2serverConn) {
+ if st.readDeadline != nil {
+ if !st.readDeadline.Stop() {
+ // Deadline already exceeded, or stream has been closed.
+ return
+ }
+ }
+ if deadline.IsZero() {
+ st.readDeadline = nil
+ } else if st.readDeadline == nil {
+ st.readDeadline = time.AfterFunc(deadline.Sub(time.Now()), st.onReadTimeout)
+ } else {
+ st.readDeadline.Reset(deadline.Sub(time.Now()))
+ }
+ })
+ return nil
+}
+
+func (w *http2responseWriter) SetWriteDeadline(deadline time.Time) error {
+ st := w.rws.stream
+ if !deadline.IsZero() && deadline.Before(time.Now()) {
+ // If we're setting a deadline in the past, reset the stream immediately
+ // so writes after SetWriteDeadline returns will fail.
+ st.onWriteTimeout()
+ return nil
+ }
+ w.rws.conn.sendServeMsg(func(sc *http2serverConn) {
+ if st.writeDeadline != nil {
+ if !st.writeDeadline.Stop() {
+ // Deadline already exceeded, or stream has been closed.
+ return
+ }
+ }
+ if deadline.IsZero() {
+ st.writeDeadline = nil
+ } else if st.writeDeadline == nil {
+ st.writeDeadline = time.AfterFunc(deadline.Sub(time.Now()), st.onWriteTimeout)
+ } else {
+ st.writeDeadline.Reset(deadline.Sub(time.Now()))
+ }
+ })
+ return nil
+}
+
+func (w *http2responseWriter) Flush() {
+ w.FlushError()
+}
+
+func (w *http2responseWriter) FlushError() error {
+ rws := w.rws
+ if rws == nil {
+ panic("Header called after Handler finished")
+ }
+ var err error
+ if rws.bw.Buffered() > 0 {
+ err = rws.bw.Flush()
+ } else {
+ // The bufio.Writer won't call chunkWriter.Write
+ // (writeChunk with zero bytes), so we have to do it
+ // ourselves to force the HTTP response header and/or
+ // final DATA frame (with END_STREAM) to be sent.
+ _, err = http2chunkWriter{rws}.Write(nil)
+ if err == nil {
+ select {
+ case <-rws.stream.cw:
+ err = rws.stream.closeErr
+ default:
+ }
+ }
+ }
+ return err
+}
+
+func (w *http2responseWriter) CloseNotify() <-chan bool {
+ rws := w.rws
+ if rws == nil {
+ panic("CloseNotify called after Handler finished")
+ }
+ rws.closeNotifierMu.Lock()
+ ch := rws.closeNotifierCh
+ if ch == nil {
+ ch = make(chan bool, 1)
+ rws.closeNotifierCh = ch
+ cw := rws.stream.cw
+ go func() {
+ cw.Wait() // wait for close
+ ch <- true
+ }()
+ }
+ rws.closeNotifierMu.Unlock()
+ return ch
+}
+
+func (w *http2responseWriter) Header() Header {
+ rws := w.rws
+ if rws == nil {
+ panic("Header called after Handler finished")
+ }
+ if rws.handlerHeader == nil {
+ rws.handlerHeader = make(Header)
+ }
+ return rws.handlerHeader
+}
+
+// checkWriteHeaderCode is a copy of net/http's checkWriteHeaderCode.
+func http2checkWriteHeaderCode(code int) {
+ // Issue 22880: require valid WriteHeader status codes.
+ // For now we only enforce that it's three digits.
+ // In the future we might block things over 599 (600 and above aren't defined
+ // at http://httpwg.org/specs/rfc7231.html#status.codes).
+ // But for now any three digits.
+ //
+ // We used to send "HTTP/1.1 000 0" on the wire in responses but there's
+ // no equivalent bogus thing we can realistically send in HTTP/2,
+ // so we'll consistently panic instead and help people find their bugs
+ // early. (We can't return an error from WriteHeader even if we wanted to.)
+ if code < 100 || code > 999 {
+ panic(fmt.Sprintf("invalid WriteHeader code %v", code))
+ }
+}
+
+func (w *http2responseWriter) WriteHeader(code int) {
+ rws := w.rws
+ if rws == nil {
+ panic("WriteHeader called after Handler finished")
+ }
+ rws.writeHeader(code)
+}
+
+func (rws *http2responseWriterState) writeHeader(code int) {
+ if rws.wroteHeader {
+ return
+ }
+
+ http2checkWriteHeaderCode(code)
+
+ // Handle informational headers
+ if code >= 100 && code <= 199 {
+ // Per RFC 8297 we must not clear the current header map
+ h := rws.handlerHeader
+
+ _, cl := h["Content-Length"]
+ _, te := h["Transfer-Encoding"]
+ if cl || te {
+ h = h.Clone()
+ h.Del("Content-Length")
+ h.Del("Transfer-Encoding")
+ }
+
+ if rws.conn.writeHeaders(rws.stream, &http2writeResHeaders{
+ streamID: rws.stream.id,
+ httpResCode: code,
+ h: h,
+ endStream: rws.handlerDone && !rws.hasTrailers(),
+ }) != nil {
+ rws.dirty = true
+ }
+
+ return
+ }
+
+ rws.wroteHeader = true
+ rws.status = code
+ if len(rws.handlerHeader) > 0 {
+ rws.snapHeader = http2cloneHeader(rws.handlerHeader)
+ }
+}
+
+func http2cloneHeader(h Header) Header {
+ h2 := make(Header, len(h))
+ for k, vv := range h {
+ vv2 := make([]string, len(vv))
+ copy(vv2, vv)
+ h2[k] = vv2
+ }
+ return h2
+}
+
+// The Life Of A Write is like this:
+//
+// * Handler calls w.Write or w.WriteString ->
+// * -> rws.bw (*bufio.Writer) ->
+// * (Handler might call Flush)
+// * -> chunkWriter{rws}
+// * -> responseWriterState.writeChunk(p []byte)
+// * -> responseWriterState.writeChunk (most of the magic; see comment there)
+func (w *http2responseWriter) Write(p []byte) (n int, err error) {
+ return w.write(len(p), p, "")
+}
+
+func (w *http2responseWriter) WriteString(s string) (n int, err error) {
+ return w.write(len(s), nil, s)
+}
+
+// either dataB or dataS is non-zero.
+func (w *http2responseWriter) write(lenData int, dataB []byte, dataS string) (n int, err error) {
+ rws := w.rws
+ if rws == nil {
+ panic("Write called after Handler finished")
+ }
+ if !rws.wroteHeader {
+ w.WriteHeader(200)
+ }
+ if !http2bodyAllowedForStatus(rws.status) {
+ return 0, ErrBodyNotAllowed
+ }
+ rws.wroteBytes += int64(len(dataB)) + int64(len(dataS)) // only one can be set
+ if rws.sentContentLen != 0 && rws.wroteBytes > rws.sentContentLen {
+ // TODO: send a RST_STREAM
+ return 0, errors.New("http2: handler wrote more than declared Content-Length")
+ }
+
+ if dataB != nil {
+ return rws.bw.Write(dataB)
+ } else {
+ return rws.bw.WriteString(dataS)
+ }
+}
+
+func (w *http2responseWriter) handlerDone() {
+ rws := w.rws
+ dirty := rws.dirty
+ rws.handlerDone = true
+ w.Flush()
+ w.rws = nil
+ if !dirty {
+ // Only recycle the pool if all prior Write calls to
+ // the serverConn goroutine completed successfully. If
+ // they returned earlier due to resets from the peer
+ // there might still be write goroutines outstanding
+ // from the serverConn referencing the rws memory. See
+ // issue 20704.
+ http2responseWriterStatePool.Put(rws)
+ }
+}
+
+// Push errors.
+var (
+ http2ErrRecursivePush = errors.New("http2: recursive push not allowed")
+ http2ErrPushLimitReached = errors.New("http2: push would exceed peer's SETTINGS_MAX_CONCURRENT_STREAMS")
+)
+
+var _ Pusher = (*http2responseWriter)(nil)
+
+func (w *http2responseWriter) Push(target string, opts *PushOptions) error {
+ st := w.rws.stream
+ sc := st.sc
+ sc.serveG.checkNotOn()
+
+ // No recursive pushes: "PUSH_PROMISE frames MUST only be sent on a peer-initiated stream."
+ // http://tools.ietf.org/html/rfc7540#section-6.6
+ if st.isPushed() {
+ return http2ErrRecursivePush
+ }
+
+ if opts == nil {
+ opts = new(PushOptions)
+ }
+
+ // Default options.
+ if opts.Method == "" {
+ opts.Method = "GET"
+ }
+ if opts.Header == nil {
+ opts.Header = Header{}
+ }
+ wantScheme := "http"
+ if w.rws.req.TLS != nil {
+ wantScheme = "https"
+ }
+
+ // Validate the request.
+ u, err := url.Parse(target)
+ if err != nil {
+ return err
+ }
+ if u.Scheme == "" {
+ if !strings.HasPrefix(target, "/") {
+ return fmt.Errorf("target must be an absolute URL or an absolute path: %q", target)
+ }
+ u.Scheme = wantScheme
+ u.Host = w.rws.req.Host
+ } else {
+ if u.Scheme != wantScheme {
+ return fmt.Errorf("cannot push URL with scheme %q from request with scheme %q", u.Scheme, wantScheme)
+ }
+ if u.Host == "" {
+ return errors.New("URL must have a host")
+ }
+ }
+ for k := range opts.Header {
+ if strings.HasPrefix(k, ":") {
+ return fmt.Errorf("promised request headers cannot include pseudo header %q", k)
+ }
+ // These headers are meaningful only if the request has a body,
+ // but PUSH_PROMISE requests cannot have a body.
+ // http://tools.ietf.org/html/rfc7540#section-8.2
+ // Also disallow Host, since the promised URL must be absolute.
+ if http2asciiEqualFold(k, "content-length") ||
+ http2asciiEqualFold(k, "content-encoding") ||
+ http2asciiEqualFold(k, "trailer") ||
+ http2asciiEqualFold(k, "te") ||
+ http2asciiEqualFold(k, "expect") ||
+ http2asciiEqualFold(k, "host") {
+ return fmt.Errorf("promised request headers cannot include %q", k)
+ }
+ }
+ if err := http2checkValidHTTP2RequestHeaders(opts.Header); err != nil {
+ return err
+ }
+
+ // The RFC effectively limits promised requests to GET and HEAD:
+ // "Promised requests MUST be cacheable [GET, HEAD, or POST], and MUST be safe [GET or HEAD]"
+ // http://tools.ietf.org/html/rfc7540#section-8.2
+ if opts.Method != "GET" && opts.Method != "HEAD" {
+ return fmt.Errorf("method %q must be GET or HEAD", opts.Method)
+ }
+
+ msg := &http2startPushRequest{
+ parent: st,
+ method: opts.Method,
+ url: u,
+ header: http2cloneHeader(opts.Header),
+ done: http2errChanPool.Get().(chan error),
+ }
+
+ select {
+ case <-sc.doneServing:
+ return http2errClientDisconnected
+ case <-st.cw:
+ return http2errStreamClosed
+ case sc.serveMsgCh <- msg:
+ }
+
+ select {
+ case <-sc.doneServing:
+ return http2errClientDisconnected
+ case <-st.cw:
+ return http2errStreamClosed
+ case err := <-msg.done:
+ http2errChanPool.Put(msg.done)
+ return err
+ }
+}
+
+type http2startPushRequest struct {
+ parent *http2stream
+ method string
+ url *url.URL
+ header Header
+ done chan error
+}
+
+func (sc *http2serverConn) startPush(msg *http2startPushRequest) {
+ sc.serveG.check()
+
+ // http://tools.ietf.org/html/rfc7540#section-6.6.
+ // PUSH_PROMISE frames MUST only be sent on a peer-initiated stream that
+ // is in either the "open" or "half-closed (remote)" state.
+ if msg.parent.state != http2stateOpen && msg.parent.state != http2stateHalfClosedRemote {
+ // responseWriter.Push checks that the stream is peer-initiated.
+ msg.done <- http2errStreamClosed
+ return
+ }
+
+ // http://tools.ietf.org/html/rfc7540#section-6.6.
+ if !sc.pushEnabled {
+ msg.done <- ErrNotSupported
+ return
+ }
+
+ // PUSH_PROMISE frames must be sent in increasing order by stream ID, so
+ // we allocate an ID for the promised stream lazily, when the PUSH_PROMISE
+ // is written. Once the ID is allocated, we start the request handler.
+ allocatePromisedID := func() (uint32, error) {
+ sc.serveG.check()
+
+ // Check this again, just in case. Technically, we might have received
+ // an updated SETTINGS by the time we got around to writing this frame.
+ if !sc.pushEnabled {
+ return 0, ErrNotSupported
+ }
+ // http://tools.ietf.org/html/rfc7540#section-6.5.2.
+ if sc.curPushedStreams+1 > sc.clientMaxStreams {
+ return 0, http2ErrPushLimitReached
+ }
+
+ // http://tools.ietf.org/html/rfc7540#section-5.1.1.
+ // Streams initiated by the server MUST use even-numbered identifiers.
+ // A server that is unable to establish a new stream identifier can send a GOAWAY
+ // frame so that the client is forced to open a new connection for new streams.
+ if sc.maxPushPromiseID+2 >= 1<<31 {
+ sc.startGracefulShutdownInternal()
+ return 0, http2ErrPushLimitReached
+ }
+ sc.maxPushPromiseID += 2
+ promisedID := sc.maxPushPromiseID
+
+ // http://tools.ietf.org/html/rfc7540#section-8.2.
+ // Strictly speaking, the new stream should start in "reserved (local)", then
+ // transition to "half closed (remote)" after sending the initial HEADERS, but
+ // we start in "half closed (remote)" for simplicity.
+ // See further comments at the definition of stateHalfClosedRemote.
+ promised := sc.newStream(promisedID, msg.parent.id, http2stateHalfClosedRemote)
+ rw, req, err := sc.newWriterAndRequestNoBody(promised, http2requestParam{
+ method: msg.method,
+ scheme: msg.url.Scheme,
+ authority: msg.url.Host,
+ path: msg.url.RequestURI(),
+ header: http2cloneHeader(msg.header), // clone since handler runs concurrently with writing the PUSH_PROMISE
+ })
+ if err != nil {
+ // Should not happen, since we've already validated msg.url.
+ panic(fmt.Sprintf("newWriterAndRequestNoBody(%+v): %v", msg.url, err))
+ }
+
+ sc.curHandlers++
+ go sc.runHandler(rw, req, sc.handler.ServeHTTP)
+ return promisedID, nil
+ }
+
+ sc.writeFrame(http2FrameWriteRequest{
+ write: &http2writePushPromise{
+ streamID: msg.parent.id,
+ method: msg.method,
+ url: msg.url,
+ h: msg.header,
+ allocatePromisedID: allocatePromisedID,
+ },
+ stream: msg.parent,
+ done: msg.done,
+ })
+}
+
+// foreachHeaderElement splits v according to the "#rule" construction
+// in RFC 7230 section 7 and calls fn for each non-empty element.
+func http2foreachHeaderElement(v string, fn func(string)) {
+ v = textproto.TrimString(v)
+ if v == "" {
+ return
+ }
+ if !strings.Contains(v, ",") {
+ fn(v)
+ return
+ }
+ for _, f := range strings.Split(v, ",") {
+ if f = textproto.TrimString(f); f != "" {
+ fn(f)
+ }
+ }
+}
+
+// From http://httpwg.org/specs/rfc7540.html#rfc.section.8.1.2.2
+var http2connHeaders = []string{
+ "Connection",
+ "Keep-Alive",
+ "Proxy-Connection",
+ "Transfer-Encoding",
+ "Upgrade",
+}
+
+// checkValidHTTP2RequestHeaders checks whether h is a valid HTTP/2 request,
+// per RFC 7540 Section 8.1.2.2.
+// The returned error is reported to users.
+func http2checkValidHTTP2RequestHeaders(h Header) error {
+ for _, k := range http2connHeaders {
+ if _, ok := h[k]; ok {
+ return fmt.Errorf("request header %q is not valid in HTTP/2", k)
+ }
+ }
+ te := h["Te"]
+ if len(te) > 0 && (len(te) > 1 || (te[0] != "trailers" && te[0] != "")) {
+ return errors.New(`request header "TE" may only be "trailers" in HTTP/2`)
+ }
+ return nil
+}
+
+func http2new400Handler(err error) HandlerFunc {
+ return func(w ResponseWriter, r *Request) {
+ Error(w, err.Error(), StatusBadRequest)
+ }
+}
+
+// h1ServerKeepAlivesDisabled reports whether hs has its keep-alives
+// disabled. See comments on h1ServerShutdownChan above for why
+// the code is written this way.
+func http2h1ServerKeepAlivesDisabled(hs *Server) bool {
+ var x interface{} = hs
+ type I interface {
+ doKeepAlives() bool
+ }
+ if hs, ok := x.(I); ok {
+ return !hs.doKeepAlives()
+ }
+ return false
+}
+
+func (sc *http2serverConn) countError(name string, err error) error {
+ if sc == nil || sc.srv == nil {
+ return err
+ }
+ f := sc.srv.CountError
+ if f == nil {
+ return err
+ }
+ var typ string
+ var code http2ErrCode
+ switch e := err.(type) {
+ case http2ConnectionError:
+ typ = "conn"
+ code = http2ErrCode(e)
+ case http2StreamError:
+ typ = "stream"
+ code = http2ErrCode(e.Code)
+ default:
+ return err
+ }
+ codeStr := http2errCodeName[code]
+ if codeStr == "" {
+ codeStr = strconv.Itoa(int(code))
+ }
+ f(fmt.Sprintf("%s_%s_%s", typ, codeStr, name))
+ return err
+}
+
+const (
+ // transportDefaultConnFlow is how many connection-level flow control
+ // tokens we give the server at start-up, past the default 64k.
+ http2transportDefaultConnFlow = 1 << 30
+
+ // transportDefaultStreamFlow is how many stream-level flow
+ // control tokens we announce to the peer, and how many bytes
+ // we buffer per stream.
+ http2transportDefaultStreamFlow = 4 << 20
+
+ http2defaultUserAgent = "Go-http-client/2.0"
+
+ // initialMaxConcurrentStreams is a connections maxConcurrentStreams until
+ // it's received servers initial SETTINGS frame, which corresponds with the
+ // spec's minimum recommended value.
+ http2initialMaxConcurrentStreams = 100
+
+ // defaultMaxConcurrentStreams is a connections default maxConcurrentStreams
+ // if the server doesn't include one in its initial SETTINGS frame.
+ http2defaultMaxConcurrentStreams = 1000
+)
+
+// Transport is an HTTP/2 Transport.
+//
+// A Transport internally caches connections to servers. It is safe
+// for concurrent use by multiple goroutines.
+type http2Transport struct {
+ // DialTLSContext specifies an optional dial function with context for
+ // creating TLS connections for requests.
+ //
+ // If DialTLSContext and DialTLS is nil, tls.Dial is used.
+ //
+ // If the returned net.Conn has a ConnectionState method like tls.Conn,
+ // it will be used to set http.Response.TLS.
+ DialTLSContext func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error)
+
+ // DialTLS specifies an optional dial function for creating
+ // TLS connections for requests.
+ //
+ // If DialTLSContext and DialTLS is nil, tls.Dial is used.
+ //
+ // Deprecated: Use DialTLSContext instead, which allows the transport
+ // to cancel dials as soon as they are no longer needed.
+ // If both are set, DialTLSContext takes priority.
+ DialTLS func(network, addr string, cfg *tls.Config) (net.Conn, error)
+
+ // TLSClientConfig specifies the TLS configuration to use with
+ // tls.Client. If nil, the default configuration is used.
+ TLSClientConfig *tls.Config
+
+ // ConnPool optionally specifies an alternate connection pool to use.
+ // If nil, the default is used.
+ ConnPool http2ClientConnPool
+
+ // DisableCompression, if true, prevents the Transport from
+ // requesting compression with an "Accept-Encoding: gzip"
+ // request header when the Request contains no existing
+ // Accept-Encoding value. If the Transport requests gzip on
+ // its own and gets a gzipped response, it's transparently
+ // decoded in the Response.Body. However, if the user
+ // explicitly requested gzip it is not automatically
+ // uncompressed.
+ DisableCompression bool
+
+ // AllowHTTP, if true, permits HTTP/2 requests using the insecure,
+ // plain-text "http" scheme. Note that this does not enable h2c support.
+ AllowHTTP bool
+
+ // MaxHeaderListSize is the http2 SETTINGS_MAX_HEADER_LIST_SIZE to
+ // send in the initial settings frame. It is how many bytes
+ // of response headers are allowed. Unlike the http2 spec, zero here
+ // means to use a default limit (currently 10MB). If you actually
+ // want to advertise an unlimited value to the peer, Transport
+ // interprets the highest possible value here (0xffffffff or 1<<32-1)
+ // to mean no limit.
+ MaxHeaderListSize uint32
+
+ // MaxReadFrameSize is the http2 SETTINGS_MAX_FRAME_SIZE to send in the
+ // initial settings frame. It is the size in bytes of the largest frame
+ // payload that the sender is willing to receive. If 0, no setting is
+ // sent, and the value is provided by the peer, which should be 16384
+ // according to the spec:
+ // https://datatracker.ietf.org/doc/html/rfc7540#section-6.5.2.
+ // Values are bounded in the range 16k to 16M.
+ MaxReadFrameSize uint32
+
+ // MaxDecoderHeaderTableSize optionally specifies the http2
+ // SETTINGS_HEADER_TABLE_SIZE to send in the initial settings frame. It
+ // informs the remote endpoint of the maximum size of the header compression
+ // table used to decode header blocks, in octets. If zero, the default value
+ // of 4096 is used.
+ MaxDecoderHeaderTableSize uint32
+
+ // MaxEncoderHeaderTableSize optionally specifies an upper limit for the
+ // header compression table used for encoding request headers. Received
+ // SETTINGS_HEADER_TABLE_SIZE settings are capped at this limit. If zero,
+ // the default value of 4096 is used.
+ MaxEncoderHeaderTableSize uint32
+
+ // StrictMaxConcurrentStreams controls whether the server's
+ // SETTINGS_MAX_CONCURRENT_STREAMS should be respected
+ // globally. If false, new TCP connections are created to the
+ // server as needed to keep each under the per-connection
+ // SETTINGS_MAX_CONCURRENT_STREAMS limit. If true, the
+ // server's SETTINGS_MAX_CONCURRENT_STREAMS is interpreted as
+ // a global limit and callers of RoundTrip block when needed,
+ // waiting for their turn.
+ StrictMaxConcurrentStreams bool
+
+ // ReadIdleTimeout is the timeout after which a health check using ping
+ // frame will be carried out if no frame is received on the connection.
+ // Note that a ping response will is considered a received frame, so if
+ // there is no other traffic on the connection, the health check will
+ // be performed every ReadIdleTimeout interval.
+ // If zero, no health check is performed.
+ ReadIdleTimeout time.Duration
+
+ // PingTimeout is the timeout after which the connection will be closed
+ // if a response to Ping is not received.
+ // Defaults to 15s.
+ PingTimeout time.Duration
+
+ // WriteByteTimeout is the timeout after which the connection will be
+ // closed no data can be written to it. The timeout begins when data is
+ // available to write, and is extended whenever any bytes are written.
+ WriteByteTimeout time.Duration
+
+ // CountError, if non-nil, is called on HTTP/2 transport errors.
+ // It's intended to increment a metric for monitoring, such
+ // as an expvar or Prometheus metric.
+ // The errType consists of only ASCII word characters.
+ CountError func(errType string)
+
+ // t1, if non-nil, is the standard library Transport using
+ // this transport. Its settings are used (but not its
+ // RoundTrip method, etc).
+ t1 *Transport
+
+ connPoolOnce sync.Once
+ connPoolOrDef http2ClientConnPool // non-nil version of ConnPool
+}
+
+func (t *http2Transport) maxHeaderListSize() uint32 {
+ if t.MaxHeaderListSize == 0 {
+ return 10 << 20
+ }
+ if t.MaxHeaderListSize == 0xffffffff {
+ return 0
+ }
+ return t.MaxHeaderListSize
+}
+
+func (t *http2Transport) maxFrameReadSize() uint32 {
+ if t.MaxReadFrameSize == 0 {
+ return 0 // use the default provided by the peer
+ }
+ if t.MaxReadFrameSize < http2minMaxFrameSize {
+ return http2minMaxFrameSize
+ }
+ if t.MaxReadFrameSize > http2maxFrameSize {
+ return http2maxFrameSize
+ }
+ return t.MaxReadFrameSize
+}
+
+func (t *http2Transport) disableCompression() bool {
+ return t.DisableCompression || (t.t1 != nil && t.t1.DisableCompression)
+}
+
+func (t *http2Transport) pingTimeout() time.Duration {
+ if t.PingTimeout == 0 {
+ return 15 * time.Second
+ }
+ return t.PingTimeout
+
+}
+
+// ConfigureTransport configures a net/http HTTP/1 Transport to use HTTP/2.
+// It returns an error if t1 has already been HTTP/2-enabled.
+//
+// Use ConfigureTransports instead to configure the HTTP/2 Transport.
+func http2ConfigureTransport(t1 *Transport) error {
+ _, err := http2ConfigureTransports(t1)
+ return err
+}
+
+// ConfigureTransports configures a net/http HTTP/1 Transport to use HTTP/2.
+// It returns a new HTTP/2 Transport for further configuration.
+// It returns an error if t1 has already been HTTP/2-enabled.
+func http2ConfigureTransports(t1 *Transport) (*http2Transport, error) {
+ return http2configureTransports(t1)
+}
+
+func http2configureTransports(t1 *Transport) (*http2Transport, error) {
+ connPool := new(http2clientConnPool)
+ t2 := &http2Transport{
+ ConnPool: http2noDialClientConnPool{connPool},
+ t1: t1,
+ }
+ connPool.t = t2
+ if err := http2registerHTTPSProtocol(t1, http2noDialH2RoundTripper{t2}); err != nil {
+ return nil, err
+ }
+ if t1.TLSClientConfig == nil {
+ t1.TLSClientConfig = new(tls.Config)
+ }
+ if !http2strSliceContains(t1.TLSClientConfig.NextProtos, "h2") {
+ t1.TLSClientConfig.NextProtos = append([]string{"h2"}, t1.TLSClientConfig.NextProtos...)
+ }
+ if !http2strSliceContains(t1.TLSClientConfig.NextProtos, "http/1.1") {
+ t1.TLSClientConfig.NextProtos = append(t1.TLSClientConfig.NextProtos, "http/1.1")
+ }
+ upgradeFn := func(authority string, c *tls.Conn) RoundTripper {
+ addr := http2authorityAddr("https", authority)
+ if used, err := connPool.addConnIfNeeded(addr, t2, c); err != nil {
+ go c.Close()
+ return http2erringRoundTripper{err}
+ } else if !used {
+ // Turns out we don't need this c.
+ // For example, two goroutines made requests to the same host
+ // at the same time, both kicking off TCP dials. (since protocol
+ // was unknown)
+ go c.Close()
+ }
+ return t2
+ }
+ if m := t1.TLSNextProto; len(m) == 0 {
+ t1.TLSNextProto = map[string]func(string, *tls.Conn) RoundTripper{
+ "h2": upgradeFn,
+ }
+ } else {
+ m["h2"] = upgradeFn
+ }
+ return t2, nil
+}
+
+func (t *http2Transport) connPool() http2ClientConnPool {
+ t.connPoolOnce.Do(t.initConnPool)
+ return t.connPoolOrDef
+}
+
+func (t *http2Transport) initConnPool() {
+ if t.ConnPool != nil {
+ t.connPoolOrDef = t.ConnPool
+ } else {
+ t.connPoolOrDef = &http2clientConnPool{t: t}
+ }
+}
+
+// ClientConn is the state of a single HTTP/2 client connection to an
+// HTTP/2 server.
+type http2ClientConn struct {
+ t *http2Transport
+ tconn net.Conn // usually *tls.Conn, except specialized impls
+ tconnClosed bool
+ tlsState *tls.ConnectionState // nil only for specialized impls
+ reused uint32 // whether conn is being reused; atomic
+ singleUse bool // whether being used for a single http.Request
+ getConnCalled bool // used by clientConnPool
+
+ // readLoop goroutine fields:
+ readerDone chan struct{} // closed on error
+ readerErr error // set before readerDone is closed
+
+ idleTimeout time.Duration // or 0 for never
+ idleTimer *time.Timer
+
+ mu sync.Mutex // guards following
+ cond *sync.Cond // hold mu; broadcast on flow/closed changes
+ flow http2outflow // our conn-level flow control quota (cs.outflow is per stream)
+ inflow http2inflow // peer's conn-level flow control
+ doNotReuse bool // whether conn is marked to not be reused for any future requests
+ closing bool
+ closed bool
+ seenSettings bool // true if we've seen a settings frame, false otherwise
+ wantSettingsAck bool // we sent a SETTINGS frame and haven't heard back
+ goAway *http2GoAwayFrame // if non-nil, the GoAwayFrame we received
+ goAwayDebug string // goAway frame's debug data, retained as a string
+ streams map[uint32]*http2clientStream // client-initiated
+ streamsReserved int // incr by ReserveNewRequest; decr on RoundTrip
+ nextStreamID uint32
+ pendingRequests int // requests blocked and waiting to be sent because len(streams) == maxConcurrentStreams
+ pings map[[8]byte]chan struct{} // in flight ping data to notification channel
+ br *bufio.Reader
+ lastActive time.Time
+ lastIdle time.Time // time last idle
+ // Settings from peer: (also guarded by wmu)
+ maxFrameSize uint32
+ maxConcurrentStreams uint32
+ peerMaxHeaderListSize uint64
+ peerMaxHeaderTableSize uint32
+ initialWindowSize uint32
+
+ // reqHeaderMu is a 1-element semaphore channel controlling access to sending new requests.
+ // Write to reqHeaderMu to lock it, read from it to unlock.
+ // Lock reqmu BEFORE mu or wmu.
+ reqHeaderMu chan struct{}
+
+ // wmu is held while writing.
+ // Acquire BEFORE mu when holding both, to avoid blocking mu on network writes.
+ // Only acquire both at the same time when changing peer settings.
+ wmu sync.Mutex
+ bw *bufio.Writer
+ fr *http2Framer
+ werr error // first write error that has occurred
+ hbuf bytes.Buffer // HPACK encoder writes into this
+ henc *hpack.Encoder
+}
+
+// clientStream is the state for a single HTTP/2 stream. One of these
+// is created for each Transport.RoundTrip call.
+type http2clientStream struct {
+ cc *http2ClientConn
+
+ // Fields of Request that we may access even after the response body is closed.
+ ctx context.Context
+ reqCancel <-chan struct{}
+
+ trace *httptrace.ClientTrace // or nil
+ ID uint32
+ bufPipe http2pipe // buffered pipe with the flow-controlled response payload
+ requestedGzip bool
+ isHead bool
+
+ abortOnce sync.Once
+ abort chan struct{} // closed to signal stream should end immediately
+ abortErr error // set if abort is closed
+
+ peerClosed chan struct{} // closed when the peer sends an END_STREAM flag
+ donec chan struct{} // closed after the stream is in the closed state
+ on100 chan struct{} // buffered; written to if a 100 is received
+
+ respHeaderRecv chan struct{} // closed when headers are received
+ res *Response // set if respHeaderRecv is closed
+
+ flow http2outflow // guarded by cc.mu
+ inflow http2inflow // guarded by cc.mu
+ bytesRemain int64 // -1 means unknown; owned by transportResponseBody.Read
+ readErr error // sticky read error; owned by transportResponseBody.Read
+
+ reqBody io.ReadCloser
+ reqBodyContentLength int64 // -1 means unknown
+ reqBodyClosed chan struct{} // guarded by cc.mu; non-nil on Close, closed when done
+
+ // owned by writeRequest:
+ sentEndStream bool // sent an END_STREAM flag to the peer
+ sentHeaders bool
+
+ // owned by clientConnReadLoop:
+ firstByte bool // got the first response byte
+ pastHeaders bool // got first MetaHeadersFrame (actual headers)
+ pastTrailers bool // got optional second MetaHeadersFrame (trailers)
+ num1xx uint8 // number of 1xx responses seen
+ readClosed bool // peer sent an END_STREAM flag
+ readAborted bool // read loop reset the stream
+
+ trailer Header // accumulated trailers
+ resTrailer *Header // client's Response.Trailer
+}
+
+var http2got1xxFuncForTests func(int, textproto.MIMEHeader) error
+
+// get1xxTraceFunc returns the value of request's httptrace.ClientTrace.Got1xxResponse func,
+// if any. It returns nil if not set or if the Go version is too old.
+func (cs *http2clientStream) get1xxTraceFunc() func(int, textproto.MIMEHeader) error {
+ if fn := http2got1xxFuncForTests; fn != nil {
+ return fn
+ }
+ return http2traceGot1xxResponseFunc(cs.trace)
+}
+
+func (cs *http2clientStream) abortStream(err error) {
+ cs.cc.mu.Lock()
+ defer cs.cc.mu.Unlock()
+ cs.abortStreamLocked(err)
+}
+
+func (cs *http2clientStream) abortStreamLocked(err error) {
+ cs.abortOnce.Do(func() {
+ cs.abortErr = err
+ close(cs.abort)
+ })
+ if cs.reqBody != nil {
+ cs.closeReqBodyLocked()
+ }
+ // TODO(dneil): Clean up tests where cs.cc.cond is nil.
+ if cs.cc.cond != nil {
+ // Wake up writeRequestBody if it is waiting on flow control.
+ cs.cc.cond.Broadcast()
+ }
+}
+
+func (cs *http2clientStream) abortRequestBodyWrite() {
+ cc := cs.cc
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ if cs.reqBody != nil && cs.reqBodyClosed == nil {
+ cs.closeReqBodyLocked()
+ cc.cond.Broadcast()
+ }
+}
+
+func (cs *http2clientStream) closeReqBodyLocked() {
+ if cs.reqBodyClosed != nil {
+ return
+ }
+ cs.reqBodyClosed = make(chan struct{})
+ reqBodyClosed := cs.reqBodyClosed
+ go func() {
+ cs.reqBody.Close()
+ close(reqBodyClosed)
+ }()
+}
+
+type http2stickyErrWriter struct {
+ conn net.Conn
+ timeout time.Duration
+ err *error
+}
+
+func (sew http2stickyErrWriter) Write(p []byte) (n int, err error) {
+ if *sew.err != nil {
+ return 0, *sew.err
+ }
+ for {
+ if sew.timeout != 0 {
+ sew.conn.SetWriteDeadline(time.Now().Add(sew.timeout))
+ }
+ nn, err := sew.conn.Write(p[n:])
+ n += nn
+ if n < len(p) && nn > 0 && errors.Is(err, os.ErrDeadlineExceeded) {
+ // Keep extending the deadline so long as we're making progress.
+ continue
+ }
+ if sew.timeout != 0 {
+ sew.conn.SetWriteDeadline(time.Time{})
+ }
+ *sew.err = err
+ return n, err
+ }
+}
+
+// noCachedConnError is the concrete type of ErrNoCachedConn, which
+// needs to be detected by net/http regardless of whether it's its
+// bundled version (in h2_bundle.go with a rewritten type name) or
+// from a user's x/net/http2. As such, as it has a unique method name
+// (IsHTTP2NoCachedConnError) that net/http sniffs for via func
+// isNoCachedConnError.
+type http2noCachedConnError struct{}
+
+func (http2noCachedConnError) IsHTTP2NoCachedConnError() {}
+
+func (http2noCachedConnError) Error() string { return "http2: no cached connection was available" }
+
+// isNoCachedConnError reports whether err is of type noCachedConnError
+// or its equivalent renamed type in net/http2's h2_bundle.go. Both types
+// may coexist in the same running program.
+func http2isNoCachedConnError(err error) bool {
+ _, ok := err.(interface{ IsHTTP2NoCachedConnError() })
+ return ok
+}
+
+var http2ErrNoCachedConn error = http2noCachedConnError{}
+
+// RoundTripOpt are options for the Transport.RoundTripOpt method.
+type http2RoundTripOpt struct {
+ // OnlyCachedConn controls whether RoundTripOpt may
+ // create a new TCP connection. If set true and
+ // no cached connection is available, RoundTripOpt
+ // will return ErrNoCachedConn.
+ OnlyCachedConn bool
+}
+
+func (t *http2Transport) RoundTrip(req *Request) (*Response, error) {
+ return t.RoundTripOpt(req, http2RoundTripOpt{})
+}
+
+// authorityAddr returns a given authority (a host/IP, or host:port / ip:port)
+// and returns a host:port. The port 443 is added if needed.
+func http2authorityAddr(scheme string, authority string) (addr string) {
+ host, port, err := net.SplitHostPort(authority)
+ if err != nil { // authority didn't have a port
+ host = authority
+ port = ""
+ }
+ if port == "" { // authority's port was empty
+ port = "443"
+ if scheme == "http" {
+ port = "80"
+ }
+ }
+ if a, err := idna.ToASCII(host); err == nil {
+ host = a
+ }
+ // IPv6 address literal, without a port:
+ if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
+ return host + ":" + port
+ }
+ return net.JoinHostPort(host, port)
+}
+
+var http2retryBackoffHook func(time.Duration) *time.Timer
+
+func http2backoffNewTimer(d time.Duration) *time.Timer {
+ if http2retryBackoffHook != nil {
+ return http2retryBackoffHook(d)
+ }
+ return time.NewTimer(d)
+}
+
+// RoundTripOpt is like RoundTrip, but takes options.
+func (t *http2Transport) RoundTripOpt(req *Request, opt http2RoundTripOpt) (*Response, error) {
+ if !(req.URL.Scheme == "https" || (req.URL.Scheme == "http" && t.AllowHTTP)) {
+ return nil, errors.New("http2: unsupported scheme")
+ }
+
+ addr := http2authorityAddr(req.URL.Scheme, req.URL.Host)
+ for retry := 0; ; retry++ {
+ cc, err := t.connPool().GetClientConn(req, addr)
+ if err != nil {
+ t.vlogf("http2: Transport failed to get client conn for %s: %v", addr, err)
+ return nil, err
+ }
+ reused := !atomic.CompareAndSwapUint32(&cc.reused, 0, 1)
+ http2traceGotConn(req, cc, reused)
+ res, err := cc.RoundTrip(req)
+ if err != nil && retry <= 6 {
+ roundTripErr := err
+ if req, err = http2shouldRetryRequest(req, err); err == nil {
+ // After the first retry, do exponential backoff with 10% jitter.
+ if retry == 0 {
+ t.vlogf("RoundTrip retrying after failure: %v", roundTripErr)
+ continue
+ }
+ backoff := float64(uint(1) << (uint(retry) - 1))
+ backoff += backoff * (0.1 * mathrand.Float64())
+ d := time.Second * time.Duration(backoff)
+ timer := http2backoffNewTimer(d)
+ select {
+ case <-timer.C:
+ t.vlogf("RoundTrip retrying after failure: %v", roundTripErr)
+ continue
+ case <-req.Context().Done():
+ timer.Stop()
+ err = req.Context().Err()
+ }
+ }
+ }
+ if err != nil {
+ t.vlogf("RoundTrip failure: %v", err)
+ return nil, err
+ }
+ return res, nil
+ }
+}
+
+// CloseIdleConnections closes any connections which were previously
+// connected from previous requests but are now sitting idle.
+// It does not interrupt any connections currently in use.
+func (t *http2Transport) CloseIdleConnections() {
+ if cp, ok := t.connPool().(http2clientConnPoolIdleCloser); ok {
+ cp.closeIdleConnections()
+ }
+}
+
+var (
+ http2errClientConnClosed = errors.New("http2: client conn is closed")
+ http2errClientConnUnusable = errors.New("http2: client conn not usable")
+ http2errClientConnGotGoAway = errors.New("http2: Transport received Server's graceful shutdown GOAWAY")
+)
+
+// shouldRetryRequest is called by RoundTrip when a request fails to get
+// response headers. It is always called with a non-nil error.
+// It returns either a request to retry (either the same request, or a
+// modified clone), or an error if the request can't be replayed.
+func http2shouldRetryRequest(req *Request, err error) (*Request, error) {
+ if !http2canRetryError(err) {
+ return nil, err
+ }
+ // If the Body is nil (or http.NoBody), it's safe to reuse
+ // this request and its Body.
+ if req.Body == nil || req.Body == NoBody {
+ return req, nil
+ }
+
+ // If the request body can be reset back to its original
+ // state via the optional req.GetBody, do that.
+ if req.GetBody != nil {
+ body, err := req.GetBody()
+ if err != nil {
+ return nil, err
+ }
+ newReq := *req
+ newReq.Body = body
+ return &newReq, nil
+ }
+
+ // The Request.Body can't reset back to the beginning, but we
+ // don't seem to have started to read from it yet, so reuse
+ // the request directly.
+ if err == http2errClientConnUnusable {
+ return req, nil
+ }
+
+ return nil, fmt.Errorf("http2: Transport: cannot retry err [%v] after Request.Body was written; define Request.GetBody to avoid this error", err)
+}
+
+func http2canRetryError(err error) bool {
+ if err == http2errClientConnUnusable || err == http2errClientConnGotGoAway {
+ return true
+ }
+ if se, ok := err.(http2StreamError); ok {
+ if se.Code == http2ErrCodeProtocol && se.Cause == http2errFromPeer {
+ // See golang/go#47635, golang/go#42777
+ return true
+ }
+ return se.Code == http2ErrCodeRefusedStream
+ }
+ return false
+}
+
+func (t *http2Transport) dialClientConn(ctx context.Context, addr string, singleUse bool) (*http2ClientConn, error) {
+ host, _, err := net.SplitHostPort(addr)
+ if err != nil {
+ return nil, err
+ }
+ tconn, err := t.dialTLS(ctx, "tcp", addr, t.newTLSConfig(host))
+ if err != nil {
+ return nil, err
+ }
+ return t.newClientConn(tconn, singleUse)
+}
+
+func (t *http2Transport) newTLSConfig(host string) *tls.Config {
+ cfg := new(tls.Config)
+ if t.TLSClientConfig != nil {
+ *cfg = *t.TLSClientConfig.Clone()
+ }
+ if !http2strSliceContains(cfg.NextProtos, http2NextProtoTLS) {
+ cfg.NextProtos = append([]string{http2NextProtoTLS}, cfg.NextProtos...)
+ }
+ if cfg.ServerName == "" {
+ cfg.ServerName = host
+ }
+ return cfg
+}
+
+func (t *http2Transport) dialTLS(ctx context.Context, network, addr string, tlsCfg *tls.Config) (net.Conn, error) {
+ if t.DialTLSContext != nil {
+ return t.DialTLSContext(ctx, network, addr, tlsCfg)
+ } else if t.DialTLS != nil {
+ return t.DialTLS(network, addr, tlsCfg)
+ }
+
+ tlsCn, err := t.dialTLSWithContext(ctx, network, addr, tlsCfg)
+ if err != nil {
+ return nil, err
+ }
+ state := tlsCn.ConnectionState()
+ if p := state.NegotiatedProtocol; p != http2NextProtoTLS {
+ return nil, fmt.Errorf("http2: unexpected ALPN protocol %q; want %q", p, http2NextProtoTLS)
+ }
+ if !state.NegotiatedProtocolIsMutual {
+ return nil, errors.New("http2: could not negotiate protocol mutually")
+ }
+ return tlsCn, nil
+}
+
+// disableKeepAlives reports whether connections should be closed as
+// soon as possible after handling the first request.
+func (t *http2Transport) disableKeepAlives() bool {
+ return t.t1 != nil && t.t1.DisableKeepAlives
+}
+
+func (t *http2Transport) expectContinueTimeout() time.Duration {
+ if t.t1 == nil {
+ return 0
+ }
+ return t.t1.ExpectContinueTimeout
+}
+
+func (t *http2Transport) maxDecoderHeaderTableSize() uint32 {
+ if v := t.MaxDecoderHeaderTableSize; v > 0 {
+ return v
+ }
+ return http2initialHeaderTableSize
+}
+
+func (t *http2Transport) maxEncoderHeaderTableSize() uint32 {
+ if v := t.MaxEncoderHeaderTableSize; v > 0 {
+ return v
+ }
+ return http2initialHeaderTableSize
+}
+
+func (t *http2Transport) NewClientConn(c net.Conn) (*http2ClientConn, error) {
+ return t.newClientConn(c, t.disableKeepAlives())
+}
+
+func (t *http2Transport) newClientConn(c net.Conn, singleUse bool) (*http2ClientConn, error) {
+ cc := &http2ClientConn{
+ t: t,
+ tconn: c,
+ readerDone: make(chan struct{}),
+ nextStreamID: 1,
+ maxFrameSize: 16 << 10, // spec default
+ initialWindowSize: 65535, // spec default
+ maxConcurrentStreams: http2initialMaxConcurrentStreams, // "infinite", per spec. Use a smaller value until we have received server settings.
+ peerMaxHeaderListSize: 0xffffffffffffffff, // "infinite", per spec. Use 2^64-1 instead.
+ streams: make(map[uint32]*http2clientStream),
+ singleUse: singleUse,
+ wantSettingsAck: true,
+ pings: make(map[[8]byte]chan struct{}),
+ reqHeaderMu: make(chan struct{}, 1),
+ }
+ if d := t.idleConnTimeout(); d != 0 {
+ cc.idleTimeout = d
+ cc.idleTimer = time.AfterFunc(d, cc.onIdleTimeout)
+ }
+ if http2VerboseLogs {
+ t.vlogf("http2: Transport creating client conn %p to %v", cc, c.RemoteAddr())
+ }
+
+ cc.cond = sync.NewCond(&cc.mu)
+ cc.flow.add(int32(http2initialWindowSize))
+
+ // TODO: adjust this writer size to account for frame size +
+ // MTU + crypto/tls record padding.
+ cc.bw = bufio.NewWriter(http2stickyErrWriter{
+ conn: c,
+ timeout: t.WriteByteTimeout,
+ err: &cc.werr,
+ })
+ cc.br = bufio.NewReader(c)
+ cc.fr = http2NewFramer(cc.bw, cc.br)
+ if t.maxFrameReadSize() != 0 {
+ cc.fr.SetMaxReadFrameSize(t.maxFrameReadSize())
+ }
+ if t.CountError != nil {
+ cc.fr.countError = t.CountError
+ }
+ maxHeaderTableSize := t.maxDecoderHeaderTableSize()
+ cc.fr.ReadMetaHeaders = hpack.NewDecoder(maxHeaderTableSize, nil)
+ cc.fr.MaxHeaderListSize = t.maxHeaderListSize()
+
+ cc.henc = hpack.NewEncoder(&cc.hbuf)
+ cc.henc.SetMaxDynamicTableSizeLimit(t.maxEncoderHeaderTableSize())
+ cc.peerMaxHeaderTableSize = http2initialHeaderTableSize
+
+ if t.AllowHTTP {
+ cc.nextStreamID = 3
+ }
+
+ if cs, ok := c.(http2connectionStater); ok {
+ state := cs.ConnectionState()
+ cc.tlsState = &state
+ }
+
+ initialSettings := []http2Setting{
+ {ID: http2SettingEnablePush, Val: 0},
+ {ID: http2SettingInitialWindowSize, Val: http2transportDefaultStreamFlow},
+ }
+ if max := t.maxFrameReadSize(); max != 0 {
+ initialSettings = append(initialSettings, http2Setting{ID: http2SettingMaxFrameSize, Val: max})
+ }
+ if max := t.maxHeaderListSize(); max != 0 {
+ initialSettings = append(initialSettings, http2Setting{ID: http2SettingMaxHeaderListSize, Val: max})
+ }
+ if maxHeaderTableSize != http2initialHeaderTableSize {
+ initialSettings = append(initialSettings, http2Setting{ID: http2SettingHeaderTableSize, Val: maxHeaderTableSize})
+ }
+
+ cc.bw.Write(http2clientPreface)
+ cc.fr.WriteSettings(initialSettings...)
+ cc.fr.WriteWindowUpdate(0, http2transportDefaultConnFlow)
+ cc.inflow.init(http2transportDefaultConnFlow + http2initialWindowSize)
+ cc.bw.Flush()
+ if cc.werr != nil {
+ cc.Close()
+ return nil, cc.werr
+ }
+
+ go cc.readLoop()
+ return cc, nil
+}
+
+func (cc *http2ClientConn) healthCheck() {
+ pingTimeout := cc.t.pingTimeout()
+ // We don't need to periodically ping in the health check, because the readLoop of ClientConn will
+ // trigger the healthCheck again if there is no frame received.
+ ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
+ defer cancel()
+ cc.vlogf("http2: Transport sending health check")
+ err := cc.Ping(ctx)
+ if err != nil {
+ cc.vlogf("http2: Transport health check failure: %v", err)
+ cc.closeForLostPing()
+ } else {
+ cc.vlogf("http2: Transport health check success")
+ }
+}
+
+// SetDoNotReuse marks cc as not reusable for future HTTP requests.
+func (cc *http2ClientConn) SetDoNotReuse() {
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ cc.doNotReuse = true
+}
+
+func (cc *http2ClientConn) setGoAway(f *http2GoAwayFrame) {
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+
+ old := cc.goAway
+ cc.goAway = f
+
+ // Merge the previous and current GoAway error frames.
+ if cc.goAwayDebug == "" {
+ cc.goAwayDebug = string(f.DebugData())
+ }
+ if old != nil && old.ErrCode != http2ErrCodeNo {
+ cc.goAway.ErrCode = old.ErrCode
+ }
+ last := f.LastStreamID
+ for streamID, cs := range cc.streams {
+ if streamID > last {
+ cs.abortStreamLocked(http2errClientConnGotGoAway)
+ }
+ }
+}
+
+// CanTakeNewRequest reports whether the connection can take a new request,
+// meaning it has not been closed or received or sent a GOAWAY.
+//
+// If the caller is going to immediately make a new request on this
+// connection, use ReserveNewRequest instead.
+func (cc *http2ClientConn) CanTakeNewRequest() bool {
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ return cc.canTakeNewRequestLocked()
+}
+
+// ReserveNewRequest is like CanTakeNewRequest but also reserves a
+// concurrent stream in cc. The reservation is decremented on the
+// next call to RoundTrip.
+func (cc *http2ClientConn) ReserveNewRequest() bool {
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ if st := cc.idleStateLocked(); !st.canTakeNewRequest {
+ return false
+ }
+ cc.streamsReserved++
+ return true
+}
+
+// ClientConnState describes the state of a ClientConn.
+type http2ClientConnState struct {
+ // Closed is whether the connection is closed.
+ Closed bool
+
+ // Closing is whether the connection is in the process of
+ // closing. It may be closing due to shutdown, being a
+ // single-use connection, being marked as DoNotReuse, or
+ // having received a GOAWAY frame.
+ Closing bool
+
+ // StreamsActive is how many streams are active.
+ StreamsActive int
+
+ // StreamsReserved is how many streams have been reserved via
+ // ClientConn.ReserveNewRequest.
+ StreamsReserved int
+
+ // StreamsPending is how many requests have been sent in excess
+ // of the peer's advertised MaxConcurrentStreams setting and
+ // are waiting for other streams to complete.
+ StreamsPending int
+
+ // MaxConcurrentStreams is how many concurrent streams the
+ // peer advertised as acceptable. Zero means no SETTINGS
+ // frame has been received yet.
+ MaxConcurrentStreams uint32
+
+ // LastIdle, if non-zero, is when the connection last
+ // transitioned to idle state.
+ LastIdle time.Time
+}
+
+// State returns a snapshot of cc's state.
+func (cc *http2ClientConn) State() http2ClientConnState {
+ cc.wmu.Lock()
+ maxConcurrent := cc.maxConcurrentStreams
+ if !cc.seenSettings {
+ maxConcurrent = 0
+ }
+ cc.wmu.Unlock()
+
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ return http2ClientConnState{
+ Closed: cc.closed,
+ Closing: cc.closing || cc.singleUse || cc.doNotReuse || cc.goAway != nil,
+ StreamsActive: len(cc.streams),
+ StreamsReserved: cc.streamsReserved,
+ StreamsPending: cc.pendingRequests,
+ LastIdle: cc.lastIdle,
+ MaxConcurrentStreams: maxConcurrent,
+ }
+}
+
+// clientConnIdleState describes the suitability of a client
+// connection to initiate a new RoundTrip request.
+type http2clientConnIdleState struct {
+ canTakeNewRequest bool
+}
+
+func (cc *http2ClientConn) idleState() http2clientConnIdleState {
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ return cc.idleStateLocked()
+}
+
+func (cc *http2ClientConn) idleStateLocked() (st http2clientConnIdleState) {
+ if cc.singleUse && cc.nextStreamID > 1 {
+ return
+ }
+ var maxConcurrentOkay bool
+ if cc.t.StrictMaxConcurrentStreams {
+ // We'll tell the caller we can take a new request to
+ // prevent the caller from dialing a new TCP
+ // connection, but then we'll block later before
+ // writing it.
+ maxConcurrentOkay = true
+ } else {
+ maxConcurrentOkay = int64(len(cc.streams)+cc.streamsReserved+1) <= int64(cc.maxConcurrentStreams)
+ }
+
+ st.canTakeNewRequest = cc.goAway == nil && !cc.closed && !cc.closing && maxConcurrentOkay &&
+ !cc.doNotReuse &&
+ int64(cc.nextStreamID)+2*int64(cc.pendingRequests) < math.MaxInt32 &&
+ !cc.tooIdleLocked()
+ return
+}
+
+func (cc *http2ClientConn) canTakeNewRequestLocked() bool {
+ st := cc.idleStateLocked()
+ return st.canTakeNewRequest
+}
+
+// tooIdleLocked reports whether this connection has been been sitting idle
+// for too much wall time.
+func (cc *http2ClientConn) tooIdleLocked() bool {
+ // The Round(0) strips the monontonic clock reading so the
+ // times are compared based on their wall time. We don't want
+ // to reuse a connection that's been sitting idle during
+ // VM/laptop suspend if monotonic time was also frozen.
+ return cc.idleTimeout != 0 && !cc.lastIdle.IsZero() && time.Since(cc.lastIdle.Round(0)) > cc.idleTimeout
+}
+
+// onIdleTimeout is called from a time.AfterFunc goroutine. It will
+// only be called when we're idle, but because we're coming from a new
+// goroutine, there could be a new request coming in at the same time,
+// so this simply calls the synchronized closeIfIdle to shut down this
+// connection. The timer could just call closeIfIdle, but this is more
+// clear.
+func (cc *http2ClientConn) onIdleTimeout() {
+ cc.closeIfIdle()
+}
+
+func (cc *http2ClientConn) closeConn() {
+ t := time.AfterFunc(250*time.Millisecond, cc.forceCloseConn)
+ defer t.Stop()
+ cc.tconn.Close()
+}
+
+// A tls.Conn.Close can hang for a long time if the peer is unresponsive.
+// Try to shut it down more aggressively.
+func (cc *http2ClientConn) forceCloseConn() {
+ tc, ok := cc.tconn.(*tls.Conn)
+ if !ok {
+ return
+ }
+ if nc := http2tlsUnderlyingConn(tc); nc != nil {
+ nc.Close()
+ }
+}
+
+func (cc *http2ClientConn) closeIfIdle() {
+ cc.mu.Lock()
+ if len(cc.streams) > 0 || cc.streamsReserved > 0 {
+ cc.mu.Unlock()
+ return
+ }
+ cc.closed = true
+ nextID := cc.nextStreamID
+ // TODO: do clients send GOAWAY too? maybe? Just Close:
+ cc.mu.Unlock()
+
+ if http2VerboseLogs {
+ cc.vlogf("http2: Transport closing idle conn %p (forSingleUse=%v, maxStream=%v)", cc, cc.singleUse, nextID-2)
+ }
+ cc.closeConn()
+}
+
+func (cc *http2ClientConn) isDoNotReuseAndIdle() bool {
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ return cc.doNotReuse && len(cc.streams) == 0
+}
+
+var http2shutdownEnterWaitStateHook = func() {}
+
+// Shutdown gracefully closes the client connection, waiting for running streams to complete.
+func (cc *http2ClientConn) Shutdown(ctx context.Context) error {
+ if err := cc.sendGoAway(); err != nil {
+ return err
+ }
+ // Wait for all in-flight streams to complete or connection to close
+ done := make(chan struct{})
+ cancelled := false // guarded by cc.mu
+ go func() {
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ for {
+ if len(cc.streams) == 0 || cc.closed {
+ cc.closed = true
+ close(done)
+ break
+ }
+ if cancelled {
+ break
+ }
+ cc.cond.Wait()
+ }
+ }()
+ http2shutdownEnterWaitStateHook()
+ select {
+ case <-done:
+ cc.closeConn()
+ return nil
+ case <-ctx.Done():
+ cc.mu.Lock()
+ // Free the goroutine above
+ cancelled = true
+ cc.cond.Broadcast()
+ cc.mu.Unlock()
+ return ctx.Err()
+ }
+}
+
+func (cc *http2ClientConn) sendGoAway() error {
+ cc.mu.Lock()
+ closing := cc.closing
+ cc.closing = true
+ maxStreamID := cc.nextStreamID
+ cc.mu.Unlock()
+ if closing {
+ // GOAWAY sent already
+ return nil
+ }
+
+ cc.wmu.Lock()
+ defer cc.wmu.Unlock()
+ // Send a graceful shutdown frame to server
+ if err := cc.fr.WriteGoAway(maxStreamID, http2ErrCodeNo, nil); err != nil {
+ return err
+ }
+ if err := cc.bw.Flush(); err != nil {
+ return err
+ }
+ // Prevent new requests
+ return nil
+}
+
+// closes the client connection immediately. In-flight requests are interrupted.
+// err is sent to streams.
+func (cc *http2ClientConn) closeForError(err error) {
+ cc.mu.Lock()
+ cc.closed = true
+ for _, cs := range cc.streams {
+ cs.abortStreamLocked(err)
+ }
+ cc.cond.Broadcast()
+ cc.mu.Unlock()
+ cc.closeConn()
+}
+
+// Close closes the client connection immediately.
+//
+// In-flight requests are interrupted. For a graceful shutdown, use Shutdown instead.
+func (cc *http2ClientConn) Close() error {
+ err := errors.New("http2: client connection force closed via ClientConn.Close")
+ cc.closeForError(err)
+ return nil
+}
+
+// closes the client connection immediately. In-flight requests are interrupted.
+func (cc *http2ClientConn) closeForLostPing() {
+ err := errors.New("http2: client connection lost")
+ if f := cc.t.CountError; f != nil {
+ f("conn_close_lost_ping")
+ }
+ cc.closeForError(err)
+}
+
+// errRequestCanceled is a copy of net/http's errRequestCanceled because it's not
+// exported. At least they'll be DeepEqual for h1-vs-h2 comparisons tests.
+var http2errRequestCanceled = errors.New("net/http: request canceled")
+
+func http2commaSeparatedTrailers(req *Request) (string, error) {
+ keys := make([]string, 0, len(req.Trailer))
+ for k := range req.Trailer {
+ k = http2canonicalHeader(k)
+ switch k {
+ case "Transfer-Encoding", "Trailer", "Content-Length":
+ return "", fmt.Errorf("invalid Trailer key %q", k)
+ }
+ keys = append(keys, k)
+ }
+ if len(keys) > 0 {
+ sort.Strings(keys)
+ return strings.Join(keys, ","), nil
+ }
+ return "", nil
+}
+
+func (cc *http2ClientConn) responseHeaderTimeout() time.Duration {
+ if cc.t.t1 != nil {
+ return cc.t.t1.ResponseHeaderTimeout
+ }
+ // No way to do this (yet?) with just an http2.Transport. Probably
+ // no need. Request.Cancel this is the new way. We only need to support
+ // this for compatibility with the old http.Transport fields when
+ // we're doing transparent http2.
+ return 0
+}
+
+// checkConnHeaders checks whether req has any invalid connection-level headers.
+// per RFC 7540 section 8.1.2.2: Connection-Specific Header Fields.
+// Certain headers are special-cased as okay but not transmitted later.
+func http2checkConnHeaders(req *Request) error {
+ if v := req.Header.Get("Upgrade"); v != "" {
+ return fmt.Errorf("http2: invalid Upgrade request header: %q", req.Header["Upgrade"])
+ }
+ if vv := req.Header["Transfer-Encoding"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && vv[0] != "chunked") {
+ return fmt.Errorf("http2: invalid Transfer-Encoding request header: %q", vv)
+ }
+ if vv := req.Header["Connection"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && !http2asciiEqualFold(vv[0], "close") && !http2asciiEqualFold(vv[0], "keep-alive")) {
+ return fmt.Errorf("http2: invalid Connection request header: %q", vv)
+ }
+ return nil
+}
+
+// actualContentLength returns a sanitized version of
+// req.ContentLength, where 0 actually means zero (not unknown) and -1
+// means unknown.
+func http2actualContentLength(req *Request) int64 {
+ if req.Body == nil || req.Body == NoBody {
+ return 0
+ }
+ if req.ContentLength != 0 {
+ return req.ContentLength
+ }
+ return -1
+}
+
+func (cc *http2ClientConn) decrStreamReservations() {
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ cc.decrStreamReservationsLocked()
+}
+
+func (cc *http2ClientConn) decrStreamReservationsLocked() {
+ if cc.streamsReserved > 0 {
+ cc.streamsReserved--
+ }
+}
+
+func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) {
+ ctx := req.Context()
+ cs := &http2clientStream{
+ cc: cc,
+ ctx: ctx,
+ reqCancel: req.Cancel,
+ isHead: req.Method == "HEAD",
+ reqBody: req.Body,
+ reqBodyContentLength: http2actualContentLength(req),
+ trace: httptrace.ContextClientTrace(ctx),
+ peerClosed: make(chan struct{}),
+ abort: make(chan struct{}),
+ respHeaderRecv: make(chan struct{}),
+ donec: make(chan struct{}),
+ }
+ go cs.doRequest(req)
+
+ waitDone := func() error {
+ select {
+ case <-cs.donec:
+ return nil
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-cs.reqCancel:
+ return http2errRequestCanceled
+ }
+ }
+
+ handleResponseHeaders := func() (*Response, error) {
+ res := cs.res
+ if res.StatusCode > 299 {
+ // On error or status code 3xx, 4xx, 5xx, etc abort any
+ // ongoing write, assuming that the server doesn't care
+ // about our request body. If the server replied with 1xx or
+ // 2xx, however, then assume the server DOES potentially
+ // want our body (e.g. full-duplex streaming:
+ // golang.org/issue/13444). If it turns out the server
+ // doesn't, they'll RST_STREAM us soon enough. This is a
+ // heuristic to avoid adding knobs to Transport. Hopefully
+ // we can keep it.
+ cs.abortRequestBodyWrite()
+ }
+ res.Request = req
+ res.TLS = cc.tlsState
+ if res.Body == http2noBody && http2actualContentLength(req) == 0 {
+ // If there isn't a request or response body still being
+ // written, then wait for the stream to be closed before
+ // RoundTrip returns.
+ if err := waitDone(); err != nil {
+ return nil, err
+ }
+ }
+ return res, nil
+ }
+
+ cancelRequest := func(cs *http2clientStream, err error) error {
+ cs.cc.mu.Lock()
+ bodyClosed := cs.reqBodyClosed
+ cs.cc.mu.Unlock()
+ // Wait for the request body to be closed.
+ //
+ // If nothing closed the body before now, abortStreamLocked
+ // will have started a goroutine to close it.
+ //
+ // Closing the body before returning avoids a race condition
+ // with net/http checking its readTrackingBody to see if the
+ // body was read from or closed. See golang/go#60041.
+ //
+ // The body is closed in a separate goroutine without the
+ // connection mutex held, but dropping the mutex before waiting
+ // will keep us from holding it indefinitely if the body
+ // close is slow for some reason.
+ if bodyClosed != nil {
+ <-bodyClosed
+ }
+ return err
+ }
+
+ for {
+ select {
+ case <-cs.respHeaderRecv:
+ return handleResponseHeaders()
+ case <-cs.abort:
+ select {
+ case <-cs.respHeaderRecv:
+ // If both cs.respHeaderRecv and cs.abort are signaling,
+ // pick respHeaderRecv. The server probably wrote the
+ // response and immediately reset the stream.
+ // golang.org/issue/49645
+ return handleResponseHeaders()
+ default:
+ waitDone()
+ return nil, cs.abortErr
+ }
+ case <-ctx.Done():
+ err := ctx.Err()
+ cs.abortStream(err)
+ return nil, cancelRequest(cs, err)
+ case <-cs.reqCancel:
+ cs.abortStream(http2errRequestCanceled)
+ return nil, cancelRequest(cs, http2errRequestCanceled)
+ }
+ }
+}
+
+// doRequest runs for the duration of the request lifetime.
+//
+// It sends the request and performs post-request cleanup (closing Request.Body, etc.).
+func (cs *http2clientStream) doRequest(req *Request) {
+ err := cs.writeRequest(req)
+ cs.cleanupWriteRequest(err)
+}
+
+// writeRequest sends a request.
+//
+// It returns nil after the request is written, the response read,
+// and the request stream is half-closed by the peer.
+//
+// It returns non-nil if the request ends otherwise.
+// If the returned error is StreamError, the error Code may be used in resetting the stream.
+func (cs *http2clientStream) writeRequest(req *Request) (err error) {
+ cc := cs.cc
+ ctx := cs.ctx
+
+ if err := http2checkConnHeaders(req); err != nil {
+ return err
+ }
+
+ // Acquire the new-request lock by writing to reqHeaderMu.
+ // This lock guards the critical section covering allocating a new stream ID
+ // (requires mu) and creating the stream (requires wmu).
+ if cc.reqHeaderMu == nil {
+ panic("RoundTrip on uninitialized ClientConn") // for tests
+ }
+ select {
+ case cc.reqHeaderMu <- struct{}{}:
+ case <-cs.reqCancel:
+ return http2errRequestCanceled
+ case <-ctx.Done():
+ return ctx.Err()
+ }
+
+ cc.mu.Lock()
+ if cc.idleTimer != nil {
+ cc.idleTimer.Stop()
+ }
+ cc.decrStreamReservationsLocked()
+ if err := cc.awaitOpenSlotForStreamLocked(cs); err != nil {
+ cc.mu.Unlock()
+ <-cc.reqHeaderMu
+ return err
+ }
+ cc.addStreamLocked(cs) // assigns stream ID
+ if http2isConnectionCloseRequest(req) {
+ cc.doNotReuse = true
+ }
+ cc.mu.Unlock()
+
+ // TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere?
+ if !cc.t.disableCompression() &&
+ req.Header.Get("Accept-Encoding") == "" &&
+ req.Header.Get("Range") == "" &&
+ !cs.isHead {
+ // Request gzip only, not deflate. Deflate is ambiguous and
+ // not as universally supported anyway.
+ // See: https://zlib.net/zlib_faq.html#faq39
+ //
+ // Note that we don't request this for HEAD requests,
+ // due to a bug in nginx:
+ // http://trac.nginx.org/nginx/ticket/358
+ // https://golang.org/issue/5522
+ //
+ // We don't request gzip if the request is for a range, since
+ // auto-decoding a portion of a gzipped document will just fail
+ // anyway. See https://golang.org/issue/8923
+ cs.requestedGzip = true
+ }
+
+ continueTimeout := cc.t.expectContinueTimeout()
+ if continueTimeout != 0 {
+ if !httpguts.HeaderValuesContainsToken(req.Header["Expect"], "100-continue") {
+ continueTimeout = 0
+ } else {
+ cs.on100 = make(chan struct{}, 1)
+ }
+ }
+
+ // Past this point (where we send request headers), it is possible for
+ // RoundTrip to return successfully. Since the RoundTrip contract permits
+ // the caller to "mutate or reuse" the Request after closing the Response's Body,
+ // we must take care when referencing the Request from here on.
+ err = cs.encodeAndWriteHeaders(req)
+ <-cc.reqHeaderMu
+ if err != nil {
+ return err
+ }
+
+ hasBody := cs.reqBodyContentLength != 0
+ if !hasBody {
+ cs.sentEndStream = true
+ } else {
+ if continueTimeout != 0 {
+ http2traceWait100Continue(cs.trace)
+ timer := time.NewTimer(continueTimeout)
+ select {
+ case <-timer.C:
+ err = nil
+ case <-cs.on100:
+ err = nil
+ case <-cs.abort:
+ err = cs.abortErr
+ case <-ctx.Done():
+ err = ctx.Err()
+ case <-cs.reqCancel:
+ err = http2errRequestCanceled
+ }
+ timer.Stop()
+ if err != nil {
+ http2traceWroteRequest(cs.trace, err)
+ return err
+ }
+ }
+
+ if err = cs.writeRequestBody(req); err != nil {
+ if err != http2errStopReqBodyWrite {
+ http2traceWroteRequest(cs.trace, err)
+ return err
+ }
+ } else {
+ cs.sentEndStream = true
+ }
+ }
+
+ http2traceWroteRequest(cs.trace, err)
+
+ var respHeaderTimer <-chan time.Time
+ var respHeaderRecv chan struct{}
+ if d := cc.responseHeaderTimeout(); d != 0 {
+ timer := time.NewTimer(d)
+ defer timer.Stop()
+ respHeaderTimer = timer.C
+ respHeaderRecv = cs.respHeaderRecv
+ }
+ // Wait until the peer half-closes its end of the stream,
+ // or until the request is aborted (via context, error, or otherwise),
+ // whichever comes first.
+ for {
+ select {
+ case <-cs.peerClosed:
+ return nil
+ case <-respHeaderTimer:
+ return http2errTimeout
+ case <-respHeaderRecv:
+ respHeaderRecv = nil
+ respHeaderTimer = nil // keep waiting for END_STREAM
+ case <-cs.abort:
+ return cs.abortErr
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-cs.reqCancel:
+ return http2errRequestCanceled
+ }
+ }
+}
+
+func (cs *http2clientStream) encodeAndWriteHeaders(req *Request) error {
+ cc := cs.cc
+ ctx := cs.ctx
+
+ cc.wmu.Lock()
+ defer cc.wmu.Unlock()
+
+ // If the request was canceled while waiting for cc.mu, just quit.
+ select {
+ case <-cs.abort:
+ return cs.abortErr
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-cs.reqCancel:
+ return http2errRequestCanceled
+ default:
+ }
+
+ // Encode headers.
+ //
+ // we send: HEADERS{1}, CONTINUATION{0,} + DATA{0,} (DATA is
+ // sent by writeRequestBody below, along with any Trailers,
+ // again in form HEADERS{1}, CONTINUATION{0,})
+ trailers, err := http2commaSeparatedTrailers(req)
+ if err != nil {
+ return err
+ }
+ hasTrailers := trailers != ""
+ contentLen := http2actualContentLength(req)
+ hasBody := contentLen != 0
+ hdrs, err := cc.encodeHeaders(req, cs.requestedGzip, trailers, contentLen)
+ if err != nil {
+ return err
+ }
+
+ // Write the request.
+ endStream := !hasBody && !hasTrailers
+ cs.sentHeaders = true
+ err = cc.writeHeaders(cs.ID, endStream, int(cc.maxFrameSize), hdrs)
+ http2traceWroteHeaders(cs.trace)
+ return err
+}
+
+// cleanupWriteRequest performs post-request tasks.
+//
+// If err (the result of writeRequest) is non-nil and the stream is not closed,
+// cleanupWriteRequest will send a reset to the peer.
+func (cs *http2clientStream) cleanupWriteRequest(err error) {
+ cc := cs.cc
+
+ if cs.ID == 0 {
+ // We were canceled before creating the stream, so return our reservation.
+ cc.decrStreamReservations()
+ }
+
+ // TODO: write h12Compare test showing whether
+ // Request.Body is closed by the Transport,
+ // and in multiple cases: server replies <=299 and >299
+ // while still writing request body
+ cc.mu.Lock()
+ mustCloseBody := false
+ if cs.reqBody != nil && cs.reqBodyClosed == nil {
+ mustCloseBody = true
+ cs.reqBodyClosed = make(chan struct{})
+ }
+ bodyClosed := cs.reqBodyClosed
+ cc.mu.Unlock()
+ if mustCloseBody {
+ cs.reqBody.Close()
+ close(bodyClosed)
+ }
+ if bodyClosed != nil {
+ <-bodyClosed
+ }
+
+ if err != nil && cs.sentEndStream {
+ // If the connection is closed immediately after the response is read,
+ // we may be aborted before finishing up here. If the stream was closed
+ // cleanly on both sides, there is no error.
+ select {
+ case <-cs.peerClosed:
+ err = nil
+ default:
+ }
+ }
+ if err != nil {
+ cs.abortStream(err) // possibly redundant, but harmless
+ if cs.sentHeaders {
+ if se, ok := err.(http2StreamError); ok {
+ if se.Cause != http2errFromPeer {
+ cc.writeStreamReset(cs.ID, se.Code, err)
+ }
+ } else {
+ cc.writeStreamReset(cs.ID, http2ErrCodeCancel, err)
+ }
+ }
+ cs.bufPipe.CloseWithError(err) // no-op if already closed
+ } else {
+ if cs.sentHeaders && !cs.sentEndStream {
+ cc.writeStreamReset(cs.ID, http2ErrCodeNo, nil)
+ }
+ cs.bufPipe.CloseWithError(http2errRequestCanceled)
+ }
+ if cs.ID != 0 {
+ cc.forgetStreamID(cs.ID)
+ }
+
+ cc.wmu.Lock()
+ werr := cc.werr
+ cc.wmu.Unlock()
+ if werr != nil {
+ cc.Close()
+ }
+
+ close(cs.donec)
+}
+
+// awaitOpenSlotForStreamLocked waits until len(streams) < maxConcurrentStreams.
+// Must hold cc.mu.
+func (cc *http2ClientConn) awaitOpenSlotForStreamLocked(cs *http2clientStream) error {
+ for {
+ cc.lastActive = time.Now()
+ if cc.closed || !cc.canTakeNewRequestLocked() {
+ return http2errClientConnUnusable
+ }
+ cc.lastIdle = time.Time{}
+ if int64(len(cc.streams)) < int64(cc.maxConcurrentStreams) {
+ return nil
+ }
+ cc.pendingRequests++
+ cc.cond.Wait()
+ cc.pendingRequests--
+ select {
+ case <-cs.abort:
+ return cs.abortErr
+ default:
+ }
+ }
+}
+
+// requires cc.wmu be held
+func (cc *http2ClientConn) writeHeaders(streamID uint32, endStream bool, maxFrameSize int, hdrs []byte) error {
+ first := true // first frame written (HEADERS is first, then CONTINUATION)
+ for len(hdrs) > 0 && cc.werr == nil {
+ chunk := hdrs
+ if len(chunk) > maxFrameSize {
+ chunk = chunk[:maxFrameSize]
+ }
+ hdrs = hdrs[len(chunk):]
+ endHeaders := len(hdrs) == 0
+ if first {
+ cc.fr.WriteHeaders(http2HeadersFrameParam{
+ StreamID: streamID,
+ BlockFragment: chunk,
+ EndStream: endStream,
+ EndHeaders: endHeaders,
+ })
+ first = false
+ } else {
+ cc.fr.WriteContinuation(streamID, endHeaders, chunk)
+ }
+ }
+ cc.bw.Flush()
+ return cc.werr
+}
+
+// internal error values; they don't escape to callers
+var (
+ // abort request body write; don't send cancel
+ http2errStopReqBodyWrite = errors.New("http2: aborting request body write")
+
+ // abort request body write, but send stream reset of cancel.
+ http2errStopReqBodyWriteAndCancel = errors.New("http2: canceling request")
+
+ http2errReqBodyTooLong = errors.New("http2: request body larger than specified content length")
+)
+
+// frameScratchBufferLen returns the length of a buffer to use for
+// outgoing request bodies to read/write to/from.
+//
+// It returns max(1, min(peer's advertised max frame size,
+// Request.ContentLength+1, 512KB)).
+func (cs *http2clientStream) frameScratchBufferLen(maxFrameSize int) int {
+ const max = 512 << 10
+ n := int64(maxFrameSize)
+ if n > max {
+ n = max
+ }
+ if cl := cs.reqBodyContentLength; cl != -1 && cl+1 < n {
+ // Add an extra byte past the declared content-length to
+ // give the caller's Request.Body io.Reader a chance to
+ // give us more bytes than they declared, so we can catch it
+ // early.
+ n = cl + 1
+ }
+ if n < 1 {
+ return 1
+ }
+ return int(n) // doesn't truncate; max is 512K
+}
+
+var http2bufPool sync.Pool // of *[]byte
+
+func (cs *http2clientStream) writeRequestBody(req *Request) (err error) {
+ cc := cs.cc
+ body := cs.reqBody
+ sentEnd := false // whether we sent the final DATA frame w/ END_STREAM
+
+ hasTrailers := req.Trailer != nil
+ remainLen := cs.reqBodyContentLength
+ hasContentLen := remainLen != -1
+
+ cc.mu.Lock()
+ maxFrameSize := int(cc.maxFrameSize)
+ cc.mu.Unlock()
+
+ // Scratch buffer for reading into & writing from.
+ scratchLen := cs.frameScratchBufferLen(maxFrameSize)
+ var buf []byte
+ if bp, ok := http2bufPool.Get().(*[]byte); ok && len(*bp) >= scratchLen {
+ defer http2bufPool.Put(bp)
+ buf = *bp
+ } else {
+ buf = make([]byte, scratchLen)
+ defer http2bufPool.Put(&buf)
+ }
+
+ var sawEOF bool
+ for !sawEOF {
+ n, err := body.Read(buf)
+ if hasContentLen {
+ remainLen -= int64(n)
+ if remainLen == 0 && err == nil {
+ // The request body's Content-Length was predeclared and
+ // we just finished reading it all, but the underlying io.Reader
+ // returned the final chunk with a nil error (which is one of
+ // the two valid things a Reader can do at EOF). Because we'd prefer
+ // to send the END_STREAM bit early, double-check that we're actually
+ // at EOF. Subsequent reads should return (0, EOF) at this point.
+ // If either value is different, we return an error in one of two ways below.
+ var scratch [1]byte
+ var n1 int
+ n1, err = body.Read(scratch[:])
+ remainLen -= int64(n1)
+ }
+ if remainLen < 0 {
+ err = http2errReqBodyTooLong
+ return err
+ }
+ }
+ if err != nil {
+ cc.mu.Lock()
+ bodyClosed := cs.reqBodyClosed != nil
+ cc.mu.Unlock()
+ switch {
+ case bodyClosed:
+ return http2errStopReqBodyWrite
+ case err == io.EOF:
+ sawEOF = true
+ err = nil
+ default:
+ return err
+ }
+ }
+
+ remain := buf[:n]
+ for len(remain) > 0 && err == nil {
+ var allowed int32
+ allowed, err = cs.awaitFlowControl(len(remain))
+ if err != nil {
+ return err
+ }
+ cc.wmu.Lock()
+ data := remain[:allowed]
+ remain = remain[allowed:]
+ sentEnd = sawEOF && len(remain) == 0 && !hasTrailers
+ err = cc.fr.WriteData(cs.ID, sentEnd, data)
+ if err == nil {
+ // TODO(bradfitz): this flush is for latency, not bandwidth.
+ // Most requests won't need this. Make this opt-in or
+ // opt-out? Use some heuristic on the body type? Nagel-like
+ // timers? Based on 'n'? Only last chunk of this for loop,
+ // unless flow control tokens are low? For now, always.
+ // If we change this, see comment below.
+ err = cc.bw.Flush()
+ }
+ cc.wmu.Unlock()
+ }
+ if err != nil {
+ return err
+ }
+ }
+
+ if sentEnd {
+ // Already sent END_STREAM (which implies we have no
+ // trailers) and flushed, because currently all
+ // WriteData frames above get a flush. So we're done.
+ return nil
+ }
+
+ // Since the RoundTrip contract permits the caller to "mutate or reuse"
+ // a request after the Response's Body is closed, verify that this hasn't
+ // happened before accessing the trailers.
+ cc.mu.Lock()
+ trailer := req.Trailer
+ err = cs.abortErr
+ cc.mu.Unlock()
+ if err != nil {
+ return err
+ }
+
+ cc.wmu.Lock()
+ defer cc.wmu.Unlock()
+ var trls []byte
+ if len(trailer) > 0 {
+ trls, err = cc.encodeTrailers(trailer)
+ if err != nil {
+ return err
+ }
+ }
+
+ // Two ways to send END_STREAM: either with trailers, or
+ // with an empty DATA frame.
+ if len(trls) > 0 {
+ err = cc.writeHeaders(cs.ID, true, maxFrameSize, trls)
+ } else {
+ err = cc.fr.WriteData(cs.ID, true, nil)
+ }
+ if ferr := cc.bw.Flush(); ferr != nil && err == nil {
+ err = ferr
+ }
+ return err
+}
+
+// awaitFlowControl waits for [1, min(maxBytes, cc.cs.maxFrameSize)] flow
+// control tokens from the server.
+// It returns either the non-zero number of tokens taken or an error
+// if the stream is dead.
+func (cs *http2clientStream) awaitFlowControl(maxBytes int) (taken int32, err error) {
+ cc := cs.cc
+ ctx := cs.ctx
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ for {
+ if cc.closed {
+ return 0, http2errClientConnClosed
+ }
+ if cs.reqBodyClosed != nil {
+ return 0, http2errStopReqBodyWrite
+ }
+ select {
+ case <-cs.abort:
+ return 0, cs.abortErr
+ case <-ctx.Done():
+ return 0, ctx.Err()
+ case <-cs.reqCancel:
+ return 0, http2errRequestCanceled
+ default:
+ }
+ if a := cs.flow.available(); a > 0 {
+ take := a
+ if int(take) > maxBytes {
+
+ take = int32(maxBytes) // can't truncate int; take is int32
+ }
+ if take > int32(cc.maxFrameSize) {
+ take = int32(cc.maxFrameSize)
+ }
+ cs.flow.take(take)
+ return take, nil
+ }
+ cc.cond.Wait()
+ }
+}
+
+var http2errNilRequestURL = errors.New("http2: Request.URI is nil")
+
+// requires cc.wmu be held.
+func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trailers string, contentLength int64) ([]byte, error) {
+ cc.hbuf.Reset()
+ if req.URL == nil {
+ return nil, http2errNilRequestURL
+ }
+
+ host := req.Host
+ if host == "" {
+ host = req.URL.Host
+ }
+ host, err := httpguts.PunycodeHostPort(host)
+ if err != nil {
+ return nil, err
+ }
+ if !httpguts.ValidHostHeader(host) {
+ return nil, errors.New("http2: invalid Host header")
+ }
+
+ var path string
+ if req.Method != "CONNECT" {
+ path = req.URL.RequestURI()
+ if !http2validPseudoPath(path) {
+ orig := path
+ path = strings.TrimPrefix(path, req.URL.Scheme+"://"+host)
+ if !http2validPseudoPath(path) {
+ if req.URL.Opaque != "" {
+ return nil, fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque)
+ } else {
+ return nil, fmt.Errorf("invalid request :path %q", orig)
+ }
+ }
+ }
+ }
+
+ // Check for any invalid headers and return an error before we
+ // potentially pollute our hpack state. (We want to be able to
+ // continue to reuse the hpack encoder for future requests)
+ for k, vv := range req.Header {
+ if !httpguts.ValidHeaderFieldName(k) {
+ return nil, fmt.Errorf("invalid HTTP header name %q", k)
+ }
+ for _, v := range vv {
+ if !httpguts.ValidHeaderFieldValue(v) {
+ // Don't include the value in the error, because it may be sensitive.
+ return nil, fmt.Errorf("invalid HTTP header value for header %q", k)
+ }
+ }
+ }
+
+ enumerateHeaders := func(f func(name, value string)) {
+ // 8.1.2.3 Request Pseudo-Header Fields
+ // The :path pseudo-header field includes the path and query parts of the
+ // target URI (the path-absolute production and optionally a '?' character
+ // followed by the query production, see Sections 3.3 and 3.4 of
+ // [RFC3986]).
+ f(":authority", host)
+ m := req.Method
+ if m == "" {
+ m = MethodGet
+ }
+ f(":method", m)
+ if req.Method != "CONNECT" {
+ f(":path", path)
+ f(":scheme", req.URL.Scheme)
+ }
+ if trailers != "" {
+ f("trailer", trailers)
+ }
+
+ var didUA bool
+ for k, vv := range req.Header {
+ if http2asciiEqualFold(k, "host") || http2asciiEqualFold(k, "content-length") {
+ // Host is :authority, already sent.
+ // Content-Length is automatic, set below.
+ continue
+ } else if http2asciiEqualFold(k, "connection") ||
+ http2asciiEqualFold(k, "proxy-connection") ||
+ http2asciiEqualFold(k, "transfer-encoding") ||
+ http2asciiEqualFold(k, "upgrade") ||
+ http2asciiEqualFold(k, "keep-alive") {
+ // Per 8.1.2.2 Connection-Specific Header
+ // Fields, don't send connection-specific
+ // fields. We have already checked if any
+ // are error-worthy so just ignore the rest.
+ continue
+ } else if http2asciiEqualFold(k, "user-agent") {
+ // Match Go's http1 behavior: at most one
+ // User-Agent. If set to nil or empty string,
+ // then omit it. Otherwise if not mentioned,
+ // include the default (below).
+ didUA = true
+ if len(vv) < 1 {
+ continue
+ }
+ vv = vv[:1]
+ if vv[0] == "" {
+ continue
+ }
+ } else if http2asciiEqualFold(k, "cookie") {
+ // Per 8.1.2.5 To allow for better compression efficiency, the
+ // Cookie header field MAY be split into separate header fields,
+ // each with one or more cookie-pairs.
+ for _, v := range vv {
+ for {
+ p := strings.IndexByte(v, ';')
+ if p < 0 {
+ break
+ }
+ f("cookie", v[:p])
+ p++
+ // strip space after semicolon if any.
+ for p+1 <= len(v) && v[p] == ' ' {
+ p++
+ }
+ v = v[p:]
+ }
+ if len(v) > 0 {
+ f("cookie", v)
+ }
+ }
+ continue
+ }
+
+ for _, v := range vv {
+ f(k, v)
+ }
+ }
+ if http2shouldSendReqContentLength(req.Method, contentLength) {
+ f("content-length", strconv.FormatInt(contentLength, 10))
+ }
+ if addGzipHeader {
+ f("accept-encoding", "gzip")
+ }
+ if !didUA {
+ f("user-agent", http2defaultUserAgent)
+ }
+ }
+
+ // Do a first pass over the headers counting bytes to ensure
+ // we don't exceed cc.peerMaxHeaderListSize. This is done as a
+ // separate pass before encoding the headers to prevent
+ // modifying the hpack state.
+ hlSize := uint64(0)
+ enumerateHeaders(func(name, value string) {
+ hf := hpack.HeaderField{Name: name, Value: value}
+ hlSize += uint64(hf.Size())
+ })
+
+ if hlSize > cc.peerMaxHeaderListSize {
+ return nil, http2errRequestHeaderListSize
+ }
+
+ trace := httptrace.ContextClientTrace(req.Context())
+ traceHeaders := http2traceHasWroteHeaderField(trace)
+
+ // Header list size is ok. Write the headers.
+ enumerateHeaders(func(name, value string) {
+ name, ascii := http2lowerHeader(name)
+ if !ascii {
+ // Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header
+ // field names have to be ASCII characters (just as in HTTP/1.x).
+ return
+ }
+ cc.writeHeader(name, value)
+ if traceHeaders {
+ http2traceWroteHeaderField(trace, name, value)
+ }
+ })
+
+ return cc.hbuf.Bytes(), nil
+}
+
+// shouldSendReqContentLength reports whether the http2.Transport should send
+// a "content-length" request header. This logic is basically a copy of the net/http
+// transferWriter.shouldSendContentLength.
+// The contentLength is the corrected contentLength (so 0 means actually 0, not unknown).
+// -1 means unknown.
+func http2shouldSendReqContentLength(method string, contentLength int64) bool {
+ if contentLength > 0 {
+ return true
+ }
+ if contentLength < 0 {
+ return false
+ }
+ // For zero bodies, whether we send a content-length depends on the method.
+ // It also kinda doesn't matter for http2 either way, with END_STREAM.
+ switch method {
+ case "POST", "PUT", "PATCH":
+ return true
+ default:
+ return false
+ }
+}
+
+// requires cc.wmu be held.
+func (cc *http2ClientConn) encodeTrailers(trailer Header) ([]byte, error) {
+ cc.hbuf.Reset()
+
+ hlSize := uint64(0)
+ for k, vv := range trailer {
+ for _, v := range vv {
+ hf := hpack.HeaderField{Name: k, Value: v}
+ hlSize += uint64(hf.Size())
+ }
+ }
+ if hlSize > cc.peerMaxHeaderListSize {
+ return nil, http2errRequestHeaderListSize
+ }
+
+ for k, vv := range trailer {
+ lowKey, ascii := http2lowerHeader(k)
+ if !ascii {
+ // Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header
+ // field names have to be ASCII characters (just as in HTTP/1.x).
+ continue
+ }
+ // Transfer-Encoding, etc.. have already been filtered at the
+ // start of RoundTrip
+ for _, v := range vv {
+ cc.writeHeader(lowKey, v)
+ }
+ }
+ return cc.hbuf.Bytes(), nil
+}
+
+func (cc *http2ClientConn) writeHeader(name, value string) {
+ if http2VerboseLogs {
+ log.Printf("http2: Transport encoding header %q = %q", name, value)
+ }
+ cc.henc.WriteField(hpack.HeaderField{Name: name, Value: value})
+}
+
+type http2resAndError struct {
+ _ http2incomparable
+ res *Response
+ err error
+}
+
+// requires cc.mu be held.
+func (cc *http2ClientConn) addStreamLocked(cs *http2clientStream) {
+ cs.flow.add(int32(cc.initialWindowSize))
+ cs.flow.setConnFlow(&cc.flow)
+ cs.inflow.init(http2transportDefaultStreamFlow)
+ cs.ID = cc.nextStreamID
+ cc.nextStreamID += 2
+ cc.streams[cs.ID] = cs
+ if cs.ID == 0 {
+ panic("assigned stream ID 0")
+ }
+}
+
+func (cc *http2ClientConn) forgetStreamID(id uint32) {
+ cc.mu.Lock()
+ slen := len(cc.streams)
+ delete(cc.streams, id)
+ if len(cc.streams) != slen-1 {
+ panic("forgetting unknown stream id")
+ }
+ cc.lastActive = time.Now()
+ if len(cc.streams) == 0 && cc.idleTimer != nil {
+ cc.idleTimer.Reset(cc.idleTimeout)
+ cc.lastIdle = time.Now()
+ }
+ // Wake up writeRequestBody via clientStream.awaitFlowControl and
+ // wake up RoundTrip if there is a pending request.
+ cc.cond.Broadcast()
+
+ closeOnIdle := cc.singleUse || cc.doNotReuse || cc.t.disableKeepAlives() || cc.goAway != nil
+ if closeOnIdle && cc.streamsReserved == 0 && len(cc.streams) == 0 {
+ if http2VerboseLogs {
+ cc.vlogf("http2: Transport closing idle conn %p (forSingleUse=%v, maxStream=%v)", cc, cc.singleUse, cc.nextStreamID-2)
+ }
+ cc.closed = true
+ defer cc.closeConn()
+ }
+
+ cc.mu.Unlock()
+}
+
+// clientConnReadLoop is the state owned by the clientConn's frame-reading readLoop.
+type http2clientConnReadLoop struct {
+ _ http2incomparable
+ cc *http2ClientConn
+}
+
+// readLoop runs in its own goroutine and reads and dispatches frames.
+func (cc *http2ClientConn) readLoop() {
+ rl := &http2clientConnReadLoop{cc: cc}
+ defer rl.cleanup()
+ cc.readerErr = rl.run()
+ if ce, ok := cc.readerErr.(http2ConnectionError); ok {
+ cc.wmu.Lock()
+ cc.fr.WriteGoAway(0, http2ErrCode(ce), nil)
+ cc.wmu.Unlock()
+ }
+}
+
+// GoAwayError is returned by the Transport when the server closes the
+// TCP connection after sending a GOAWAY frame.
+type http2GoAwayError struct {
+ LastStreamID uint32
+ ErrCode http2ErrCode
+ DebugData string
+}
+
+func (e http2GoAwayError) Error() string {
+ return fmt.Sprintf("http2: server sent GOAWAY and closed the connection; LastStreamID=%v, ErrCode=%v, debug=%q",
+ e.LastStreamID, e.ErrCode, e.DebugData)
+}
+
+func http2isEOFOrNetReadError(err error) bool {
+ if err == io.EOF {
+ return true
+ }
+ ne, ok := err.(*net.OpError)
+ return ok && ne.Op == "read"
+}
+
+func (rl *http2clientConnReadLoop) cleanup() {
+ cc := rl.cc
+ cc.t.connPool().MarkDead(cc)
+ defer cc.closeConn()
+ defer close(cc.readerDone)
+
+ if cc.idleTimer != nil {
+ cc.idleTimer.Stop()
+ }
+
+ // Close any response bodies if the server closes prematurely.
+ // TODO: also do this if we've written the headers but not
+ // gotten a response yet.
+ err := cc.readerErr
+ cc.mu.Lock()
+ if cc.goAway != nil && http2isEOFOrNetReadError(err) {
+ err = http2GoAwayError{
+ LastStreamID: cc.goAway.LastStreamID,
+ ErrCode: cc.goAway.ErrCode,
+ DebugData: cc.goAwayDebug,
+ }
+ } else if err == io.EOF {
+ err = io.ErrUnexpectedEOF
+ }
+ cc.closed = true
+
+ for _, cs := range cc.streams {
+ select {
+ case <-cs.peerClosed:
+ // The server closed the stream before closing the conn,
+ // so no need to interrupt it.
+ default:
+ cs.abortStreamLocked(err)
+ }
+ }
+ cc.cond.Broadcast()
+ cc.mu.Unlock()
+}
+
+// countReadFrameError calls Transport.CountError with a string
+// representing err.
+func (cc *http2ClientConn) countReadFrameError(err error) {
+ f := cc.t.CountError
+ if f == nil || err == nil {
+ return
+ }
+ if ce, ok := err.(http2ConnectionError); ok {
+ errCode := http2ErrCode(ce)
+ f(fmt.Sprintf("read_frame_conn_error_%s", errCode.stringToken()))
+ return
+ }
+ if errors.Is(err, io.EOF) {
+ f("read_frame_eof")
+ return
+ }
+ if errors.Is(err, io.ErrUnexpectedEOF) {
+ f("read_frame_unexpected_eof")
+ return
+ }
+ if errors.Is(err, http2ErrFrameTooLarge) {
+ f("read_frame_too_large")
+ return
+ }
+ f("read_frame_other")
+}
+
+func (rl *http2clientConnReadLoop) run() error {
+ cc := rl.cc
+ gotSettings := false
+ readIdleTimeout := cc.t.ReadIdleTimeout
+ var t *time.Timer
+ if readIdleTimeout != 0 {
+ t = time.AfterFunc(readIdleTimeout, cc.healthCheck)
+ defer t.Stop()
+ }
+ for {
+ f, err := cc.fr.ReadFrame()
+ if t != nil {
+ t.Reset(readIdleTimeout)
+ }
+ if err != nil {
+ cc.vlogf("http2: Transport readFrame error on conn %p: (%T) %v", cc, err, err)
+ }
+ if se, ok := err.(http2StreamError); ok {
+ if cs := rl.streamByID(se.StreamID); cs != nil {
+ if se.Cause == nil {
+ se.Cause = cc.fr.errDetail
+ }
+ rl.endStreamError(cs, se)
+ }
+ continue
+ } else if err != nil {
+ cc.countReadFrameError(err)
+ return err
+ }
+ if http2VerboseLogs {
+ cc.vlogf("http2: Transport received %s", http2summarizeFrame(f))
+ }
+ if !gotSettings {
+ if _, ok := f.(*http2SettingsFrame); !ok {
+ cc.logf("protocol error: received %T before a SETTINGS frame", f)
+ return http2ConnectionError(http2ErrCodeProtocol)
+ }
+ gotSettings = true
+ }
+
+ switch f := f.(type) {
+ case *http2MetaHeadersFrame:
+ err = rl.processHeaders(f)
+ case *http2DataFrame:
+ err = rl.processData(f)
+ case *http2GoAwayFrame:
+ err = rl.processGoAway(f)
+ case *http2RSTStreamFrame:
+ err = rl.processResetStream(f)
+ case *http2SettingsFrame:
+ err = rl.processSettings(f)
+ case *http2PushPromiseFrame:
+ err = rl.processPushPromise(f)
+ case *http2WindowUpdateFrame:
+ err = rl.processWindowUpdate(f)
+ case *http2PingFrame:
+ err = rl.processPing(f)
+ default:
+ cc.logf("Transport: unhandled response frame type %T", f)
+ }
+ if err != nil {
+ if http2VerboseLogs {
+ cc.vlogf("http2: Transport conn %p received error from processing frame %v: %v", cc, http2summarizeFrame(f), err)
+ }
+ return err
+ }
+ }
+}
+
+func (rl *http2clientConnReadLoop) processHeaders(f *http2MetaHeadersFrame) error {
+ cs := rl.streamByID(f.StreamID)
+ if cs == nil {
+ // We'd get here if we canceled a request while the
+ // server had its response still in flight. So if this
+ // was just something we canceled, ignore it.
+ return nil
+ }
+ if cs.readClosed {
+ rl.endStreamError(cs, http2StreamError{
+ StreamID: f.StreamID,
+ Code: http2ErrCodeProtocol,
+ Cause: errors.New("protocol error: headers after END_STREAM"),
+ })
+ return nil
+ }
+ if !cs.firstByte {
+ if cs.trace != nil {
+ // TODO(bradfitz): move first response byte earlier,
+ // when we first read the 9 byte header, not waiting
+ // until all the HEADERS+CONTINUATION frames have been
+ // merged. This works for now.
+ http2traceFirstResponseByte(cs.trace)
+ }
+ cs.firstByte = true
+ }
+ if !cs.pastHeaders {
+ cs.pastHeaders = true
+ } else {
+ return rl.processTrailers(cs, f)
+ }
+
+ res, err := rl.handleResponse(cs, f)
+ if err != nil {
+ if _, ok := err.(http2ConnectionError); ok {
+ return err
+ }
+ // Any other error type is a stream error.
+ rl.endStreamError(cs, http2StreamError{
+ StreamID: f.StreamID,
+ Code: http2ErrCodeProtocol,
+ Cause: err,
+ })
+ return nil // return nil from process* funcs to keep conn alive
+ }
+ if res == nil {
+ // (nil, nil) special case. See handleResponse docs.
+ return nil
+ }
+ cs.resTrailer = &res.Trailer
+ cs.res = res
+ close(cs.respHeaderRecv)
+ if f.StreamEnded() {
+ rl.endStream(cs)
+ }
+ return nil
+}
+
+// may return error types nil, or ConnectionError. Any other error value
+// is a StreamError of type ErrCodeProtocol. The returned error in that case
+// is the detail.
+//
+// As a special case, handleResponse may return (nil, nil) to skip the
+// frame (currently only used for 1xx responses).
+func (rl *http2clientConnReadLoop) handleResponse(cs *http2clientStream, f *http2MetaHeadersFrame) (*Response, error) {
+ if f.Truncated {
+ return nil, http2errResponseHeaderListSize
+ }
+
+ status := f.PseudoValue("status")
+ if status == "" {
+ return nil, errors.New("malformed response from server: missing status pseudo header")
+ }
+ statusCode, err := strconv.Atoi(status)
+ if err != nil {
+ return nil, errors.New("malformed response from server: malformed non-numeric status pseudo header")
+ }
+
+ regularFields := f.RegularFields()
+ strs := make([]string, len(regularFields))
+ header := make(Header, len(regularFields))
+ res := &Response{
+ Proto: "HTTP/2.0",
+ ProtoMajor: 2,
+ Header: header,
+ StatusCode: statusCode,
+ Status: status + " " + StatusText(statusCode),
+ }
+ for _, hf := range regularFields {
+ key := http2canonicalHeader(hf.Name)
+ if key == "Trailer" {
+ t := res.Trailer
+ if t == nil {
+ t = make(Header)
+ res.Trailer = t
+ }
+ http2foreachHeaderElement(hf.Value, func(v string) {
+ t[http2canonicalHeader(v)] = nil
+ })
+ } else {
+ vv := header[key]
+ if vv == nil && len(strs) > 0 {
+ // More than likely this will be a single-element key.
+ // Most headers aren't multi-valued.
+ // Set the capacity on strs[0] to 1, so any future append
+ // won't extend the slice into the other strings.
+ vv, strs = strs[:1:1], strs[1:]
+ vv[0] = hf.Value
+ header[key] = vv
+ } else {
+ header[key] = append(vv, hf.Value)
+ }
+ }
+ }
+
+ if statusCode >= 100 && statusCode <= 199 {
+ if f.StreamEnded() {
+ return nil, errors.New("1xx informational response with END_STREAM flag")
+ }
+ cs.num1xx++
+ const max1xxResponses = 5 // arbitrary bound on number of informational responses, same as net/http
+ if cs.num1xx > max1xxResponses {
+ return nil, errors.New("http2: too many 1xx informational responses")
+ }
+ if fn := cs.get1xxTraceFunc(); fn != nil {
+ if err := fn(statusCode, textproto.MIMEHeader(header)); err != nil {
+ return nil, err
+ }
+ }
+ if statusCode == 100 {
+ http2traceGot100Continue(cs.trace)
+ select {
+ case cs.on100 <- struct{}{}:
+ default:
+ }
+ }
+ cs.pastHeaders = false // do it all again
+ return nil, nil
+ }
+
+ res.ContentLength = -1
+ if clens := res.Header["Content-Length"]; len(clens) == 1 {
+ if cl, err := strconv.ParseUint(clens[0], 10, 63); err == nil {
+ res.ContentLength = int64(cl)
+ } else {
+ // TODO: care? unlike http/1, it won't mess up our framing, so it's
+ // more safe smuggling-wise to ignore.
+ }
+ } else if len(clens) > 1 {
+ // TODO: care? unlike http/1, it won't mess up our framing, so it's
+ // more safe smuggling-wise to ignore.
+ } else if f.StreamEnded() && !cs.isHead {
+ res.ContentLength = 0
+ }
+
+ if cs.isHead {
+ res.Body = http2noBody
+ return res, nil
+ }
+
+ if f.StreamEnded() {
+ if res.ContentLength > 0 {
+ res.Body = http2missingBody{}
+ } else {
+ res.Body = http2noBody
+ }
+ return res, nil
+ }
+
+ cs.bufPipe.setBuffer(&http2dataBuffer{expected: res.ContentLength})
+ cs.bytesRemain = res.ContentLength
+ res.Body = http2transportResponseBody{cs}
+
+ if cs.requestedGzip && http2asciiEqualFold(res.Header.Get("Content-Encoding"), "gzip") {
+ res.Header.Del("Content-Encoding")
+ res.Header.Del("Content-Length")
+ res.ContentLength = -1
+ res.Body = &http2gzipReader{body: res.Body}
+ res.Uncompressed = true
+ }
+ return res, nil
+}
+
+func (rl *http2clientConnReadLoop) processTrailers(cs *http2clientStream, f *http2MetaHeadersFrame) error {
+ if cs.pastTrailers {
+ // Too many HEADERS frames for this stream.
+ return http2ConnectionError(http2ErrCodeProtocol)
+ }
+ cs.pastTrailers = true
+ if !f.StreamEnded() {
+ // We expect that any headers for trailers also
+ // has END_STREAM.
+ return http2ConnectionError(http2ErrCodeProtocol)
+ }
+ if len(f.PseudoFields()) > 0 {
+ // No pseudo header fields are defined for trailers.
+ // TODO: ConnectionError might be overly harsh? Check.
+ return http2ConnectionError(http2ErrCodeProtocol)
+ }
+
+ trailer := make(Header)
+ for _, hf := range f.RegularFields() {
+ key := http2canonicalHeader(hf.Name)
+ trailer[key] = append(trailer[key], hf.Value)
+ }
+ cs.trailer = trailer
+
+ rl.endStream(cs)
+ return nil
+}
+
+// transportResponseBody is the concrete type of Transport.RoundTrip's
+// Response.Body. It is an io.ReadCloser.
+type http2transportResponseBody struct {
+ cs *http2clientStream
+}
+
+func (b http2transportResponseBody) Read(p []byte) (n int, err error) {
+ cs := b.cs
+ cc := cs.cc
+
+ if cs.readErr != nil {
+ return 0, cs.readErr
+ }
+ n, err = b.cs.bufPipe.Read(p)
+ if cs.bytesRemain != -1 {
+ if int64(n) > cs.bytesRemain {
+ n = int(cs.bytesRemain)
+ if err == nil {
+ err = errors.New("net/http: server replied with more than declared Content-Length; truncated")
+ cs.abortStream(err)
+ }
+ cs.readErr = err
+ return int(cs.bytesRemain), err
+ }
+ cs.bytesRemain -= int64(n)
+ if err == io.EOF && cs.bytesRemain > 0 {
+ err = io.ErrUnexpectedEOF
+ cs.readErr = err
+ return n, err
+ }
+ }
+ if n == 0 {
+ // No flow control tokens to send back.
+ return
+ }
+
+ cc.mu.Lock()
+ connAdd := cc.inflow.add(n)
+ var streamAdd int32
+ if err == nil { // No need to refresh if the stream is over or failed.
+ streamAdd = cs.inflow.add(n)
+ }
+ cc.mu.Unlock()
+
+ if connAdd != 0 || streamAdd != 0 {
+ cc.wmu.Lock()
+ defer cc.wmu.Unlock()
+ if connAdd != 0 {
+ cc.fr.WriteWindowUpdate(0, http2mustUint31(connAdd))
+ }
+ if streamAdd != 0 {
+ cc.fr.WriteWindowUpdate(cs.ID, http2mustUint31(streamAdd))
+ }
+ cc.bw.Flush()
+ }
+ return
+}
+
+var http2errClosedResponseBody = errors.New("http2: response body closed")
+
+func (b http2transportResponseBody) Close() error {
+ cs := b.cs
+ cc := cs.cc
+
+ cs.bufPipe.BreakWithError(http2errClosedResponseBody)
+ cs.abortStream(http2errClosedResponseBody)
+
+ unread := cs.bufPipe.Len()
+ if unread > 0 {
+ cc.mu.Lock()
+ // Return connection-level flow control.
+ connAdd := cc.inflow.add(unread)
+ cc.mu.Unlock()
+
+ // TODO(dneil): Acquiring this mutex can block indefinitely.
+ // Move flow control return to a goroutine?
+ cc.wmu.Lock()
+ // Return connection-level flow control.
+ if connAdd > 0 {
+ cc.fr.WriteWindowUpdate(0, uint32(connAdd))
+ }
+ cc.bw.Flush()
+ cc.wmu.Unlock()
+ }
+
+ select {
+ case <-cs.donec:
+ case <-cs.ctx.Done():
+ // See golang/go#49366: The net/http package can cancel the
+ // request context after the response body is fully read.
+ // Don't treat this as an error.
+ return nil
+ case <-cs.reqCancel:
+ return http2errRequestCanceled
+ }
+ return nil
+}
+
+func (rl *http2clientConnReadLoop) processData(f *http2DataFrame) error {
+ cc := rl.cc
+ cs := rl.streamByID(f.StreamID)
+ data := f.Data()
+ if cs == nil {
+ cc.mu.Lock()
+ neverSent := cc.nextStreamID
+ cc.mu.Unlock()
+ if f.StreamID >= neverSent {
+ // We never asked for this.
+ cc.logf("http2: Transport received unsolicited DATA frame; closing connection")
+ return http2ConnectionError(http2ErrCodeProtocol)
+ }
+ // We probably did ask for this, but canceled. Just ignore it.
+ // TODO: be stricter here? only silently ignore things which
+ // we canceled, but not things which were closed normally
+ // by the peer? Tough without accumulating too much state.
+
+ // But at least return their flow control:
+ if f.Length > 0 {
+ cc.mu.Lock()
+ ok := cc.inflow.take(f.Length)
+ connAdd := cc.inflow.add(int(f.Length))
+ cc.mu.Unlock()
+ if !ok {
+ return http2ConnectionError(http2ErrCodeFlowControl)
+ }
+ if connAdd > 0 {
+ cc.wmu.Lock()
+ cc.fr.WriteWindowUpdate(0, uint32(connAdd))
+ cc.bw.Flush()
+ cc.wmu.Unlock()
+ }
+ }
+ return nil
+ }
+ if cs.readClosed {
+ cc.logf("protocol error: received DATA after END_STREAM")
+ rl.endStreamError(cs, http2StreamError{
+ StreamID: f.StreamID,
+ Code: http2ErrCodeProtocol,
+ })
+ return nil
+ }
+ if !cs.firstByte {
+ cc.logf("protocol error: received DATA before a HEADERS frame")
+ rl.endStreamError(cs, http2StreamError{
+ StreamID: f.StreamID,
+ Code: http2ErrCodeProtocol,
+ })
+ return nil
+ }
+ if f.Length > 0 {
+ if cs.isHead && len(data) > 0 {
+ cc.logf("protocol error: received DATA on a HEAD request")
+ rl.endStreamError(cs, http2StreamError{
+ StreamID: f.StreamID,
+ Code: http2ErrCodeProtocol,
+ })
+ return nil
+ }
+ // Check connection-level flow control.
+ cc.mu.Lock()
+ if !http2takeInflows(&cc.inflow, &cs.inflow, f.Length) {
+ cc.mu.Unlock()
+ return http2ConnectionError(http2ErrCodeFlowControl)
+ }
+ // Return any padded flow control now, since we won't
+ // refund it later on body reads.
+ var refund int
+ if pad := int(f.Length) - len(data); pad > 0 {
+ refund += pad
+ }
+
+ didReset := false
+ var err error
+ if len(data) > 0 {
+ if _, err = cs.bufPipe.Write(data); err != nil {
+ // Return len(data) now if the stream is already closed,
+ // since data will never be read.
+ didReset = true
+ refund += len(data)
+ }
+ }
+
+ sendConn := cc.inflow.add(refund)
+ var sendStream int32
+ if !didReset {
+ sendStream = cs.inflow.add(refund)
+ }
+ cc.mu.Unlock()
+
+ if sendConn > 0 || sendStream > 0 {
+ cc.wmu.Lock()
+ if sendConn > 0 {
+ cc.fr.WriteWindowUpdate(0, uint32(sendConn))
+ }
+ if sendStream > 0 {
+ cc.fr.WriteWindowUpdate(cs.ID, uint32(sendStream))
+ }
+ cc.bw.Flush()
+ cc.wmu.Unlock()
+ }
+
+ if err != nil {
+ rl.endStreamError(cs, err)
+ return nil
+ }
+ }
+
+ if f.StreamEnded() {
+ rl.endStream(cs)
+ }
+ return nil
+}
+
+func (rl *http2clientConnReadLoop) endStream(cs *http2clientStream) {
+ // TODO: check that any declared content-length matches, like
+ // server.go's (*stream).endStream method.
+ if !cs.readClosed {
+ cs.readClosed = true
+ // Close cs.bufPipe and cs.peerClosed with cc.mu held to avoid a
+ // race condition: The caller can read io.EOF from Response.Body
+ // and close the body before we close cs.peerClosed, causing
+ // cleanupWriteRequest to send a RST_STREAM.
+ rl.cc.mu.Lock()
+ defer rl.cc.mu.Unlock()
+ cs.bufPipe.closeWithErrorAndCode(io.EOF, cs.copyTrailers)
+ close(cs.peerClosed)
+ }
+}
+
+func (rl *http2clientConnReadLoop) endStreamError(cs *http2clientStream, err error) {
+ cs.readAborted = true
+ cs.abortStream(err)
+}
+
+func (rl *http2clientConnReadLoop) streamByID(id uint32) *http2clientStream {
+ rl.cc.mu.Lock()
+ defer rl.cc.mu.Unlock()
+ cs := rl.cc.streams[id]
+ if cs != nil && !cs.readAborted {
+ return cs
+ }
+ return nil
+}
+
+func (cs *http2clientStream) copyTrailers() {
+ for k, vv := range cs.trailer {
+ t := cs.resTrailer
+ if *t == nil {
+ *t = make(Header)
+ }
+ (*t)[k] = vv
+ }
+}
+
+func (rl *http2clientConnReadLoop) processGoAway(f *http2GoAwayFrame) error {
+ cc := rl.cc
+ cc.t.connPool().MarkDead(cc)
+ if f.ErrCode != 0 {
+ // TODO: deal with GOAWAY more. particularly the error code
+ cc.vlogf("transport got GOAWAY with error code = %v", f.ErrCode)
+ if fn := cc.t.CountError; fn != nil {
+ fn("recv_goaway_" + f.ErrCode.stringToken())
+ }
+ }
+ cc.setGoAway(f)
+ return nil
+}
+
+func (rl *http2clientConnReadLoop) processSettings(f *http2SettingsFrame) error {
+ cc := rl.cc
+ // Locking both mu and wmu here allows frame encoding to read settings with only wmu held.
+ // Acquiring wmu when f.IsAck() is unnecessary, but convenient and mostly harmless.
+ cc.wmu.Lock()
+ defer cc.wmu.Unlock()
+
+ if err := rl.processSettingsNoWrite(f); err != nil {
+ return err
+ }
+ if !f.IsAck() {
+ cc.fr.WriteSettingsAck()
+ cc.bw.Flush()
+ }
+ return nil
+}
+
+func (rl *http2clientConnReadLoop) processSettingsNoWrite(f *http2SettingsFrame) error {
+ cc := rl.cc
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+
+ if f.IsAck() {
+ if cc.wantSettingsAck {
+ cc.wantSettingsAck = false
+ return nil
+ }
+ return http2ConnectionError(http2ErrCodeProtocol)
+ }
+
+ var seenMaxConcurrentStreams bool
+ err := f.ForeachSetting(func(s http2Setting) error {
+ switch s.ID {
+ case http2SettingMaxFrameSize:
+ cc.maxFrameSize = s.Val
+ case http2SettingMaxConcurrentStreams:
+ cc.maxConcurrentStreams = s.Val
+ seenMaxConcurrentStreams = true
+ case http2SettingMaxHeaderListSize:
+ cc.peerMaxHeaderListSize = uint64(s.Val)
+ case http2SettingInitialWindowSize:
+ // Values above the maximum flow-control
+ // window size of 2^31-1 MUST be treated as a
+ // connection error (Section 5.4.1) of type
+ // FLOW_CONTROL_ERROR.
+ if s.Val > math.MaxInt32 {
+ return http2ConnectionError(http2ErrCodeFlowControl)
+ }
+
+ // Adjust flow control of currently-open
+ // frames by the difference of the old initial
+ // window size and this one.
+ delta := int32(s.Val) - int32(cc.initialWindowSize)
+ for _, cs := range cc.streams {
+ cs.flow.add(delta)
+ }
+ cc.cond.Broadcast()
+
+ cc.initialWindowSize = s.Val
+ case http2SettingHeaderTableSize:
+ cc.henc.SetMaxDynamicTableSize(s.Val)
+ cc.peerMaxHeaderTableSize = s.Val
+ default:
+ cc.vlogf("Unhandled Setting: %v", s)
+ }
+ return nil
+ })
+ if err != nil {
+ return err
+ }
+
+ if !cc.seenSettings {
+ if !seenMaxConcurrentStreams {
+ // This was the servers initial SETTINGS frame and it
+ // didn't contain a MAX_CONCURRENT_STREAMS field so
+ // increase the number of concurrent streams this
+ // connection can establish to our default.
+ cc.maxConcurrentStreams = http2defaultMaxConcurrentStreams
+ }
+ cc.seenSettings = true
+ }
+
+ return nil
+}
+
+func (rl *http2clientConnReadLoop) processWindowUpdate(f *http2WindowUpdateFrame) error {
+ cc := rl.cc
+ cs := rl.streamByID(f.StreamID)
+ if f.StreamID != 0 && cs == nil {
+ return nil
+ }
+
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+
+ fl := &cc.flow
+ if cs != nil {
+ fl = &cs.flow
+ }
+ if !fl.add(int32(f.Increment)) {
+ return http2ConnectionError(http2ErrCodeFlowControl)
+ }
+ cc.cond.Broadcast()
+ return nil
+}
+
+func (rl *http2clientConnReadLoop) processResetStream(f *http2RSTStreamFrame) error {
+ cs := rl.streamByID(f.StreamID)
+ if cs == nil {
+ // TODO: return error if server tries to RST_STREAM an idle stream
+ return nil
+ }
+ serr := http2streamError(cs.ID, f.ErrCode)
+ serr.Cause = http2errFromPeer
+ if f.ErrCode == http2ErrCodeProtocol {
+ rl.cc.SetDoNotReuse()
+ }
+ if fn := cs.cc.t.CountError; fn != nil {
+ fn("recv_rststream_" + f.ErrCode.stringToken())
+ }
+ cs.abortStream(serr)
+
+ cs.bufPipe.CloseWithError(serr)
+ return nil
+}
+
+// Ping sends a PING frame to the server and waits for the ack.
+func (cc *http2ClientConn) Ping(ctx context.Context) error {
+ c := make(chan struct{})
+ // Generate a random payload
+ var p [8]byte
+ for {
+ if _, err := rand.Read(p[:]); err != nil {
+ return err
+ }
+ cc.mu.Lock()
+ // check for dup before insert
+ if _, found := cc.pings[p]; !found {
+ cc.pings[p] = c
+ cc.mu.Unlock()
+ break
+ }
+ cc.mu.Unlock()
+ }
+ errc := make(chan error, 1)
+ go func() {
+ cc.wmu.Lock()
+ defer cc.wmu.Unlock()
+ if err := cc.fr.WritePing(false, p); err != nil {
+ errc <- err
+ return
+ }
+ if err := cc.bw.Flush(); err != nil {
+ errc <- err
+ return
+ }
+ }()
+ select {
+ case <-c:
+ return nil
+ case err := <-errc:
+ return err
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-cc.readerDone:
+ // connection closed
+ return cc.readerErr
+ }
+}
+
+func (rl *http2clientConnReadLoop) processPing(f *http2PingFrame) error {
+ if f.IsAck() {
+ cc := rl.cc
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ // If ack, notify listener if any
+ if c, ok := cc.pings[f.Data]; ok {
+ close(c)
+ delete(cc.pings, f.Data)
+ }
+ return nil
+ }
+ cc := rl.cc
+ cc.wmu.Lock()
+ defer cc.wmu.Unlock()
+ if err := cc.fr.WritePing(true, f.Data); err != nil {
+ return err
+ }
+ return cc.bw.Flush()
+}
+
+func (rl *http2clientConnReadLoop) processPushPromise(f *http2PushPromiseFrame) error {
+ // We told the peer we don't want them.
+ // Spec says:
+ // "PUSH_PROMISE MUST NOT be sent if the SETTINGS_ENABLE_PUSH
+ // setting of the peer endpoint is set to 0. An endpoint that
+ // has set this setting and has received acknowledgement MUST
+ // treat the receipt of a PUSH_PROMISE frame as a connection
+ // error (Section 5.4.1) of type PROTOCOL_ERROR."
+ return http2ConnectionError(http2ErrCodeProtocol)
+}
+
+func (cc *http2ClientConn) writeStreamReset(streamID uint32, code http2ErrCode, err error) {
+ // TODO: map err to more interesting error codes, once the
+ // HTTP community comes up with some. But currently for
+ // RST_STREAM there's no equivalent to GOAWAY frame's debug
+ // data, and the error codes are all pretty vague ("cancel").
+ cc.wmu.Lock()
+ cc.fr.WriteRSTStream(streamID, code)
+ cc.bw.Flush()
+ cc.wmu.Unlock()
+}
+
+var (
+ http2errResponseHeaderListSize = errors.New("http2: response header list larger than advertised limit")
+ http2errRequestHeaderListSize = errors.New("http2: request header list larger than peer's advertised limit")
+)
+
+func (cc *http2ClientConn) logf(format string, args ...interface{}) {
+ cc.t.logf(format, args...)
+}
+
+func (cc *http2ClientConn) vlogf(format string, args ...interface{}) {
+ cc.t.vlogf(format, args...)
+}
+
+func (t *http2Transport) vlogf(format string, args ...interface{}) {
+ if http2VerboseLogs {
+ t.logf(format, args...)
+ }
+}
+
+func (t *http2Transport) logf(format string, args ...interface{}) {
+ log.Printf(format, args...)
+}
+
+var http2noBody io.ReadCloser = http2noBodyReader{}
+
+type http2noBodyReader struct{}
+
+func (http2noBodyReader) Close() error { return nil }
+
+func (http2noBodyReader) Read([]byte) (int, error) { return 0, io.EOF }
+
+type http2missingBody struct{}
+
+func (http2missingBody) Close() error { return nil }
+
+func (http2missingBody) Read([]byte) (int, error) { return 0, io.ErrUnexpectedEOF }
+
+func http2strSliceContains(ss []string, s string) bool {
+ for _, v := range ss {
+ if v == s {
+ return true
+ }
+ }
+ return false
+}
+
+type http2erringRoundTripper struct{ err error }
+
+func (rt http2erringRoundTripper) RoundTripErr() error { return rt.err }
+
+func (rt http2erringRoundTripper) RoundTrip(*Request) (*Response, error) { return nil, rt.err }
+
+// gzipReader wraps a response body so it can lazily
+// call gzip.NewReader on the first call to Read
+type http2gzipReader struct {
+ _ http2incomparable
+ body io.ReadCloser // underlying Response.Body
+ zr *gzip.Reader // lazily-initialized gzip reader
+ zerr error // sticky error
+}
+
+func (gz *http2gzipReader) Read(p []byte) (n int, err error) {
+ if gz.zerr != nil {
+ return 0, gz.zerr
+ }
+ if gz.zr == nil {
+ gz.zr, err = gzip.NewReader(gz.body)
+ if err != nil {
+ gz.zerr = err
+ return 0, err
+ }
+ }
+ return gz.zr.Read(p)
+}
+
+func (gz *http2gzipReader) Close() error {
+ if err := gz.body.Close(); err != nil {
+ return err
+ }
+ gz.zerr = fs.ErrClosed
+ return nil
+}
+
+type http2errorReader struct{ err error }
+
+func (r http2errorReader) Read(p []byte) (int, error) { return 0, r.err }
+
+// isConnectionCloseRequest reports whether req should use its own
+// connection for a single request and then close the connection.
+func http2isConnectionCloseRequest(req *Request) bool {
+ return req.Close || httpguts.HeaderValuesContainsToken(req.Header["Connection"], "close")
+}
+
+// registerHTTPSProtocol calls Transport.RegisterProtocol but
+// converting panics into errors.
+func http2registerHTTPSProtocol(t *Transport, rt http2noDialH2RoundTripper) (err error) {
+ defer func() {
+ if e := recover(); e != nil {
+ err = fmt.Errorf("%v", e)
+ }
+ }()
+ t.RegisterProtocol("https", rt)
+ return nil
+}
+
+// noDialH2RoundTripper is a RoundTripper which only tries to complete the request
+// if there's already has a cached connection to the host.
+// (The field is exported so it can be accessed via reflect from net/http; tested
+// by TestNoDialH2RoundTripperType)
+type http2noDialH2RoundTripper struct{ *http2Transport }
+
+func (rt http2noDialH2RoundTripper) RoundTrip(req *Request) (*Response, error) {
+ res, err := rt.http2Transport.RoundTrip(req)
+ if http2isNoCachedConnError(err) {
+ return nil, ErrSkipAltProtocol
+ }
+ return res, err
+}
+
+func (t *http2Transport) idleConnTimeout() time.Duration {
+ if t.t1 != nil {
+ return t.t1.IdleConnTimeout
+ }
+ return 0
+}
+
+func http2traceGetConn(req *Request, hostPort string) {
+ trace := httptrace.ContextClientTrace(req.Context())
+ if trace == nil || trace.GetConn == nil {
+ return
+ }
+ trace.GetConn(hostPort)
+}
+
+func http2traceGotConn(req *Request, cc *http2ClientConn, reused bool) {
+ trace := httptrace.ContextClientTrace(req.Context())
+ if trace == nil || trace.GotConn == nil {
+ return
+ }
+ ci := httptrace.GotConnInfo{Conn: cc.tconn}
+ ci.Reused = reused
+ cc.mu.Lock()
+ ci.WasIdle = len(cc.streams) == 0 && reused
+ if ci.WasIdle && !cc.lastActive.IsZero() {
+ ci.IdleTime = time.Since(cc.lastActive)
+ }
+ cc.mu.Unlock()
+
+ trace.GotConn(ci)
+}
+
+func http2traceWroteHeaders(trace *httptrace.ClientTrace) {
+ if trace != nil && trace.WroteHeaders != nil {
+ trace.WroteHeaders()
+ }
+}
+
+func http2traceGot100Continue(trace *httptrace.ClientTrace) {
+ if trace != nil && trace.Got100Continue != nil {
+ trace.Got100Continue()
+ }
+}
+
+func http2traceWait100Continue(trace *httptrace.ClientTrace) {
+ if trace != nil && trace.Wait100Continue != nil {
+ trace.Wait100Continue()
+ }
+}
+
+func http2traceWroteRequest(trace *httptrace.ClientTrace, err error) {
+ if trace != nil && trace.WroteRequest != nil {
+ trace.WroteRequest(httptrace.WroteRequestInfo{Err: err})
+ }
+}
+
+func http2traceFirstResponseByte(trace *httptrace.ClientTrace) {
+ if trace != nil && trace.GotFirstResponseByte != nil {
+ trace.GotFirstResponseByte()
+ }
+}
+
+// writeFramer is implemented by any type that is used to write frames.
+type http2writeFramer interface {
+ writeFrame(http2writeContext) error
+
+ // staysWithinBuffer reports whether this writer promises that
+ // it will only write less than or equal to size bytes, and it
+ // won't Flush the write context.
+ staysWithinBuffer(size int) bool
+}
+
+// writeContext is the interface needed by the various frame writer
+// types below. All the writeFrame methods below are scheduled via the
+// frame writing scheduler (see writeScheduler in writesched.go).
+//
+// This interface is implemented by *serverConn.
+//
+// TODO: decide whether to a) use this in the client code (which didn't
+// end up using this yet, because it has a simpler design, not
+// currently implementing priorities), or b) delete this and
+// make the server code a bit more concrete.
+type http2writeContext interface {
+ Framer() *http2Framer
+ Flush() error
+ CloseConn() error
+ // HeaderEncoder returns an HPACK encoder that writes to the
+ // returned buffer.
+ HeaderEncoder() (*hpack.Encoder, *bytes.Buffer)
+}
+
+// writeEndsStream reports whether w writes a frame that will transition
+// the stream to a half-closed local state. This returns false for RST_STREAM,
+// which closes the entire stream (not just the local half).
+func http2writeEndsStream(w http2writeFramer) bool {
+ switch v := w.(type) {
+ case *http2writeData:
+ return v.endStream
+ case *http2writeResHeaders:
+ return v.endStream
+ case nil:
+ // This can only happen if the caller reuses w after it's
+ // been intentionally nil'ed out to prevent use. Keep this
+ // here to catch future refactoring breaking it.
+ panic("writeEndsStream called on nil writeFramer")
+ }
+ return false
+}
+
+type http2flushFrameWriter struct{}
+
+func (http2flushFrameWriter) writeFrame(ctx http2writeContext) error {
+ return ctx.Flush()
+}
+
+func (http2flushFrameWriter) staysWithinBuffer(max int) bool { return false }
+
+type http2writeSettings []http2Setting
+
+func (s http2writeSettings) staysWithinBuffer(max int) bool {
+ const settingSize = 6 // uint16 + uint32
+ return http2frameHeaderLen+settingSize*len(s) <= max
+
+}
+
+func (s http2writeSettings) writeFrame(ctx http2writeContext) error {
+ return ctx.Framer().WriteSettings([]http2Setting(s)...)
+}
+
+type http2writeGoAway struct {
+ maxStreamID uint32
+ code http2ErrCode
+}
+
+func (p *http2writeGoAway) writeFrame(ctx http2writeContext) error {
+ err := ctx.Framer().WriteGoAway(p.maxStreamID, p.code, nil)
+ ctx.Flush() // ignore error: we're hanging up on them anyway
+ return err
+}
+
+func (*http2writeGoAway) staysWithinBuffer(max int) bool { return false } // flushes
+
+type http2writeData struct {
+ streamID uint32
+ p []byte
+ endStream bool
+}
+
+func (w *http2writeData) String() string {
+ return fmt.Sprintf("writeData(stream=%d, p=%d, endStream=%v)", w.streamID, len(w.p), w.endStream)
+}
+
+func (w *http2writeData) writeFrame(ctx http2writeContext) error {
+ return ctx.Framer().WriteData(w.streamID, w.endStream, w.p)
+}
+
+func (w *http2writeData) staysWithinBuffer(max int) bool {
+ return http2frameHeaderLen+len(w.p) <= max
+}
+
+// handlerPanicRST is the message sent from handler goroutines when
+// the handler panics.
+type http2handlerPanicRST struct {
+ StreamID uint32
+}
+
+func (hp http2handlerPanicRST) writeFrame(ctx http2writeContext) error {
+ return ctx.Framer().WriteRSTStream(hp.StreamID, http2ErrCodeInternal)
+}
+
+func (hp http2handlerPanicRST) staysWithinBuffer(max int) bool { return http2frameHeaderLen+4 <= max }
+
+func (se http2StreamError) writeFrame(ctx http2writeContext) error {
+ return ctx.Framer().WriteRSTStream(se.StreamID, se.Code)
+}
+
+func (se http2StreamError) staysWithinBuffer(max int) bool { return http2frameHeaderLen+4 <= max }
+
+type http2writePingAck struct{ pf *http2PingFrame }
+
+func (w http2writePingAck) writeFrame(ctx http2writeContext) error {
+ return ctx.Framer().WritePing(true, w.pf.Data)
+}
+
+func (w http2writePingAck) staysWithinBuffer(max int) bool {
+ return http2frameHeaderLen+len(w.pf.Data) <= max
+}
+
+type http2writeSettingsAck struct{}
+
+func (http2writeSettingsAck) writeFrame(ctx http2writeContext) error {
+ return ctx.Framer().WriteSettingsAck()
+}
+
+func (http2writeSettingsAck) staysWithinBuffer(max int) bool { return http2frameHeaderLen <= max }
+
+// splitHeaderBlock splits headerBlock into fragments so that each fragment fits
+// in a single frame, then calls fn for each fragment. firstFrag/lastFrag are true
+// for the first/last fragment, respectively.
+func http2splitHeaderBlock(ctx http2writeContext, headerBlock []byte, fn func(ctx http2writeContext, frag []byte, firstFrag, lastFrag bool) error) error {
+ // For now we're lazy and just pick the minimum MAX_FRAME_SIZE
+ // that all peers must support (16KB). Later we could care
+ // more and send larger frames if the peer advertised it, but
+ // there's little point. Most headers are small anyway (so we
+ // generally won't have CONTINUATION frames), and extra frames
+ // only waste 9 bytes anyway.
+ const maxFrameSize = 16384
+
+ first := true
+ for len(headerBlock) > 0 {
+ frag := headerBlock
+ if len(frag) > maxFrameSize {
+ frag = frag[:maxFrameSize]
+ }
+ headerBlock = headerBlock[len(frag):]
+ if err := fn(ctx, frag, first, len(headerBlock) == 0); err != nil {
+ return err
+ }
+ first = false
+ }
+ return nil
+}
+
+// writeResHeaders is a request to write a HEADERS and 0+ CONTINUATION frames
+// for HTTP response headers or trailers from a server handler.
+type http2writeResHeaders struct {
+ streamID uint32
+ httpResCode int // 0 means no ":status" line
+ h Header // may be nil
+ trailers []string // if non-nil, which keys of h to write. nil means all.
+ endStream bool
+
+ date string
+ contentType string
+ contentLength string
+}
+
+func http2encKV(enc *hpack.Encoder, k, v string) {
+ if http2VerboseLogs {
+ log.Printf("http2: server encoding header %q = %q", k, v)
+ }
+ enc.WriteField(hpack.HeaderField{Name: k, Value: v})
+}
+
+func (w *http2writeResHeaders) staysWithinBuffer(max int) bool {
+ // TODO: this is a common one. It'd be nice to return true
+ // here and get into the fast path if we could be clever and
+ // calculate the size fast enough, or at least a conservative
+ // upper bound that usually fires. (Maybe if w.h and
+ // w.trailers are nil, so we don't need to enumerate it.)
+ // Otherwise I'm afraid that just calculating the length to
+ // answer this question would be slower than the ~2µs benefit.
+ return false
+}
+
+func (w *http2writeResHeaders) writeFrame(ctx http2writeContext) error {
+ enc, buf := ctx.HeaderEncoder()
+ buf.Reset()
+
+ if w.httpResCode != 0 {
+ http2encKV(enc, ":status", http2httpCodeString(w.httpResCode))
+ }
+
+ http2encodeHeaders(enc, w.h, w.trailers)
+
+ if w.contentType != "" {
+ http2encKV(enc, "content-type", w.contentType)
+ }
+ if w.contentLength != "" {
+ http2encKV(enc, "content-length", w.contentLength)
+ }
+ if w.date != "" {
+ http2encKV(enc, "date", w.date)
+ }
+
+ headerBlock := buf.Bytes()
+ if len(headerBlock) == 0 && w.trailers == nil {
+ panic("unexpected empty hpack")
+ }
+
+ return http2splitHeaderBlock(ctx, headerBlock, w.writeHeaderBlock)
+}
+
+func (w *http2writeResHeaders) writeHeaderBlock(ctx http2writeContext, frag []byte, firstFrag, lastFrag bool) error {
+ if firstFrag {
+ return ctx.Framer().WriteHeaders(http2HeadersFrameParam{
+ StreamID: w.streamID,
+ BlockFragment: frag,
+ EndStream: w.endStream,
+ EndHeaders: lastFrag,
+ })
+ } else {
+ return ctx.Framer().WriteContinuation(w.streamID, lastFrag, frag)
+ }
+}
+
+// writePushPromise is a request to write a PUSH_PROMISE and 0+ CONTINUATION frames.
+type http2writePushPromise struct {
+ streamID uint32 // pusher stream
+ method string // for :method
+ url *url.URL // for :scheme, :authority, :path
+ h Header
+
+ // Creates an ID for a pushed stream. This runs on serveG just before
+ // the frame is written. The returned ID is copied to promisedID.
+ allocatePromisedID func() (uint32, error)
+ promisedID uint32
+}
+
+func (w *http2writePushPromise) staysWithinBuffer(max int) bool {
+ // TODO: see writeResHeaders.staysWithinBuffer
+ return false
+}
+
+func (w *http2writePushPromise) writeFrame(ctx http2writeContext) error {
+ enc, buf := ctx.HeaderEncoder()
+ buf.Reset()
+
+ http2encKV(enc, ":method", w.method)
+ http2encKV(enc, ":scheme", w.url.Scheme)
+ http2encKV(enc, ":authority", w.url.Host)
+ http2encKV(enc, ":path", w.url.RequestURI())
+ http2encodeHeaders(enc, w.h, nil)
+
+ headerBlock := buf.Bytes()
+ if len(headerBlock) == 0 {
+ panic("unexpected empty hpack")
+ }
+
+ return http2splitHeaderBlock(ctx, headerBlock, w.writeHeaderBlock)
+}
+
+func (w *http2writePushPromise) writeHeaderBlock(ctx http2writeContext, frag []byte, firstFrag, lastFrag bool) error {
+ if firstFrag {
+ return ctx.Framer().WritePushPromise(http2PushPromiseParam{
+ StreamID: w.streamID,
+ PromiseID: w.promisedID,
+ BlockFragment: frag,
+ EndHeaders: lastFrag,
+ })
+ } else {
+ return ctx.Framer().WriteContinuation(w.streamID, lastFrag, frag)
+ }
+}
+
+type http2write100ContinueHeadersFrame struct {
+ streamID uint32
+}
+
+func (w http2write100ContinueHeadersFrame) writeFrame(ctx http2writeContext) error {
+ enc, buf := ctx.HeaderEncoder()
+ buf.Reset()
+ http2encKV(enc, ":status", "100")
+ return ctx.Framer().WriteHeaders(http2HeadersFrameParam{
+ StreamID: w.streamID,
+ BlockFragment: buf.Bytes(),
+ EndStream: false,
+ EndHeaders: true,
+ })
+}
+
+func (w http2write100ContinueHeadersFrame) staysWithinBuffer(max int) bool {
+ // Sloppy but conservative:
+ return 9+2*(len(":status")+len("100")) <= max
+}
+
+type http2writeWindowUpdate struct {
+ streamID uint32 // or 0 for conn-level
+ n uint32
+}
+
+func (wu http2writeWindowUpdate) staysWithinBuffer(max int) bool { return http2frameHeaderLen+4 <= max }
+
+func (wu http2writeWindowUpdate) writeFrame(ctx http2writeContext) error {
+ return ctx.Framer().WriteWindowUpdate(wu.streamID, wu.n)
+}
+
+// encodeHeaders encodes an http.Header. If keys is not nil, then (k, h[k])
+// is encoded only if k is in keys.
+func http2encodeHeaders(enc *hpack.Encoder, h Header, keys []string) {
+ if keys == nil {
+ sorter := http2sorterPool.Get().(*http2sorter)
+ // Using defer here, since the returned keys from the
+ // sorter.Keys method is only valid until the sorter
+ // is returned:
+ defer http2sorterPool.Put(sorter)
+ keys = sorter.Keys(h)
+ }
+ for _, k := range keys {
+ vv := h[k]
+ k, ascii := http2lowerHeader(k)
+ if !ascii {
+ // Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header
+ // field names have to be ASCII characters (just as in HTTP/1.x).
+ continue
+ }
+ if !http2validWireHeaderFieldName(k) {
+ // Skip it as backup paranoia. Per
+ // golang.org/issue/14048, these should
+ // already be rejected at a higher level.
+ continue
+ }
+ isTE := k == "transfer-encoding"
+ for _, v := range vv {
+ if !httpguts.ValidHeaderFieldValue(v) {
+ // TODO: return an error? golang.org/issue/14048
+ // For now just omit it.
+ continue
+ }
+ // TODO: more of "8.1.2.2 Connection-Specific Header Fields"
+ if isTE && v != "trailers" {
+ continue
+ }
+ http2encKV(enc, k, v)
+ }
+ }
+}
+
+// WriteScheduler is the interface implemented by HTTP/2 write schedulers.
+// Methods are never called concurrently.
+type http2WriteScheduler interface {
+ // OpenStream opens a new stream in the write scheduler.
+ // It is illegal to call this with streamID=0 or with a streamID that is
+ // already open -- the call may panic.
+ OpenStream(streamID uint32, options http2OpenStreamOptions)
+
+ // CloseStream closes a stream in the write scheduler. Any frames queued on
+ // this stream should be discarded. It is illegal to call this on a stream
+ // that is not open -- the call may panic.
+ CloseStream(streamID uint32)
+
+ // AdjustStream adjusts the priority of the given stream. This may be called
+ // on a stream that has not yet been opened or has been closed. Note that
+ // RFC 7540 allows PRIORITY frames to be sent on streams in any state. See:
+ // https://tools.ietf.org/html/rfc7540#section-5.1
+ AdjustStream(streamID uint32, priority http2PriorityParam)
+
+ // Push queues a frame in the scheduler. In most cases, this will not be
+ // called with wr.StreamID()!=0 unless that stream is currently open. The one
+ // exception is RST_STREAM frames, which may be sent on idle or closed streams.
+ Push(wr http2FrameWriteRequest)
+
+ // Pop dequeues the next frame to write. Returns false if no frames can
+ // be written. Frames with a given wr.StreamID() are Pop'd in the same
+ // order they are Push'd, except RST_STREAM frames. No frames should be
+ // discarded except by CloseStream.
+ Pop() (wr http2FrameWriteRequest, ok bool)
+}
+
+// OpenStreamOptions specifies extra options for WriteScheduler.OpenStream.
+type http2OpenStreamOptions struct {
+ // PusherID is zero if the stream was initiated by the client. Otherwise,
+ // PusherID names the stream that pushed the newly opened stream.
+ PusherID uint32
+}
+
+// FrameWriteRequest is a request to write a frame.
+type http2FrameWriteRequest struct {
+ // write is the interface value that does the writing, once the
+ // WriteScheduler has selected this frame to write. The write
+ // functions are all defined in write.go.
+ write http2writeFramer
+
+ // stream is the stream on which this frame will be written.
+ // nil for non-stream frames like PING and SETTINGS.
+ // nil for RST_STREAM streams, which use the StreamError.StreamID field instead.
+ stream *http2stream
+
+ // done, if non-nil, must be a buffered channel with space for
+ // 1 message and is sent the return value from write (or an
+ // earlier error) when the frame has been written.
+ done chan error
+}
+
+// StreamID returns the id of the stream this frame will be written to.
+// 0 is used for non-stream frames such as PING and SETTINGS.
+func (wr http2FrameWriteRequest) StreamID() uint32 {
+ if wr.stream == nil {
+ if se, ok := wr.write.(http2StreamError); ok {
+ // (*serverConn).resetStream doesn't set
+ // stream because it doesn't necessarily have
+ // one. So special case this type of write
+ // message.
+ return se.StreamID
+ }
+ return 0
+ }
+ return wr.stream.id
+}
+
+// isControl reports whether wr is a control frame for MaxQueuedControlFrames
+// purposes. That includes non-stream frames and RST_STREAM frames.
+func (wr http2FrameWriteRequest) isControl() bool {
+ return wr.stream == nil
+}
+
+// DataSize returns the number of flow control bytes that must be consumed
+// to write this entire frame. This is 0 for non-DATA frames.
+func (wr http2FrameWriteRequest) DataSize() int {
+ if wd, ok := wr.write.(*http2writeData); ok {
+ return len(wd.p)
+ }
+ return 0
+}
+
+// Consume consumes min(n, available) bytes from this frame, where available
+// is the number of flow control bytes available on the stream. Consume returns
+// 0, 1, or 2 frames, where the integer return value gives the number of frames
+// returned.
+//
+// If flow control prevents consuming any bytes, this returns (_, _, 0). If
+// the entire frame was consumed, this returns (wr, _, 1). Otherwise, this
+// returns (consumed, rest, 2), where 'consumed' contains the consumed bytes and
+// 'rest' contains the remaining bytes. The consumed bytes are deducted from the
+// underlying stream's flow control budget.
+func (wr http2FrameWriteRequest) Consume(n int32) (http2FrameWriteRequest, http2FrameWriteRequest, int) {
+ var empty http2FrameWriteRequest
+
+ // Non-DATA frames are always consumed whole.
+ wd, ok := wr.write.(*http2writeData)
+ if !ok || len(wd.p) == 0 {
+ return wr, empty, 1
+ }
+
+ // Might need to split after applying limits.
+ allowed := wr.stream.flow.available()
+ if n < allowed {
+ allowed = n
+ }
+ if wr.stream.sc.maxFrameSize < allowed {
+ allowed = wr.stream.sc.maxFrameSize
+ }
+ if allowed <= 0 {
+ return empty, empty, 0
+ }
+ if len(wd.p) > int(allowed) {
+ wr.stream.flow.take(allowed)
+ consumed := http2FrameWriteRequest{
+ stream: wr.stream,
+ write: &http2writeData{
+ streamID: wd.streamID,
+ p: wd.p[:allowed],
+ // Even if the original had endStream set, there
+ // are bytes remaining because len(wd.p) > allowed,
+ // so we know endStream is false.
+ endStream: false,
+ },
+ // Our caller is blocking on the final DATA frame, not
+ // this intermediate frame, so no need to wait.
+ done: nil,
+ }
+ rest := http2FrameWriteRequest{
+ stream: wr.stream,
+ write: &http2writeData{
+ streamID: wd.streamID,
+ p: wd.p[allowed:],
+ endStream: wd.endStream,
+ },
+ done: wr.done,
+ }
+ return consumed, rest, 2
+ }
+
+ // The frame is consumed whole.
+ // NB: This cast cannot overflow because allowed is <= math.MaxInt32.
+ wr.stream.flow.take(int32(len(wd.p)))
+ return wr, empty, 1
+}
+
+// String is for debugging only.
+func (wr http2FrameWriteRequest) String() string {
+ var des string
+ if s, ok := wr.write.(fmt.Stringer); ok {
+ des = s.String()
+ } else {
+ des = fmt.Sprintf("%T", wr.write)
+ }
+ return fmt.Sprintf("[FrameWriteRequest stream=%d, ch=%v, writer=%v]", wr.StreamID(), wr.done != nil, des)
+}
+
+// replyToWriter sends err to wr.done and panics if the send must block
+// This does nothing if wr.done is nil.
+func (wr *http2FrameWriteRequest) replyToWriter(err error) {
+ if wr.done == nil {
+ return
+ }
+ select {
+ case wr.done <- err:
+ default:
+ panic(fmt.Sprintf("unbuffered done channel passed in for type %T", wr.write))
+ }
+ wr.write = nil // prevent use (assume it's tainted after wr.done send)
+}
+
+// writeQueue is used by implementations of WriteScheduler.
+type http2writeQueue struct {
+ s []http2FrameWriteRequest
+ prev, next *http2writeQueue
+}
+
+func (q *http2writeQueue) empty() bool { return len(q.s) == 0 }
+
+func (q *http2writeQueue) push(wr http2FrameWriteRequest) {
+ q.s = append(q.s, wr)
+}
+
+func (q *http2writeQueue) shift() http2FrameWriteRequest {
+ if len(q.s) == 0 {
+ panic("invalid use of queue")
+ }
+ wr := q.s[0]
+ // TODO: less copy-happy queue.
+ copy(q.s, q.s[1:])
+ q.s[len(q.s)-1] = http2FrameWriteRequest{}
+ q.s = q.s[:len(q.s)-1]
+ return wr
+}
+
+// consume consumes up to n bytes from q.s[0]. If the frame is
+// entirely consumed, it is removed from the queue. If the frame
+// is partially consumed, the frame is kept with the consumed
+// bytes removed. Returns true iff any bytes were consumed.
+func (q *http2writeQueue) consume(n int32) (http2FrameWriteRequest, bool) {
+ if len(q.s) == 0 {
+ return http2FrameWriteRequest{}, false
+ }
+ consumed, rest, numresult := q.s[0].Consume(n)
+ switch numresult {
+ case 0:
+ return http2FrameWriteRequest{}, false
+ case 1:
+ q.shift()
+ case 2:
+ q.s[0] = rest
+ }
+ return consumed, true
+}
+
+type http2writeQueuePool []*http2writeQueue
+
+// put inserts an unused writeQueue into the pool.
+
+// put inserts an unused writeQueue into the pool.
+func (p *http2writeQueuePool) put(q *http2writeQueue) {
+ for i := range q.s {
+ q.s[i] = http2FrameWriteRequest{}
+ }
+ q.s = q.s[:0]
+ *p = append(*p, q)
+}
+
+// get returns an empty writeQueue.
+func (p *http2writeQueuePool) get() *http2writeQueue {
+ ln := len(*p)
+ if ln == 0 {
+ return new(http2writeQueue)
+ }
+ x := ln - 1
+ q := (*p)[x]
+ (*p)[x] = nil
+ *p = (*p)[:x]
+ return q
+}
+
+// RFC 7540, Section 5.3.5: the default weight is 16.
+const http2priorityDefaultWeight = 15 // 16 = 15 + 1
+
+// PriorityWriteSchedulerConfig configures a priorityWriteScheduler.
+type http2PriorityWriteSchedulerConfig struct {
+ // MaxClosedNodesInTree controls the maximum number of closed streams to
+ // retain in the priority tree. Setting this to zero saves a small amount
+ // of memory at the cost of performance.
+ //
+ // See RFC 7540, Section 5.3.4:
+ // "It is possible for a stream to become closed while prioritization
+ // information ... is in transit. ... This potentially creates suboptimal
+ // prioritization, since the stream could be given a priority that is
+ // different from what is intended. To avoid these problems, an endpoint
+ // SHOULD retain stream prioritization state for a period after streams
+ // become closed. The longer state is retained, the lower the chance that
+ // streams are assigned incorrect or default priority values."
+ MaxClosedNodesInTree int
+
+ // MaxIdleNodesInTree controls the maximum number of idle streams to
+ // retain in the priority tree. Setting this to zero saves a small amount
+ // of memory at the cost of performance.
+ //
+ // See RFC 7540, Section 5.3.4:
+ // Similarly, streams that are in the "idle" state can be assigned
+ // priority or become a parent of other streams. This allows for the
+ // creation of a grouping node in the dependency tree, which enables
+ // more flexible expressions of priority. Idle streams begin with a
+ // default priority (Section 5.3.5).
+ MaxIdleNodesInTree int
+
+ // ThrottleOutOfOrderWrites enables write throttling to help ensure that
+ // data is delivered in priority order. This works around a race where
+ // stream B depends on stream A and both streams are about to call Write
+ // to queue DATA frames. If B wins the race, a naive scheduler would eagerly
+ // write as much data from B as possible, but this is suboptimal because A
+ // is a higher-priority stream. With throttling enabled, we write a small
+ // amount of data from B to minimize the amount of bandwidth that B can
+ // steal from A.
+ ThrottleOutOfOrderWrites bool
+}
+
+// NewPriorityWriteScheduler constructs a WriteScheduler that schedules
+// frames by following HTTP/2 priorities as described in RFC 7540 Section 5.3.
+// If cfg is nil, default options are used.
+func http2NewPriorityWriteScheduler(cfg *http2PriorityWriteSchedulerConfig) http2WriteScheduler {
+ if cfg == nil {
+ // For justification of these defaults, see:
+ // https://docs.google.com/document/d/1oLhNg1skaWD4_DtaoCxdSRN5erEXrH-KnLrMwEpOtFY
+ cfg = &http2PriorityWriteSchedulerConfig{
+ MaxClosedNodesInTree: 10,
+ MaxIdleNodesInTree: 10,
+ ThrottleOutOfOrderWrites: false,
+ }
+ }
+
+ ws := &http2priorityWriteScheduler{
+ nodes: make(map[uint32]*http2priorityNode),
+ maxClosedNodesInTree: cfg.MaxClosedNodesInTree,
+ maxIdleNodesInTree: cfg.MaxIdleNodesInTree,
+ enableWriteThrottle: cfg.ThrottleOutOfOrderWrites,
+ }
+ ws.nodes[0] = &ws.root
+ if cfg.ThrottleOutOfOrderWrites {
+ ws.writeThrottleLimit = 1024
+ } else {
+ ws.writeThrottleLimit = math.MaxInt32
+ }
+ return ws
+}
+
+type http2priorityNodeState int
+
+const (
+ http2priorityNodeOpen http2priorityNodeState = iota
+ http2priorityNodeClosed
+ http2priorityNodeIdle
+)
+
+// priorityNode is a node in an HTTP/2 priority tree.
+// Each node is associated with a single stream ID.
+// See RFC 7540, Section 5.3.
+type http2priorityNode struct {
+ q http2writeQueue // queue of pending frames to write
+ id uint32 // id of the stream, or 0 for the root of the tree
+ weight uint8 // the actual weight is weight+1, so the value is in [1,256]
+ state http2priorityNodeState // open | closed | idle
+ bytes int64 // number of bytes written by this node, or 0 if closed
+ subtreeBytes int64 // sum(node.bytes) of all nodes in this subtree
+
+ // These links form the priority tree.
+ parent *http2priorityNode
+ kids *http2priorityNode // start of the kids list
+ prev, next *http2priorityNode // doubly-linked list of siblings
+}
+
+func (n *http2priorityNode) setParent(parent *http2priorityNode) {
+ if n == parent {
+ panic("setParent to self")
+ }
+ if n.parent == parent {
+ return
+ }
+ // Unlink from current parent.
+ if parent := n.parent; parent != nil {
+ if n.prev == nil {
+ parent.kids = n.next
+ } else {
+ n.prev.next = n.next
+ }
+ if n.next != nil {
+ n.next.prev = n.prev
+ }
+ }
+ // Link to new parent.
+ // If parent=nil, remove n from the tree.
+ // Always insert at the head of parent.kids (this is assumed by walkReadyInOrder).
+ n.parent = parent
+ if parent == nil {
+ n.next = nil
+ n.prev = nil
+ } else {
+ n.next = parent.kids
+ n.prev = nil
+ if n.next != nil {
+ n.next.prev = n
+ }
+ parent.kids = n
+ }
+}
+
+func (n *http2priorityNode) addBytes(b int64) {
+ n.bytes += b
+ for ; n != nil; n = n.parent {
+ n.subtreeBytes += b
+ }
+}
+
+// walkReadyInOrder iterates over the tree in priority order, calling f for each node
+// with a non-empty write queue. When f returns true, this function returns true and the
+// walk halts. tmp is used as scratch space for sorting.
+//
+// f(n, openParent) takes two arguments: the node to visit, n, and a bool that is true
+// if any ancestor p of n is still open (ignoring the root node).
+func (n *http2priorityNode) walkReadyInOrder(openParent bool, tmp *[]*http2priorityNode, f func(*http2priorityNode, bool) bool) bool {
+ if !n.q.empty() && f(n, openParent) {
+ return true
+ }
+ if n.kids == nil {
+ return false
+ }
+
+ // Don't consider the root "open" when updating openParent since
+ // we can't send data frames on the root stream (only control frames).
+ if n.id != 0 {
+ openParent = openParent || (n.state == http2priorityNodeOpen)
+ }
+
+ // Common case: only one kid or all kids have the same weight.
+ // Some clients don't use weights; other clients (like web browsers)
+ // use mostly-linear priority trees.
+ w := n.kids.weight
+ needSort := false
+ for k := n.kids.next; k != nil; k = k.next {
+ if k.weight != w {
+ needSort = true
+ break
+ }
+ }
+ if !needSort {
+ for k := n.kids; k != nil; k = k.next {
+ if k.walkReadyInOrder(openParent, tmp, f) {
+ return true
+ }
+ }
+ return false
+ }
+
+ // Uncommon case: sort the child nodes. We remove the kids from the parent,
+ // then re-insert after sorting so we can reuse tmp for future sort calls.
+ *tmp = (*tmp)[:0]
+ for n.kids != nil {
+ *tmp = append(*tmp, n.kids)
+ n.kids.setParent(nil)
+ }
+ sort.Sort(http2sortPriorityNodeSiblings(*tmp))
+ for i := len(*tmp) - 1; i >= 0; i-- {
+ (*tmp)[i].setParent(n) // setParent inserts at the head of n.kids
+ }
+ for k := n.kids; k != nil; k = k.next {
+ if k.walkReadyInOrder(openParent, tmp, f) {
+ return true
+ }
+ }
+ return false
+}
+
+type http2sortPriorityNodeSiblings []*http2priorityNode
+
+func (z http2sortPriorityNodeSiblings) Len() int { return len(z) }
+
+func (z http2sortPriorityNodeSiblings) Swap(i, k int) { z[i], z[k] = z[k], z[i] }
+
+func (z http2sortPriorityNodeSiblings) Less(i, k int) bool {
+ // Prefer the subtree that has sent fewer bytes relative to its weight.
+ // See sections 5.3.2 and 5.3.4.
+ wi, bi := float64(z[i].weight+1), float64(z[i].subtreeBytes)
+ wk, bk := float64(z[k].weight+1), float64(z[k].subtreeBytes)
+ if bi == 0 && bk == 0 {
+ return wi >= wk
+ }
+ if bk == 0 {
+ return false
+ }
+ return bi/bk <= wi/wk
+}
+
+type http2priorityWriteScheduler struct {
+ // root is the root of the priority tree, where root.id = 0.
+ // The root queues control frames that are not associated with any stream.
+ root http2priorityNode
+
+ // nodes maps stream ids to priority tree nodes.
+ nodes map[uint32]*http2priorityNode
+
+ // maxID is the maximum stream id in nodes.
+ maxID uint32
+
+ // lists of nodes that have been closed or are idle, but are kept in
+ // the tree for improved prioritization. When the lengths exceed either
+ // maxClosedNodesInTree or maxIdleNodesInTree, old nodes are discarded.
+ closedNodes, idleNodes []*http2priorityNode
+
+ // From the config.
+ maxClosedNodesInTree int
+ maxIdleNodesInTree int
+ writeThrottleLimit int32
+ enableWriteThrottle bool
+
+ // tmp is scratch space for priorityNode.walkReadyInOrder to reduce allocations.
+ tmp []*http2priorityNode
+
+ // pool of empty queues for reuse.
+ queuePool http2writeQueuePool
+}
+
+func (ws *http2priorityWriteScheduler) OpenStream(streamID uint32, options http2OpenStreamOptions) {
+ // The stream may be currently idle but cannot be opened or closed.
+ if curr := ws.nodes[streamID]; curr != nil {
+ if curr.state != http2priorityNodeIdle {
+ panic(fmt.Sprintf("stream %d already opened", streamID))
+ }
+ curr.state = http2priorityNodeOpen
+ return
+ }
+
+ // RFC 7540, Section 5.3.5:
+ // "All streams are initially assigned a non-exclusive dependency on stream 0x0.
+ // Pushed streams initially depend on their associated stream. In both cases,
+ // streams are assigned a default weight of 16."
+ parent := ws.nodes[options.PusherID]
+ if parent == nil {
+ parent = &ws.root
+ }
+ n := &http2priorityNode{
+ q: *ws.queuePool.get(),
+ id: streamID,
+ weight: http2priorityDefaultWeight,
+ state: http2priorityNodeOpen,
+ }
+ n.setParent(parent)
+ ws.nodes[streamID] = n
+ if streamID > ws.maxID {
+ ws.maxID = streamID
+ }
+}
+
+func (ws *http2priorityWriteScheduler) CloseStream(streamID uint32) {
+ if streamID == 0 {
+ panic("violation of WriteScheduler interface: cannot close stream 0")
+ }
+ if ws.nodes[streamID] == nil {
+ panic(fmt.Sprintf("violation of WriteScheduler interface: unknown stream %d", streamID))
+ }
+ if ws.nodes[streamID].state != http2priorityNodeOpen {
+ panic(fmt.Sprintf("violation of WriteScheduler interface: stream %d already closed", streamID))
+ }
+
+ n := ws.nodes[streamID]
+ n.state = http2priorityNodeClosed
+ n.addBytes(-n.bytes)
+
+ q := n.q
+ ws.queuePool.put(&q)
+ n.q.s = nil
+ if ws.maxClosedNodesInTree > 0 {
+ ws.addClosedOrIdleNode(&ws.closedNodes, ws.maxClosedNodesInTree, n)
+ } else {
+ ws.removeNode(n)
+ }
+}
+
+func (ws *http2priorityWriteScheduler) AdjustStream(streamID uint32, priority http2PriorityParam) {
+ if streamID == 0 {
+ panic("adjustPriority on root")
+ }
+
+ // If streamID does not exist, there are two cases:
+ // - A closed stream that has been removed (this will have ID <= maxID)
+ // - An idle stream that is being used for "grouping" (this will have ID > maxID)
+ n := ws.nodes[streamID]
+ if n == nil {
+ if streamID <= ws.maxID || ws.maxIdleNodesInTree == 0 {
+ return
+ }
+ ws.maxID = streamID
+ n = &http2priorityNode{
+ q: *ws.queuePool.get(),
+ id: streamID,
+ weight: http2priorityDefaultWeight,
+ state: http2priorityNodeIdle,
+ }
+ n.setParent(&ws.root)
+ ws.nodes[streamID] = n
+ ws.addClosedOrIdleNode(&ws.idleNodes, ws.maxIdleNodesInTree, n)
+ }
+
+ // Section 5.3.1: A dependency on a stream that is not currently in the tree
+ // results in that stream being given a default priority (Section 5.3.5).
+ parent := ws.nodes[priority.StreamDep]
+ if parent == nil {
+ n.setParent(&ws.root)
+ n.weight = http2priorityDefaultWeight
+ return
+ }
+
+ // Ignore if the client tries to make a node its own parent.
+ if n == parent {
+ return
+ }
+
+ // Section 5.3.3:
+ // "If a stream is made dependent on one of its own dependencies, the
+ // formerly dependent stream is first moved to be dependent on the
+ // reprioritized stream's previous parent. The moved dependency retains
+ // its weight."
+ //
+ // That is: if parent depends on n, move parent to depend on n.parent.
+ for x := parent.parent; x != nil; x = x.parent {
+ if x == n {
+ parent.setParent(n.parent)
+ break
+ }
+ }
+
+ // Section 5.3.3: The exclusive flag causes the stream to become the sole
+ // dependency of its parent stream, causing other dependencies to become
+ // dependent on the exclusive stream.
+ if priority.Exclusive {
+ k := parent.kids
+ for k != nil {
+ next := k.next
+ if k != n {
+ k.setParent(n)
+ }
+ k = next
+ }
+ }
+
+ n.setParent(parent)
+ n.weight = priority.Weight
+}
+
+func (ws *http2priorityWriteScheduler) Push(wr http2FrameWriteRequest) {
+ var n *http2priorityNode
+ if wr.isControl() {
+ n = &ws.root
+ } else {
+ id := wr.StreamID()
+ n = ws.nodes[id]
+ if n == nil {
+ // id is an idle or closed stream. wr should not be a HEADERS or
+ // DATA frame. In other case, we push wr onto the root, rather
+ // than creating a new priorityNode.
+ if wr.DataSize() > 0 {
+ panic("add DATA on non-open stream")
+ }
+ n = &ws.root
+ }
+ }
+ n.q.push(wr)
+}
+
+func (ws *http2priorityWriteScheduler) Pop() (wr http2FrameWriteRequest, ok bool) {
+ ws.root.walkReadyInOrder(false, &ws.tmp, func(n *http2priorityNode, openParent bool) bool {
+ limit := int32(math.MaxInt32)
+ if openParent {
+ limit = ws.writeThrottleLimit
+ }
+ wr, ok = n.q.consume(limit)
+ if !ok {
+ return false
+ }
+ n.addBytes(int64(wr.DataSize()))
+ // If B depends on A and B continuously has data available but A
+ // does not, gradually increase the throttling limit to allow B to
+ // steal more and more bandwidth from A.
+ if openParent {
+ ws.writeThrottleLimit += 1024
+ if ws.writeThrottleLimit < 0 {
+ ws.writeThrottleLimit = math.MaxInt32
+ }
+ } else if ws.enableWriteThrottle {
+ ws.writeThrottleLimit = 1024
+ }
+ return true
+ })
+ return wr, ok
+}
+
+func (ws *http2priorityWriteScheduler) addClosedOrIdleNode(list *[]*http2priorityNode, maxSize int, n *http2priorityNode) {
+ if maxSize == 0 {
+ return
+ }
+ if len(*list) == maxSize {
+ // Remove the oldest node, then shift left.
+ ws.removeNode((*list)[0])
+ x := (*list)[1:]
+ copy(*list, x)
+ *list = (*list)[:len(x)]
+ }
+ *list = append(*list, n)
+}
+
+func (ws *http2priorityWriteScheduler) removeNode(n *http2priorityNode) {
+ for k := n.kids; k != nil; k = k.next {
+ k.setParent(n.parent)
+ }
+ n.setParent(nil)
+ delete(ws.nodes, n.id)
+}
+
+// NewRandomWriteScheduler constructs a WriteScheduler that ignores HTTP/2
+// priorities. Control frames like SETTINGS and PING are written before DATA
+// frames, but if no control frames are queued and multiple streams have queued
+// HEADERS or DATA frames, Pop selects a ready stream arbitrarily.
+func http2NewRandomWriteScheduler() http2WriteScheduler {
+ return &http2randomWriteScheduler{sq: make(map[uint32]*http2writeQueue)}
+}
+
+type http2randomWriteScheduler struct {
+ // zero are frames not associated with a specific stream.
+ zero http2writeQueue
+
+ // sq contains the stream-specific queues, keyed by stream ID.
+ // When a stream is idle, closed, or emptied, it's deleted
+ // from the map.
+ sq map[uint32]*http2writeQueue
+
+ // pool of empty queues for reuse.
+ queuePool http2writeQueuePool
+}
+
+func (ws *http2randomWriteScheduler) OpenStream(streamID uint32, options http2OpenStreamOptions) {
+ // no-op: idle streams are not tracked
+}
+
+func (ws *http2randomWriteScheduler) CloseStream(streamID uint32) {
+ q, ok := ws.sq[streamID]
+ if !ok {
+ return
+ }
+ delete(ws.sq, streamID)
+ ws.queuePool.put(q)
+}
+
+func (ws *http2randomWriteScheduler) AdjustStream(streamID uint32, priority http2PriorityParam) {
+ // no-op: priorities are ignored
+}
+
+func (ws *http2randomWriteScheduler) Push(wr http2FrameWriteRequest) {
+ if wr.isControl() {
+ ws.zero.push(wr)
+ return
+ }
+ id := wr.StreamID()
+ q, ok := ws.sq[id]
+ if !ok {
+ q = ws.queuePool.get()
+ ws.sq[id] = q
+ }
+ q.push(wr)
+}
+
+func (ws *http2randomWriteScheduler) Pop() (http2FrameWriteRequest, bool) {
+ // Control and RST_STREAM frames first.
+ if !ws.zero.empty() {
+ return ws.zero.shift(), true
+ }
+ // Iterate over all non-idle streams until finding one that can be consumed.
+ for streamID, q := range ws.sq {
+ if wr, ok := q.consume(math.MaxInt32); ok {
+ if q.empty() {
+ delete(ws.sq, streamID)
+ ws.queuePool.put(q)
+ }
+ return wr, true
+ }
+ }
+ return http2FrameWriteRequest{}, false
+}
+
+type http2roundRobinWriteScheduler struct {
+ // control contains control frames (SETTINGS, PING, etc.).
+ control http2writeQueue
+
+ // streams maps stream ID to a queue.
+ streams map[uint32]*http2writeQueue
+
+ // stream queues are stored in a circular linked list.
+ // head is the next stream to write, or nil if there are no streams open.
+ head *http2writeQueue
+
+ // pool of empty queues for reuse.
+ queuePool http2writeQueuePool
+}
+
+// newRoundRobinWriteScheduler constructs a new write scheduler.
+// The round robin scheduler priorizes control frames
+// like SETTINGS and PING over DATA frames.
+// When there are no control frames to send, it performs a round-robin
+// selection from the ready streams.
+func http2newRoundRobinWriteScheduler() http2WriteScheduler {
+ ws := &http2roundRobinWriteScheduler{
+ streams: make(map[uint32]*http2writeQueue),
+ }
+ return ws
+}
+
+func (ws *http2roundRobinWriteScheduler) OpenStream(streamID uint32, options http2OpenStreamOptions) {
+ if ws.streams[streamID] != nil {
+ panic(fmt.Errorf("stream %d already opened", streamID))
+ }
+ q := ws.queuePool.get()
+ ws.streams[streamID] = q
+ if ws.head == nil {
+ ws.head = q
+ q.next = q
+ q.prev = q
+ } else {
+ // Queues are stored in a ring.
+ // Insert the new stream before ws.head, putting it at the end of the list.
+ q.prev = ws.head.prev
+ q.next = ws.head
+ q.prev.next = q
+ q.next.prev = q
+ }
+}
+
+func (ws *http2roundRobinWriteScheduler) CloseStream(streamID uint32) {
+ q := ws.streams[streamID]
+ if q == nil {
+ return
+ }
+ if q.next == q {
+ // This was the only open stream.
+ ws.head = nil
+ } else {
+ q.prev.next = q.next
+ q.next.prev = q.prev
+ if ws.head == q {
+ ws.head = q.next
+ }
+ }
+ delete(ws.streams, streamID)
+ ws.queuePool.put(q)
+}
+
+func (ws *http2roundRobinWriteScheduler) AdjustStream(streamID uint32, priority http2PriorityParam) {}
+
+func (ws *http2roundRobinWriteScheduler) Push(wr http2FrameWriteRequest) {
+ if wr.isControl() {
+ ws.control.push(wr)
+ return
+ }
+ q := ws.streams[wr.StreamID()]
+ if q == nil {
+ // This is a closed stream.
+ // wr should not be a HEADERS or DATA frame.
+ // We push the request onto the control queue.
+ if wr.DataSize() > 0 {
+ panic("add DATA on non-open stream")
+ }
+ ws.control.push(wr)
+ return
+ }
+ q.push(wr)
+}
+
+func (ws *http2roundRobinWriteScheduler) Pop() (http2FrameWriteRequest, bool) {
+ // Control and RST_STREAM frames first.
+ if !ws.control.empty() {
+ return ws.control.shift(), true
+ }
+ if ws.head == nil {
+ return http2FrameWriteRequest{}, false
+ }
+ q := ws.head
+ for {
+ if wr, ok := q.consume(math.MaxInt32); ok {
+ ws.head = q.next
+ return wr, true
+ }
+ q = q.next
+ if q == ws.head {
+ break
+ }
+ }
+ return http2FrameWriteRequest{}, false
+}
diff --git a/src/net/http/h2_error.go b/src/net/http/h2_error.go
new file mode 100644
index 0000000..0391d31
--- /dev/null
+++ b/src/net/http/h2_error.go
@@ -0,0 +1,38 @@
+// Copyright 2022 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !nethttpomithttp2
+// +build !nethttpomithttp2
+
+package http
+
+import (
+ "reflect"
+)
+
+func (e http2StreamError) As(target any) bool {
+ dst := reflect.ValueOf(target).Elem()
+ dstType := dst.Type()
+ if dstType.Kind() != reflect.Struct {
+ return false
+ }
+ src := reflect.ValueOf(e)
+ srcType := src.Type()
+ numField := srcType.NumField()
+ if dstType.NumField() != numField {
+ return false
+ }
+ for i := 0; i < numField; i++ {
+ sf := srcType.Field(i)
+ df := dstType.Field(i)
+ if sf.Name != df.Name || !sf.Type.ConvertibleTo(df.Type) {
+ return false
+ }
+ }
+ for i := 0; i < numField; i++ {
+ df := dst.Field(i)
+ df.Set(src.Field(i).Convert(df.Type()))
+ }
+ return true
+}
diff --git a/src/net/http/h2_error_test.go b/src/net/http/h2_error_test.go
new file mode 100644
index 0000000..0d85e2f
--- /dev/null
+++ b/src/net/http/h2_error_test.go
@@ -0,0 +1,44 @@
+// Copyright 2022 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !nethttpomithttp2
+// +build !nethttpomithttp2
+
+package http
+
+import (
+ "errors"
+ "fmt"
+ "testing"
+)
+
+type externalStreamErrorCode uint32
+
+type externalStreamError struct {
+ StreamID uint32
+ Code externalStreamErrorCode
+ Cause error
+}
+
+func (e externalStreamError) Error() string {
+ return fmt.Sprintf("ID %v, code %v", e.StreamID, e.Code)
+}
+
+func TestStreamError(t *testing.T) {
+ var target externalStreamError
+ streamErr := http2streamError(42, http2ErrCodeProtocol)
+ ok := errors.As(streamErr, &target)
+ if !ok {
+ t.Fatalf("errors.As failed")
+ }
+ if target.StreamID != streamErr.StreamID {
+ t.Errorf("got StreamID %v, expected %v", target.StreamID, streamErr.StreamID)
+ }
+ if target.Cause != streamErr.Cause {
+ t.Errorf("got Cause %v, expected %v", target.Cause, streamErr.Cause)
+ }
+ if uint32(target.Code) != uint32(streamErr.Code) {
+ t.Errorf("got Code %v, expected %v", target.Code, streamErr.Code)
+ }
+}
diff --git a/src/net/http/header.go b/src/net/http/header.go
new file mode 100644
index 0000000..e0b342c
--- /dev/null
+++ b/src/net/http/header.go
@@ -0,0 +1,280 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "io"
+ "net/http/httptrace"
+ "net/http/internal/ascii"
+ "net/textproto"
+ "sort"
+ "strings"
+ "sync"
+ "time"
+
+ "golang.org/x/net/http/httpguts"
+)
+
+// A Header represents the key-value pairs in an HTTP header.
+//
+// The keys should be in canonical form, as returned by
+// CanonicalHeaderKey.
+type Header map[string][]string
+
+// Add adds the key, value pair to the header.
+// It appends to any existing values associated with key.
+// The key is case insensitive; it is canonicalized by
+// CanonicalHeaderKey.
+func (h Header) Add(key, value string) {
+ textproto.MIMEHeader(h).Add(key, value)
+}
+
+// Set sets the header entries associated with key to the
+// single element value. It replaces any existing values
+// associated with key. The key is case insensitive; it is
+// canonicalized by textproto.CanonicalMIMEHeaderKey.
+// To use non-canonical keys, assign to the map directly.
+func (h Header) Set(key, value string) {
+ textproto.MIMEHeader(h).Set(key, value)
+}
+
+// Get gets the first value associated with the given key. If
+// there are no values associated with the key, Get returns "".
+// It is case insensitive; textproto.CanonicalMIMEHeaderKey is
+// used to canonicalize the provided key. Get assumes that all
+// keys are stored in canonical form. To use non-canonical keys,
+// access the map directly.
+func (h Header) Get(key string) string {
+ return textproto.MIMEHeader(h).Get(key)
+}
+
+// Values returns all values associated with the given key.
+// It is case insensitive; textproto.CanonicalMIMEHeaderKey is
+// used to canonicalize the provided key. To use non-canonical
+// keys, access the map directly.
+// The returned slice is not a copy.
+func (h Header) Values(key string) []string {
+ return textproto.MIMEHeader(h).Values(key)
+}
+
+// get is like Get, but key must already be in CanonicalHeaderKey form.
+func (h Header) get(key string) string {
+ if v := h[key]; len(v) > 0 {
+ return v[0]
+ }
+ return ""
+}
+
+// has reports whether h has the provided key defined, even if it's
+// set to 0-length slice.
+func (h Header) has(key string) bool {
+ _, ok := h[key]
+ return ok
+}
+
+// Del deletes the values associated with key.
+// The key is case insensitive; it is canonicalized by
+// CanonicalHeaderKey.
+func (h Header) Del(key string) {
+ textproto.MIMEHeader(h).Del(key)
+}
+
+// Write writes a header in wire format.
+func (h Header) Write(w io.Writer) error {
+ return h.write(w, nil)
+}
+
+func (h Header) write(w io.Writer, trace *httptrace.ClientTrace) error {
+ return h.writeSubset(w, nil, trace)
+}
+
+// Clone returns a copy of h or nil if h is nil.
+func (h Header) Clone() Header {
+ if h == nil {
+ return nil
+ }
+
+ // Find total number of values.
+ nv := 0
+ for _, vv := range h {
+ nv += len(vv)
+ }
+ sv := make([]string, nv) // shared backing array for headers' values
+ h2 := make(Header, len(h))
+ for k, vv := range h {
+ if vv == nil {
+ // Preserve nil values. ReverseProxy distinguishes
+ // between nil and zero-length header values.
+ h2[k] = nil
+ continue
+ }
+ n := copy(sv, vv)
+ h2[k] = sv[:n:n]
+ sv = sv[n:]
+ }
+ return h2
+}
+
+var timeFormats = []string{
+ TimeFormat,
+ time.RFC850,
+ time.ANSIC,
+}
+
+// ParseTime parses a time header (such as the Date: header),
+// trying each of the three formats allowed by HTTP/1.1:
+// TimeFormat, time.RFC850, and time.ANSIC.
+func ParseTime(text string) (t time.Time, err error) {
+ for _, layout := range timeFormats {
+ t, err = time.Parse(layout, text)
+ if err == nil {
+ return
+ }
+ }
+ return
+}
+
+var headerNewlineToSpace = strings.NewReplacer("\n", " ", "\r", " ")
+
+// stringWriter implements WriteString on a Writer.
+type stringWriter struct {
+ w io.Writer
+}
+
+func (w stringWriter) WriteString(s string) (n int, err error) {
+ return w.w.Write([]byte(s))
+}
+
+type keyValues struct {
+ key string
+ values []string
+}
+
+// A headerSorter implements sort.Interface by sorting a []keyValues
+// by key. It's used as a pointer, so it can fit in a sort.Interface
+// interface value without allocation.
+type headerSorter struct {
+ kvs []keyValues
+}
+
+func (s *headerSorter) Len() int { return len(s.kvs) }
+func (s *headerSorter) Swap(i, j int) { s.kvs[i], s.kvs[j] = s.kvs[j], s.kvs[i] }
+func (s *headerSorter) Less(i, j int) bool { return s.kvs[i].key < s.kvs[j].key }
+
+var headerSorterPool = sync.Pool{
+ New: func() any { return new(headerSorter) },
+}
+
+// sortedKeyValues returns h's keys sorted in the returned kvs
+// slice. The headerSorter used to sort is also returned, for possible
+// return to headerSorterCache.
+func (h Header) sortedKeyValues(exclude map[string]bool) (kvs []keyValues, hs *headerSorter) {
+ hs = headerSorterPool.Get().(*headerSorter)
+ if cap(hs.kvs) < len(h) {
+ hs.kvs = make([]keyValues, 0, len(h))
+ }
+ kvs = hs.kvs[:0]
+ for k, vv := range h {
+ if !exclude[k] {
+ kvs = append(kvs, keyValues{k, vv})
+ }
+ }
+ hs.kvs = kvs
+ sort.Sort(hs)
+ return kvs, hs
+}
+
+// WriteSubset writes a header in wire format.
+// If exclude is not nil, keys where exclude[key] == true are not written.
+// Keys are not canonicalized before checking the exclude map.
+func (h Header) WriteSubset(w io.Writer, exclude map[string]bool) error {
+ return h.writeSubset(w, exclude, nil)
+}
+
+func (h Header) writeSubset(w io.Writer, exclude map[string]bool, trace *httptrace.ClientTrace) error {
+ ws, ok := w.(io.StringWriter)
+ if !ok {
+ ws = stringWriter{w}
+ }
+ kvs, sorter := h.sortedKeyValues(exclude)
+ var formattedVals []string
+ for _, kv := range kvs {
+ if !httpguts.ValidHeaderFieldName(kv.key) {
+ // This could be an error. In the common case of
+ // writing response headers, however, we have no good
+ // way to provide the error back to the server
+ // handler, so just drop invalid headers instead.
+ continue
+ }
+ for _, v := range kv.values {
+ v = headerNewlineToSpace.Replace(v)
+ v = textproto.TrimString(v)
+ for _, s := range []string{kv.key, ": ", v, "\r\n"} {
+ if _, err := ws.WriteString(s); err != nil {
+ headerSorterPool.Put(sorter)
+ return err
+ }
+ }
+ if trace != nil && trace.WroteHeaderField != nil {
+ formattedVals = append(formattedVals, v)
+ }
+ }
+ if trace != nil && trace.WroteHeaderField != nil {
+ trace.WroteHeaderField(kv.key, formattedVals)
+ formattedVals = nil
+ }
+ }
+ headerSorterPool.Put(sorter)
+ return nil
+}
+
+// CanonicalHeaderKey returns the canonical format of the
+// header key s. The canonicalization converts the first
+// letter and any letter following a hyphen to upper case;
+// the rest are converted to lowercase. For example, the
+// canonical key for "accept-encoding" is "Accept-Encoding".
+// If s contains a space or invalid header field bytes, it is
+// returned without modifications.
+func CanonicalHeaderKey(s string) string { return textproto.CanonicalMIMEHeaderKey(s) }
+
+// hasToken reports whether token appears with v, ASCII
+// case-insensitive, with space or comma boundaries.
+// token must be all lowercase.
+// v may contain mixed cased.
+func hasToken(v, token string) bool {
+ if len(token) > len(v) || token == "" {
+ return false
+ }
+ if v == token {
+ return true
+ }
+ for sp := 0; sp <= len(v)-len(token); sp++ {
+ // Check that first character is good.
+ // The token is ASCII, so checking only a single byte
+ // is sufficient. We skip this potential starting
+ // position if both the first byte and its potential
+ // ASCII uppercase equivalent (b|0x20) don't match.
+ // False positives ('^' => '~') are caught by EqualFold.
+ if b := v[sp]; b != token[0] && b|0x20 != token[0] {
+ continue
+ }
+ // Check that start pos is on a valid token boundary.
+ if sp > 0 && !isTokenBoundary(v[sp-1]) {
+ continue
+ }
+ // Check that end pos is on a valid token boundary.
+ if endPos := sp + len(token); endPos != len(v) && !isTokenBoundary(v[endPos]) {
+ continue
+ }
+ if ascii.EqualFold(v[sp:sp+len(token)], token) {
+ return true
+ }
+ }
+ return false
+}
+
+func isTokenBoundary(b byte) bool {
+ return b == ' ' || b == ',' || b == '\t'
+}
diff --git a/src/net/http/header_test.go b/src/net/http/header_test.go
new file mode 100644
index 0000000..e98cc5c
--- /dev/null
+++ b/src/net/http/header_test.go
@@ -0,0 +1,272 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "bytes"
+ "internal/race"
+ "reflect"
+ "runtime"
+ "strings"
+ "testing"
+ "time"
+)
+
+var headerWriteTests = []struct {
+ h Header
+ exclude map[string]bool
+ expected string
+}{
+ {Header{}, nil, ""},
+ {
+ Header{
+ "Content-Type": {"text/html; charset=UTF-8"},
+ "Content-Length": {"0"},
+ },
+ nil,
+ "Content-Length: 0\r\nContent-Type: text/html; charset=UTF-8\r\n",
+ },
+ {
+ Header{
+ "Content-Length": {"0", "1", "2"},
+ },
+ nil,
+ "Content-Length: 0\r\nContent-Length: 1\r\nContent-Length: 2\r\n",
+ },
+ {
+ Header{
+ "Expires": {"-1"},
+ "Content-Length": {"0"},
+ "Content-Encoding": {"gzip"},
+ },
+ map[string]bool{"Content-Length": true},
+ "Content-Encoding: gzip\r\nExpires: -1\r\n",
+ },
+ {
+ Header{
+ "Expires": {"-1"},
+ "Content-Length": {"0", "1", "2"},
+ "Content-Encoding": {"gzip"},
+ },
+ map[string]bool{"Content-Length": true},
+ "Content-Encoding: gzip\r\nExpires: -1\r\n",
+ },
+ {
+ Header{
+ "Expires": {"-1"},
+ "Content-Length": {"0"},
+ "Content-Encoding": {"gzip"},
+ },
+ map[string]bool{"Content-Length": true, "Expires": true, "Content-Encoding": true},
+ "",
+ },
+ {
+ Header{
+ "Nil": nil,
+ "Empty": {},
+ "Blank": {""},
+ "Double-Blank": {"", ""},
+ },
+ nil,
+ "Blank: \r\nDouble-Blank: \r\nDouble-Blank: \r\n",
+ },
+ // Tests header sorting when over the insertion sort threshold side:
+ {
+ Header{
+ "k1": {"1a", "1b"},
+ "k2": {"2a", "2b"},
+ "k3": {"3a", "3b"},
+ "k4": {"4a", "4b"},
+ "k5": {"5a", "5b"},
+ "k6": {"6a", "6b"},
+ "k7": {"7a", "7b"},
+ "k8": {"8a", "8b"},
+ "k9": {"9a", "9b"},
+ },
+ map[string]bool{"k5": true},
+ "k1: 1a\r\nk1: 1b\r\nk2: 2a\r\nk2: 2b\r\nk3: 3a\r\nk3: 3b\r\n" +
+ "k4: 4a\r\nk4: 4b\r\nk6: 6a\r\nk6: 6b\r\n" +
+ "k7: 7a\r\nk7: 7b\r\nk8: 8a\r\nk8: 8b\r\nk9: 9a\r\nk9: 9b\r\n",
+ },
+ // Tests invalid characters in headers.
+ {
+ Header{
+ "Content-Type": {"text/html; charset=UTF-8"},
+ "NewlineInValue": {"1\r\nBar: 2"},
+ "NewlineInKey\r\n": {"1"},
+ "Colon:InKey": {"1"},
+ "Evil: 1\r\nSmuggledValue": {"1"},
+ },
+ nil,
+ "Content-Type: text/html; charset=UTF-8\r\n" +
+ "NewlineInValue: 1 Bar: 2\r\n",
+ },
+}
+
+func TestHeaderWrite(t *testing.T) {
+ var buf strings.Builder
+ for i, test := range headerWriteTests {
+ test.h.WriteSubset(&buf, test.exclude)
+ if buf.String() != test.expected {
+ t.Errorf("#%d:\n got: %q\nwant: %q", i, buf.String(), test.expected)
+ }
+ buf.Reset()
+ }
+}
+
+var parseTimeTests = []struct {
+ h Header
+ err bool
+}{
+ {Header{"Date": {""}}, true},
+ {Header{"Date": {"invalid"}}, true},
+ {Header{"Date": {"1994-11-06T08:49:37Z00:00"}}, true},
+ {Header{"Date": {"Sun, 06 Nov 1994 08:49:37 GMT"}}, false},
+ {Header{"Date": {"Sunday, 06-Nov-94 08:49:37 GMT"}}, false},
+ {Header{"Date": {"Sun Nov 6 08:49:37 1994"}}, false},
+}
+
+func TestParseTime(t *testing.T) {
+ expect := time.Date(1994, 11, 6, 8, 49, 37, 0, time.UTC)
+ for i, test := range parseTimeTests {
+ d, err := ParseTime(test.h.Get("Date"))
+ if err != nil {
+ if !test.err {
+ t.Errorf("#%d:\n got err: %v", i, err)
+ }
+ continue
+ }
+ if test.err {
+ t.Errorf("#%d:\n should err", i)
+ continue
+ }
+ if !expect.Equal(d) {
+ t.Errorf("#%d:\n got: %v\nwant: %v", i, d, expect)
+ }
+ }
+}
+
+type hasTokenTest struct {
+ header string
+ token string
+ want bool
+}
+
+var hasTokenTests = []hasTokenTest{
+ {"", "", false},
+ {"", "foo", false},
+ {"foo", "foo", true},
+ {"foo ", "foo", true},
+ {" foo", "foo", true},
+ {" foo ", "foo", true},
+ {"foo,bar", "foo", true},
+ {"bar,foo", "foo", true},
+ {"bar, foo", "foo", true},
+ {"bar,foo, baz", "foo", true},
+ {"bar, foo,baz", "foo", true},
+ {"bar,foo, baz", "foo", true},
+ {"bar, foo, baz", "foo", true},
+ {"FOO", "foo", true},
+ {"FOO ", "foo", true},
+ {" FOO", "foo", true},
+ {" FOO ", "foo", true},
+ {"FOO,BAR", "foo", true},
+ {"BAR,FOO", "foo", true},
+ {"BAR, FOO", "foo", true},
+ {"BAR,FOO, baz", "foo", true},
+ {"BAR, FOO,BAZ", "foo", true},
+ {"BAR,FOO, BAZ", "foo", true},
+ {"BAR, FOO, BAZ", "foo", true},
+ {"foobar", "foo", false},
+ {"barfoo ", "foo", false},
+}
+
+func TestHasToken(t *testing.T) {
+ for _, tt := range hasTokenTests {
+ if hasToken(tt.header, tt.token) != tt.want {
+ t.Errorf("hasToken(%q, %q) = %v; want %v", tt.header, tt.token, !tt.want, tt.want)
+ }
+ }
+}
+
+func TestNilHeaderClone(t *testing.T) {
+ t1 := Header(nil)
+ t2 := t1.Clone()
+ if t2 != nil {
+ t.Errorf("cloned header does not match original: got: %+v; want: %+v", t2, nil)
+ }
+}
+
+var testHeader = Header{
+ "Content-Length": {"123"},
+ "Content-Type": {"text/plain"},
+ "Date": {"some date at some time Z"},
+ "Server": {DefaultUserAgent},
+}
+
+var buf bytes.Buffer
+
+func BenchmarkHeaderWriteSubset(b *testing.B) {
+ b.ReportAllocs()
+ for i := 0; i < b.N; i++ {
+ buf.Reset()
+ testHeader.WriteSubset(&buf, nil)
+ }
+}
+
+func TestHeaderWriteSubsetAllocs(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping alloc test in short mode")
+ }
+ if race.Enabled {
+ t.Skip("skipping test under race detector")
+ }
+ if runtime.GOMAXPROCS(0) > 1 {
+ t.Skip("skipping; GOMAXPROCS>1")
+ }
+ n := testing.AllocsPerRun(100, func() {
+ buf.Reset()
+ testHeader.WriteSubset(&buf, nil)
+ })
+ if n > 0 {
+ t.Errorf("allocs = %g; want 0", n)
+ }
+}
+
+// Issue 34878: test that every call to
+// cloneOrMakeHeader never returns a nil Header.
+func TestCloneOrMakeHeader(t *testing.T) {
+ tests := []struct {
+ name string
+ in, want Header
+ }{
+ {"nil", nil, Header{}},
+ {"empty", Header{}, Header{}},
+ {
+ name: "non-empty",
+ in: Header{"foo": {"bar"}},
+ want: Header{"foo": {"bar"}},
+ },
+ {
+ name: "nil value",
+ in: Header{"foo": nil},
+ want: Header{"foo": nil},
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := cloneOrMakeHeader(tt.in)
+ if got == nil {
+ t.Fatal("unexpected nil Header")
+ }
+ if !reflect.DeepEqual(got, tt.want) {
+ t.Fatalf("Got: %#v\nWant: %#v", got, tt.want)
+ }
+ got.Add("A", "B")
+ got.Get("A")
+ })
+ }
+}
diff --git a/src/net/http/http.go b/src/net/http/http.go
new file mode 100644
index 0000000..9b81654
--- /dev/null
+++ b/src/net/http/http.go
@@ -0,0 +1,165 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:generate bundle -o=h2_bundle.go -prefix=http2 -tags=!nethttpomithttp2 golang.org/x/net/http2
+
+package http
+
+import (
+ "io"
+ "strconv"
+ "strings"
+ "time"
+ "unicode/utf8"
+
+ "golang.org/x/net/http/httpguts"
+)
+
+// incomparable is a zero-width, non-comparable type. Adding it to a struct
+// makes that struct also non-comparable, and generally doesn't add
+// any size (as long as it's first).
+type incomparable [0]func()
+
+// maxInt64 is the effective "infinite" value for the Server and
+// Transport's byte-limiting readers.
+const maxInt64 = 1<<63 - 1
+
+// aLongTimeAgo is a non-zero time, far in the past, used for
+// immediate cancellation of network operations.
+var aLongTimeAgo = time.Unix(1, 0)
+
+// omitBundledHTTP2 is set by omithttp2.go when the nethttpomithttp2
+// build tag is set. That means h2_bundle.go isn't compiled in and we
+// shouldn't try to use it.
+var omitBundledHTTP2 bool
+
+// TODO(bradfitz): move common stuff here. The other files have accumulated
+// generic http stuff in random places.
+
+// contextKey is a value for use with context.WithValue. It's used as
+// a pointer so it fits in an interface{} without allocation.
+type contextKey struct {
+ name string
+}
+
+func (k *contextKey) String() string { return "net/http context value " + k.name }
+
+// Given a string of the form "host", "host:port", or "[ipv6::address]:port",
+// return true if the string includes a port.
+func hasPort(s string) bool { return strings.LastIndex(s, ":") > strings.LastIndex(s, "]") }
+
+// removeEmptyPort strips the empty port in ":port" to ""
+// as mandated by RFC 3986 Section 6.2.3.
+func removeEmptyPort(host string) string {
+ if hasPort(host) {
+ return strings.TrimSuffix(host, ":")
+ }
+ return host
+}
+
+func isNotToken(r rune) bool {
+ return !httpguts.IsTokenRune(r)
+}
+
+// stringContainsCTLByte reports whether s contains any ASCII control character.
+func stringContainsCTLByte(s string) bool {
+ for i := 0; i < len(s); i++ {
+ b := s[i]
+ if b < ' ' || b == 0x7f {
+ return true
+ }
+ }
+ return false
+}
+
+func hexEscapeNonASCII(s string) string {
+ newLen := 0
+ for i := 0; i < len(s); i++ {
+ if s[i] >= utf8.RuneSelf {
+ newLen += 3
+ } else {
+ newLen++
+ }
+ }
+ if newLen == len(s) {
+ return s
+ }
+ b := make([]byte, 0, newLen)
+ var pos int
+ for i := 0; i < len(s); i++ {
+ if s[i] >= utf8.RuneSelf {
+ if pos < i {
+ b = append(b, s[pos:i]...)
+ }
+ b = append(b, '%')
+ b = strconv.AppendInt(b, int64(s[i]), 16)
+ pos = i + 1
+ }
+ }
+ if pos < len(s) {
+ b = append(b, s[pos:]...)
+ }
+ return string(b)
+}
+
+// NoBody is an io.ReadCloser with no bytes. Read always returns EOF
+// and Close always returns nil. It can be used in an outgoing client
+// request to explicitly signal that a request has zero bytes.
+// An alternative, however, is to simply set Request.Body to nil.
+var NoBody = noBody{}
+
+type noBody struct{}
+
+func (noBody) Read([]byte) (int, error) { return 0, io.EOF }
+func (noBody) Close() error { return nil }
+func (noBody) WriteTo(io.Writer) (int64, error) { return 0, nil }
+
+var (
+ // verify that an io.Copy from NoBody won't require a buffer:
+ _ io.WriterTo = NoBody
+ _ io.ReadCloser = NoBody
+)
+
+// PushOptions describes options for Pusher.Push.
+type PushOptions struct {
+ // Method specifies the HTTP method for the promised request.
+ // If set, it must be "GET" or "HEAD". Empty means "GET".
+ Method string
+
+ // Header specifies additional promised request headers. This cannot
+ // include HTTP/2 pseudo header fields like ":path" and ":scheme",
+ // which will be added automatically.
+ Header Header
+}
+
+// Pusher is the interface implemented by ResponseWriters that support
+// HTTP/2 server push. For more background, see
+// https://tools.ietf.org/html/rfc7540#section-8.2.
+type Pusher interface {
+ // Push initiates an HTTP/2 server push. This constructs a synthetic
+ // request using the given target and options, serializes that request
+ // into a PUSH_PROMISE frame, then dispatches that request using the
+ // server's request handler. If opts is nil, default options are used.
+ //
+ // The target must either be an absolute path (like "/path") or an absolute
+ // URL that contains a valid host and the same scheme as the parent request.
+ // If the target is a path, it will inherit the scheme and host of the
+ // parent request.
+ //
+ // The HTTP/2 spec disallows recursive pushes and cross-authority pushes.
+ // Push may or may not detect these invalid pushes; however, invalid
+ // pushes will be detected and canceled by conforming clients.
+ //
+ // Handlers that wish to push URL X should call Push before sending any
+ // data that may trigger a request for URL X. This avoids a race where the
+ // client issues requests for X before receiving the PUSH_PROMISE for X.
+ //
+ // Push will run in a separate goroutine making the order of arrival
+ // non-deterministic. Any required synchronization needs to be implemented
+ // by the caller.
+ //
+ // Push returns ErrNotSupported if the client has disabled push or if push
+ // is not supported on the underlying connection.
+ Push(target string, opts *PushOptions) error
+}
diff --git a/src/net/http/http_test.go b/src/net/http/http_test.go
new file mode 100644
index 0000000..91bb1b2
--- /dev/null
+++ b/src/net/http/http_test.go
@@ -0,0 +1,201 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Tests of internal functions and things with no better homes.
+
+package http
+
+import (
+ "bytes"
+ "internal/testenv"
+ "io/fs"
+ "net/url"
+ "os"
+ "os/exec"
+ "reflect"
+ "regexp"
+ "strings"
+ "testing"
+)
+
+func TestForeachHeaderElement(t *testing.T) {
+ tests := []struct {
+ in string
+ want []string
+ }{
+ {"Foo", []string{"Foo"}},
+ {" Foo", []string{"Foo"}},
+ {"Foo ", []string{"Foo"}},
+ {" Foo ", []string{"Foo"}},
+
+ {"foo", []string{"foo"}},
+ {"anY-cAsE", []string{"anY-cAsE"}},
+
+ {"", nil},
+ {",,,, , ,, ,,, ,", nil},
+
+ {" Foo,Bar, Baz,lower,,Quux ", []string{"Foo", "Bar", "Baz", "lower", "Quux"}},
+ }
+ for _, tt := range tests {
+ var got []string
+ foreachHeaderElement(tt.in, func(v string) {
+ got = append(got, v)
+ })
+ if !reflect.DeepEqual(got, tt.want) {
+ t.Errorf("foreachHeaderElement(%q) = %q; want %q", tt.in, got, tt.want)
+ }
+ }
+}
+
+// Test that cmd/go doesn't link in the HTTP server.
+//
+// This catches accidental dependencies between the HTTP transport and
+// server code.
+func TestCmdGoNoHTTPServer(t *testing.T) {
+ t.Parallel()
+ goBin := testenv.GoToolPath(t)
+ out, err := exec.Command(goBin, "tool", "nm", goBin).CombinedOutput()
+ if err != nil {
+ t.Fatalf("go tool nm: %v: %s", err, out)
+ }
+ wantSym := map[string]bool{
+ // Verify these exist: (sanity checking this test)
+ "net/http.(*Client).do": true,
+ "net/http.(*Transport).RoundTrip": true,
+
+ // Verify these don't exist:
+ "net/http.http2Server": false,
+ "net/http.(*Server).Serve": false,
+ "net/http.(*ServeMux).ServeHTTP": false,
+ "net/http.DefaultServeMux": false,
+ }
+ for sym, want := range wantSym {
+ got := bytes.Contains(out, []byte(sym))
+ if !want && got {
+ t.Errorf("cmd/go unexpectedly links in HTTP server code; found symbol %q in cmd/go", sym)
+ }
+ if want && !got {
+ t.Errorf("expected to find symbol %q in cmd/go; not found", sym)
+ }
+ }
+}
+
+// Tests that the nethttpomithttp2 build tag doesn't rot too much,
+// even if there's not a regular builder on it.
+func TestOmitHTTP2(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping in short mode")
+ }
+ t.Parallel()
+ goTool := testenv.GoToolPath(t)
+ out, err := exec.Command(goTool, "test", "-short", "-tags=nethttpomithttp2", "net/http").CombinedOutput()
+ if err != nil {
+ t.Fatalf("go test -short failed: %v, %s", err, out)
+ }
+}
+
+// Tests that the nethttpomithttp2 build tag at least type checks
+// in short mode.
+// The TestOmitHTTP2 test above actually runs tests (in long mode).
+func TestOmitHTTP2Vet(t *testing.T) {
+ t.Parallel()
+ goTool := testenv.GoToolPath(t)
+ out, err := exec.Command(goTool, "vet", "-tags=nethttpomithttp2", "net/http").CombinedOutput()
+ if err != nil {
+ t.Fatalf("go vet failed: %v, %s", err, out)
+ }
+}
+
+var valuesCount int
+
+func BenchmarkCopyValues(b *testing.B) {
+ b.ReportAllocs()
+ src := url.Values{
+ "a": {"1", "2", "3", "4", "5"},
+ "b": {"2", "2", "3", "4", "5"},
+ "c": {"3", "2", "3", "4", "5"},
+ "d": {"4", "2", "3", "4", "5"},
+ "e": {"1", "1", "2", "3", "4", "5", "6", "7", "abcdef", "l", "a", "b", "c", "d", "z"},
+ "j": {"1", "2"},
+ "m": nil,
+ }
+ for i := 0; i < b.N; i++ {
+ dst := url.Values{"a": {"b"}, "b": {"2"}, "c": {"3"}, "d": {"4"}, "j": nil, "m": {"x"}}
+ copyValues(dst, src)
+ if valuesCount = len(dst["a"]); valuesCount != 6 {
+ b.Fatalf(`%d items in dst["a"] but expected 6`, valuesCount)
+ }
+ }
+ if valuesCount == 0 {
+ b.Fatal("Benchmark wasn't run")
+ }
+}
+
+var forbiddenStringsFunctions = map[string]bool{
+ // Functions that use Unicode-aware case folding.
+ "EqualFold": true,
+ "Title": true,
+ "ToLower": true,
+ "ToLowerSpecial": true,
+ "ToTitle": true,
+ "ToTitleSpecial": true,
+ "ToUpper": true,
+ "ToUpperSpecial": true,
+
+ // Functions that use Unicode-aware spaces.
+ "Fields": true,
+ "TrimSpace": true,
+}
+
+// TestNoUnicodeStrings checks that nothing in net/http uses the Unicode-aware
+// strings and bytes package functions. HTTP is mostly ASCII based, and doing
+// Unicode-aware case folding or space stripping can introduce vulnerabilities.
+func TestNoUnicodeStrings(t *testing.T) {
+ if !testenv.HasSrc() {
+ t.Skip("source code not available")
+ }
+
+ re := regexp.MustCompile(`(strings|bytes).([A-Za-z]+)`)
+ if err := fs.WalkDir(os.DirFS("."), ".", func(path string, d fs.DirEntry, err error) error {
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if path == "internal/ascii" {
+ return fs.SkipDir
+ }
+ if !strings.HasSuffix(path, ".go") ||
+ strings.HasSuffix(path, "_test.go") ||
+ path == "h2_bundle.go" || d.IsDir() {
+ return nil
+ }
+
+ contents, err := os.ReadFile(path)
+ if err != nil {
+ t.Fatal(err)
+ }
+ for lineNum, line := range strings.Split(string(contents), "\n") {
+ for _, match := range re.FindAllStringSubmatch(line, -1) {
+ if !forbiddenStringsFunctions[match[2]] {
+ continue
+ }
+ t.Errorf("disallowed call to %s at %s:%d", match[0], path, lineNum+1)
+ }
+ }
+
+ return nil
+ }); err != nil {
+ t.Fatal(err)
+ }
+}
+
+const redirectURL = "/thisaredirect细雪withasciilettersのけぶabcdefghijk.html"
+
+func BenchmarkHexEscapeNonASCII(b *testing.B) {
+ b.ReportAllocs()
+
+ for i := 0; i < b.N; i++ {
+ hexEscapeNonASCII(redirectURL)
+ }
+}
diff --git a/src/net/http/httptest/example_test.go b/src/net/http/httptest/example_test.go
new file mode 100644
index 0000000..a673843
--- /dev/null
+++ b/src/net/http/httptest/example_test.go
@@ -0,0 +1,99 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package httptest_test
+
+import (
+ "fmt"
+ "io"
+ "log"
+ "net/http"
+ "net/http/httptest"
+)
+
+func ExampleResponseRecorder() {
+ handler := func(w http.ResponseWriter, r *http.Request) {
+ io.WriteString(w, "<html><body>Hello World!</body></html>")
+ }
+
+ req := httptest.NewRequest("GET", "http://example.com/foo", nil)
+ w := httptest.NewRecorder()
+ handler(w, req)
+
+ resp := w.Result()
+ body, _ := io.ReadAll(resp.Body)
+
+ fmt.Println(resp.StatusCode)
+ fmt.Println(resp.Header.Get("Content-Type"))
+ fmt.Println(string(body))
+
+ // Output:
+ // 200
+ // text/html; charset=utf-8
+ // <html><body>Hello World!</body></html>
+}
+
+func ExampleServer() {
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprintln(w, "Hello, client")
+ }))
+ defer ts.Close()
+
+ res, err := http.Get(ts.URL)
+ if err != nil {
+ log.Fatal(err)
+ }
+ greeting, err := io.ReadAll(res.Body)
+ res.Body.Close()
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ fmt.Printf("%s", greeting)
+ // Output: Hello, client
+}
+
+func ExampleServer_hTTP2() {
+ ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprintf(w, "Hello, %s", r.Proto)
+ }))
+ ts.EnableHTTP2 = true
+ ts.StartTLS()
+ defer ts.Close()
+
+ res, err := ts.Client().Get(ts.URL)
+ if err != nil {
+ log.Fatal(err)
+ }
+ greeting, err := io.ReadAll(res.Body)
+ res.Body.Close()
+ if err != nil {
+ log.Fatal(err)
+ }
+ fmt.Printf("%s", greeting)
+
+ // Output: Hello, HTTP/2.0
+}
+
+func ExampleNewTLSServer() {
+ ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprintln(w, "Hello, client")
+ }))
+ defer ts.Close()
+
+ client := ts.Client()
+ res, err := client.Get(ts.URL)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ greeting, err := io.ReadAll(res.Body)
+ res.Body.Close()
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ fmt.Printf("%s", greeting)
+ // Output: Hello, client
+}
diff --git a/src/net/http/httptest/httptest.go b/src/net/http/httptest/httptest.go
new file mode 100644
index 0000000..9bedefd
--- /dev/null
+++ b/src/net/http/httptest/httptest.go
@@ -0,0 +1,90 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package httptest provides utilities for HTTP testing.
+package httptest
+
+import (
+ "bufio"
+ "bytes"
+ "crypto/tls"
+ "io"
+ "net/http"
+ "strings"
+)
+
+// NewRequest returns a new incoming server Request, suitable
+// for passing to an http.Handler for testing.
+//
+// The target is the RFC 7230 "request-target": it may be either a
+// path or an absolute URL. If target is an absolute URL, the host name
+// from the URL is used. Otherwise, "example.com" is used.
+//
+// The TLS field is set to a non-nil dummy value if target has scheme
+// "https".
+//
+// The Request.Proto is always HTTP/1.1.
+//
+// An empty method means "GET".
+//
+// The provided body may be nil. If the body is of type *bytes.Reader,
+// *strings.Reader, or *bytes.Buffer, the Request.ContentLength is
+// set.
+//
+// NewRequest panics on error for ease of use in testing, where a
+// panic is acceptable.
+//
+// To generate a client HTTP request instead of a server request, see
+// the NewRequest function in the net/http package.
+func NewRequest(method, target string, body io.Reader) *http.Request {
+ if method == "" {
+ method = "GET"
+ }
+ req, err := http.ReadRequest(bufio.NewReader(strings.NewReader(method + " " + target + " HTTP/1.0\r\n\r\n")))
+ if err != nil {
+ panic("invalid NewRequest arguments; " + err.Error())
+ }
+
+ // HTTP/1.0 was used above to avoid needing a Host field. Change it to 1.1 here.
+ req.Proto = "HTTP/1.1"
+ req.ProtoMinor = 1
+ req.Close = false
+
+ if body != nil {
+ switch v := body.(type) {
+ case *bytes.Buffer:
+ req.ContentLength = int64(v.Len())
+ case *bytes.Reader:
+ req.ContentLength = int64(v.Len())
+ case *strings.Reader:
+ req.ContentLength = int64(v.Len())
+ default:
+ req.ContentLength = -1
+ }
+ if rc, ok := body.(io.ReadCloser); ok {
+ req.Body = rc
+ } else {
+ req.Body = io.NopCloser(body)
+ }
+ }
+
+ // 192.0.2.0/24 is "TEST-NET" in RFC 5737 for use solely in
+ // documentation and example source code and should not be
+ // used publicly.
+ req.RemoteAddr = "192.0.2.1:1234"
+
+ if req.Host == "" {
+ req.Host = "example.com"
+ }
+
+ if strings.HasPrefix(target, "https://") {
+ req.TLS = &tls.ConnectionState{
+ Version: tls.VersionTLS12,
+ HandshakeComplete: true,
+ ServerName: req.Host,
+ }
+ }
+
+ return req
+}
diff --git a/src/net/http/httptest/httptest_test.go b/src/net/http/httptest/httptest_test.go
new file mode 100644
index 0000000..071add6
--- /dev/null
+++ b/src/net/http/httptest/httptest_test.go
@@ -0,0 +1,179 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package httptest
+
+import (
+ "crypto/tls"
+ "io"
+ "net/http"
+ "net/url"
+ "reflect"
+ "strings"
+ "testing"
+)
+
+func TestNewRequest(t *testing.T) {
+ for _, tt := range [...]struct {
+ name string
+
+ method, uri string
+ body io.Reader
+
+ want *http.Request
+ wantBody string
+ }{
+ {
+ name: "Empty method means GET",
+ method: "",
+ uri: "/",
+ body: nil,
+ want: &http.Request{
+ Method: "GET",
+ Host: "example.com",
+ URL: &url.URL{Path: "/"},
+ Header: http.Header{},
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ RemoteAddr: "192.0.2.1:1234",
+ RequestURI: "/",
+ },
+ wantBody: "",
+ },
+
+ {
+ name: "GET with full URL",
+ method: "GET",
+ uri: "http://foo.com/path/%2f/bar/",
+ body: nil,
+ want: &http.Request{
+ Method: "GET",
+ Host: "foo.com",
+ URL: &url.URL{
+ Scheme: "http",
+ Path: "/path///bar/",
+ RawPath: "/path/%2f/bar/",
+ Host: "foo.com",
+ },
+ Header: http.Header{},
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ RemoteAddr: "192.0.2.1:1234",
+ RequestURI: "http://foo.com/path/%2f/bar/",
+ },
+ wantBody: "",
+ },
+
+ {
+ name: "GET with full https URL",
+ method: "GET",
+ uri: "https://foo.com/path/",
+ body: nil,
+ want: &http.Request{
+ Method: "GET",
+ Host: "foo.com",
+ URL: &url.URL{
+ Scheme: "https",
+ Path: "/path/",
+ Host: "foo.com",
+ },
+ Header: http.Header{},
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ RemoteAddr: "192.0.2.1:1234",
+ RequestURI: "https://foo.com/path/",
+ TLS: &tls.ConnectionState{
+ Version: tls.VersionTLS12,
+ HandshakeComplete: true,
+ ServerName: "foo.com",
+ },
+ },
+ wantBody: "",
+ },
+
+ {
+ name: "Post with known length",
+ method: "POST",
+ uri: "/",
+ body: strings.NewReader("foo"),
+ want: &http.Request{
+ Method: "POST",
+ Host: "example.com",
+ URL: &url.URL{Path: "/"},
+ Header: http.Header{},
+ Proto: "HTTP/1.1",
+ ContentLength: 3,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ RemoteAddr: "192.0.2.1:1234",
+ RequestURI: "/",
+ },
+ wantBody: "foo",
+ },
+
+ {
+ name: "Post with unknown length",
+ method: "POST",
+ uri: "/",
+ body: struct{ io.Reader }{strings.NewReader("foo")},
+ want: &http.Request{
+ Method: "POST",
+ Host: "example.com",
+ URL: &url.URL{Path: "/"},
+ Header: http.Header{},
+ Proto: "HTTP/1.1",
+ ContentLength: -1,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ RemoteAddr: "192.0.2.1:1234",
+ RequestURI: "/",
+ },
+ wantBody: "foo",
+ },
+
+ {
+ name: "OPTIONS *",
+ method: "OPTIONS",
+ uri: "*",
+ want: &http.Request{
+ Method: "OPTIONS",
+ Host: "example.com",
+ URL: &url.URL{Path: "*"},
+ Header: http.Header{},
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ RemoteAddr: "192.0.2.1:1234",
+ RequestURI: "*",
+ },
+ },
+ } {
+ t.Run(tt.name, func(t *testing.T) {
+ got := NewRequest(tt.method, tt.uri, tt.body)
+ slurp, err := io.ReadAll(got.Body)
+ if err != nil {
+ t.Errorf("ReadAll: %v", err)
+ }
+ if string(slurp) != tt.wantBody {
+ t.Errorf("Body = %q; want %q", slurp, tt.wantBody)
+ }
+ got.Body = nil // before DeepEqual
+ if !reflect.DeepEqual(got.URL, tt.want.URL) {
+ t.Errorf("Request.URL mismatch:\n got: %#v\nwant: %#v", got.URL, tt.want.URL)
+ }
+ if !reflect.DeepEqual(got.Header, tt.want.Header) {
+ t.Errorf("Request.Header mismatch:\n got: %#v\nwant: %#v", got.Header, tt.want.Header)
+ }
+ if !reflect.DeepEqual(got.TLS, tt.want.TLS) {
+ t.Errorf("Request.TLS mismatch:\n got: %#v\nwant: %#v", got.TLS, tt.want.TLS)
+ }
+ if !reflect.DeepEqual(got, tt.want) {
+ t.Errorf("Request mismatch:\n got: %#v\nwant: %#v", got, tt.want)
+ }
+ })
+ }
+}
diff --git a/src/net/http/httptest/recorder.go b/src/net/http/httptest/recorder.go
new file mode 100644
index 0000000..1c1d880
--- /dev/null
+++ b/src/net/http/httptest/recorder.go
@@ -0,0 +1,255 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package httptest
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "net/http"
+ "net/textproto"
+ "strconv"
+ "strings"
+
+ "golang.org/x/net/http/httpguts"
+)
+
+// ResponseRecorder is an implementation of http.ResponseWriter that
+// records its mutations for later inspection in tests.
+type ResponseRecorder struct {
+ // Code is the HTTP response code set by WriteHeader.
+ //
+ // Note that if a Handler never calls WriteHeader or Write,
+ // this might end up being 0, rather than the implicit
+ // http.StatusOK. To get the implicit value, use the Result
+ // method.
+ Code int
+
+ // HeaderMap contains the headers explicitly set by the Handler.
+ // It is an internal detail.
+ //
+ // Deprecated: HeaderMap exists for historical compatibility
+ // and should not be used. To access the headers returned by a handler,
+ // use the Response.Header map as returned by the Result method.
+ HeaderMap http.Header
+
+ // Body is the buffer to which the Handler's Write calls are sent.
+ // If nil, the Writes are silently discarded.
+ Body *bytes.Buffer
+
+ // Flushed is whether the Handler called Flush.
+ Flushed bool
+
+ result *http.Response // cache of Result's return value
+ snapHeader http.Header // snapshot of HeaderMap at first Write
+ wroteHeader bool
+}
+
+// NewRecorder returns an initialized ResponseRecorder.
+func NewRecorder() *ResponseRecorder {
+ return &ResponseRecorder{
+ HeaderMap: make(http.Header),
+ Body: new(bytes.Buffer),
+ Code: 200,
+ }
+}
+
+// DefaultRemoteAddr is the default remote address to return in RemoteAddr if
+// an explicit DefaultRemoteAddr isn't set on ResponseRecorder.
+const DefaultRemoteAddr = "1.2.3.4"
+
+// Header implements http.ResponseWriter. It returns the response
+// headers to mutate within a handler. To test the headers that were
+// written after a handler completes, use the Result method and see
+// the returned Response value's Header.
+func (rw *ResponseRecorder) Header() http.Header {
+ m := rw.HeaderMap
+ if m == nil {
+ m = make(http.Header)
+ rw.HeaderMap = m
+ }
+ return m
+}
+
+// writeHeader writes a header if it was not written yet and
+// detects Content-Type if needed.
+//
+// bytes or str are the beginning of the response body.
+// We pass both to avoid unnecessarily generate garbage
+// in rw.WriteString which was created for performance reasons.
+// Non-nil bytes win.
+func (rw *ResponseRecorder) writeHeader(b []byte, str string) {
+ if rw.wroteHeader {
+ return
+ }
+ if len(str) > 512 {
+ str = str[:512]
+ }
+
+ m := rw.Header()
+
+ _, hasType := m["Content-Type"]
+ hasTE := m.Get("Transfer-Encoding") != ""
+ if !hasType && !hasTE {
+ if b == nil {
+ b = []byte(str)
+ }
+ m.Set("Content-Type", http.DetectContentType(b))
+ }
+
+ rw.WriteHeader(200)
+}
+
+// Write implements http.ResponseWriter. The data in buf is written to
+// rw.Body, if not nil.
+func (rw *ResponseRecorder) Write(buf []byte) (int, error) {
+ rw.writeHeader(buf, "")
+ if rw.Body != nil {
+ rw.Body.Write(buf)
+ }
+ return len(buf), nil
+}
+
+// WriteString implements io.StringWriter. The data in str is written
+// to rw.Body, if not nil.
+func (rw *ResponseRecorder) WriteString(str string) (int, error) {
+ rw.writeHeader(nil, str)
+ if rw.Body != nil {
+ rw.Body.WriteString(str)
+ }
+ return len(str), nil
+}
+
+func checkWriteHeaderCode(code int) {
+ // Issue 22880: require valid WriteHeader status codes.
+ // For now we only enforce that it's three digits.
+ // In the future we might block things over 599 (600 and above aren't defined
+ // at https://httpwg.org/specs/rfc7231.html#status.codes)
+ // and we might block under 200 (once we have more mature 1xx support).
+ // But for now any three digits.
+ //
+ // We used to send "HTTP/1.1 000 0" on the wire in responses but there's
+ // no equivalent bogus thing we can realistically send in HTTP/2,
+ // so we'll consistently panic instead and help people find their bugs
+ // early. (We can't return an error from WriteHeader even if we wanted to.)
+ if code < 100 || code > 999 {
+ panic(fmt.Sprintf("invalid WriteHeader code %v", code))
+ }
+}
+
+// WriteHeader implements http.ResponseWriter.
+func (rw *ResponseRecorder) WriteHeader(code int) {
+ if rw.wroteHeader {
+ return
+ }
+
+ checkWriteHeaderCode(code)
+ rw.Code = code
+ rw.wroteHeader = true
+ if rw.HeaderMap == nil {
+ rw.HeaderMap = make(http.Header)
+ }
+ rw.snapHeader = rw.HeaderMap.Clone()
+}
+
+// Flush implements http.Flusher. To test whether Flush was
+// called, see rw.Flushed.
+func (rw *ResponseRecorder) Flush() {
+ if !rw.wroteHeader {
+ rw.WriteHeader(200)
+ }
+ rw.Flushed = true
+}
+
+// Result returns the response generated by the handler.
+//
+// The returned Response will have at least its StatusCode,
+// Header, Body, and optionally Trailer populated.
+// More fields may be populated in the future, so callers should
+// not DeepEqual the result in tests.
+//
+// The Response.Header is a snapshot of the headers at the time of the
+// first write call, or at the time of this call, if the handler never
+// did a write.
+//
+// The Response.Body is guaranteed to be non-nil and Body.Read call is
+// guaranteed to not return any error other than io.EOF.
+//
+// Result must only be called after the handler has finished running.
+func (rw *ResponseRecorder) Result() *http.Response {
+ if rw.result != nil {
+ return rw.result
+ }
+ if rw.snapHeader == nil {
+ rw.snapHeader = rw.HeaderMap.Clone()
+ }
+ res := &http.Response{
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ StatusCode: rw.Code,
+ Header: rw.snapHeader,
+ }
+ rw.result = res
+ if res.StatusCode == 0 {
+ res.StatusCode = 200
+ }
+ res.Status = fmt.Sprintf("%03d %s", res.StatusCode, http.StatusText(res.StatusCode))
+ if rw.Body != nil {
+ res.Body = io.NopCloser(bytes.NewReader(rw.Body.Bytes()))
+ } else {
+ res.Body = http.NoBody
+ }
+ res.ContentLength = parseContentLength(res.Header.Get("Content-Length"))
+
+ if trailers, ok := rw.snapHeader["Trailer"]; ok {
+ res.Trailer = make(http.Header, len(trailers))
+ for _, k := range trailers {
+ for _, k := range strings.Split(k, ",") {
+ k = http.CanonicalHeaderKey(textproto.TrimString(k))
+ if !httpguts.ValidTrailerHeader(k) {
+ // Ignore since forbidden by RFC 7230, section 4.1.2.
+ continue
+ }
+ vv, ok := rw.HeaderMap[k]
+ if !ok {
+ continue
+ }
+ vv2 := make([]string, len(vv))
+ copy(vv2, vv)
+ res.Trailer[k] = vv2
+ }
+ }
+ }
+ for k, vv := range rw.HeaderMap {
+ if !strings.HasPrefix(k, http.TrailerPrefix) {
+ continue
+ }
+ if res.Trailer == nil {
+ res.Trailer = make(http.Header)
+ }
+ for _, v := range vv {
+ res.Trailer.Add(strings.TrimPrefix(k, http.TrailerPrefix), v)
+ }
+ }
+ return res
+}
+
+// parseContentLength trims whitespace from s and returns -1 if no value
+// is set, or the value if it's >= 0.
+//
+// This a modified version of same function found in net/http/transfer.go. This
+// one just ignores an invalid header.
+func parseContentLength(cl string) int64 {
+ cl = textproto.TrimString(cl)
+ if cl == "" {
+ return -1
+ }
+ n, err := strconv.ParseUint(cl, 10, 63)
+ if err != nil {
+ return -1
+ }
+ return int64(n)
+}
diff --git a/src/net/http/httptest/recorder_test.go b/src/net/http/httptest/recorder_test.go
new file mode 100644
index 0000000..4782ece
--- /dev/null
+++ b/src/net/http/httptest/recorder_test.go
@@ -0,0 +1,371 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package httptest
+
+import (
+ "fmt"
+ "io"
+ "net/http"
+ "testing"
+)
+
+func TestRecorder(t *testing.T) {
+ type checkFunc func(*ResponseRecorder) error
+ check := func(fns ...checkFunc) []checkFunc { return fns }
+
+ hasStatus := func(wantCode int) checkFunc {
+ return func(rec *ResponseRecorder) error {
+ if rec.Code != wantCode {
+ return fmt.Errorf("Status = %d; want %d", rec.Code, wantCode)
+ }
+ return nil
+ }
+ }
+ hasResultStatus := func(want string) checkFunc {
+ return func(rec *ResponseRecorder) error {
+ if rec.Result().Status != want {
+ return fmt.Errorf("Result().Status = %q; want %q", rec.Result().Status, want)
+ }
+ return nil
+ }
+ }
+ hasResultStatusCode := func(wantCode int) checkFunc {
+ return func(rec *ResponseRecorder) error {
+ if rec.Result().StatusCode != wantCode {
+ return fmt.Errorf("Result().StatusCode = %d; want %d", rec.Result().StatusCode, wantCode)
+ }
+ return nil
+ }
+ }
+ hasResultContents := func(want string) checkFunc {
+ return func(rec *ResponseRecorder) error {
+ contentBytes, err := io.ReadAll(rec.Result().Body)
+ if err != nil {
+ return err
+ }
+ contents := string(contentBytes)
+ if contents != want {
+ return fmt.Errorf("Result().Body = %s; want %s", contents, want)
+ }
+ return nil
+ }
+ }
+ hasContents := func(want string) checkFunc {
+ return func(rec *ResponseRecorder) error {
+ if rec.Body.String() != want {
+ return fmt.Errorf("wrote = %q; want %q", rec.Body.String(), want)
+ }
+ return nil
+ }
+ }
+ hasFlush := func(want bool) checkFunc {
+ return func(rec *ResponseRecorder) error {
+ if rec.Flushed != want {
+ return fmt.Errorf("Flushed = %v; want %v", rec.Flushed, want)
+ }
+ return nil
+ }
+ }
+ hasOldHeader := func(key, want string) checkFunc {
+ return func(rec *ResponseRecorder) error {
+ if got := rec.HeaderMap.Get(key); got != want {
+ return fmt.Errorf("HeaderMap header %s = %q; want %q", key, got, want)
+ }
+ return nil
+ }
+ }
+ hasHeader := func(key, want string) checkFunc {
+ return func(rec *ResponseRecorder) error {
+ if got := rec.Result().Header.Get(key); got != want {
+ return fmt.Errorf("final header %s = %q; want %q", key, got, want)
+ }
+ return nil
+ }
+ }
+ hasNotHeaders := func(keys ...string) checkFunc {
+ return func(rec *ResponseRecorder) error {
+ for _, k := range keys {
+ v, ok := rec.Result().Header[http.CanonicalHeaderKey(k)]
+ if ok {
+ return fmt.Errorf("unexpected header %s with value %q", k, v)
+ }
+ }
+ return nil
+ }
+ }
+ hasTrailer := func(key, want string) checkFunc {
+ return func(rec *ResponseRecorder) error {
+ if got := rec.Result().Trailer.Get(key); got != want {
+ return fmt.Errorf("trailer %s = %q; want %q", key, got, want)
+ }
+ return nil
+ }
+ }
+ hasNotTrailers := func(keys ...string) checkFunc {
+ return func(rec *ResponseRecorder) error {
+ trailers := rec.Result().Trailer
+ for _, k := range keys {
+ _, ok := trailers[http.CanonicalHeaderKey(k)]
+ if ok {
+ return fmt.Errorf("unexpected trailer %s", k)
+ }
+ }
+ return nil
+ }
+ }
+ hasContentLength := func(length int64) checkFunc {
+ return func(rec *ResponseRecorder) error {
+ if got := rec.Result().ContentLength; got != length {
+ return fmt.Errorf("ContentLength = %d; want %d", got, length)
+ }
+ return nil
+ }
+ }
+
+ for _, tt := range [...]struct {
+ name string
+ h func(w http.ResponseWriter, r *http.Request)
+ checks []checkFunc
+ }{
+ {
+ "200 default",
+ func(w http.ResponseWriter, r *http.Request) {},
+ check(hasStatus(200), hasContents("")),
+ },
+ {
+ "first code only",
+ func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(201)
+ w.WriteHeader(202)
+ w.Write([]byte("hi"))
+ },
+ check(hasStatus(201), hasContents("hi")),
+ },
+ {
+ "write sends 200",
+ func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte("hi first"))
+ w.WriteHeader(201)
+ w.WriteHeader(202)
+ },
+ check(hasStatus(200), hasContents("hi first"), hasFlush(false)),
+ },
+ {
+ "write string",
+ func(w http.ResponseWriter, r *http.Request) {
+ io.WriteString(w, "hi first")
+ },
+ check(
+ hasStatus(200),
+ hasContents("hi first"),
+ hasFlush(false),
+ hasHeader("Content-Type", "text/plain; charset=utf-8"),
+ ),
+ },
+ {
+ "flush",
+ func(w http.ResponseWriter, r *http.Request) {
+ w.(http.Flusher).Flush() // also sends a 200
+ w.WriteHeader(201)
+ },
+ check(hasStatus(200), hasFlush(true), hasContentLength(-1)),
+ },
+ {
+ "Content-Type detection",
+ func(w http.ResponseWriter, r *http.Request) {
+ io.WriteString(w, "<html>")
+ },
+ check(hasHeader("Content-Type", "text/html; charset=utf-8")),
+ },
+ {
+ "no Content-Type detection with Transfer-Encoding",
+ func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Transfer-Encoding", "some encoding")
+ io.WriteString(w, "<html>")
+ },
+ check(hasHeader("Content-Type", "")), // no header
+ },
+ {
+ "no Content-Type detection if set explicitly",
+ func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "some/type")
+ io.WriteString(w, "<html>")
+ },
+ check(hasHeader("Content-Type", "some/type")),
+ },
+ {
+ "Content-Type detection doesn't crash if HeaderMap is nil",
+ func(w http.ResponseWriter, r *http.Request) {
+ // Act as if the user wrote new(httptest.ResponseRecorder)
+ // rather than using NewRecorder (which initializes
+ // HeaderMap)
+ w.(*ResponseRecorder).HeaderMap = nil
+ io.WriteString(w, "<html>")
+ },
+ check(hasHeader("Content-Type", "text/html; charset=utf-8")),
+ },
+ {
+ "Header is not changed after write",
+ func(w http.ResponseWriter, r *http.Request) {
+ hdr := w.Header()
+ hdr.Set("Key", "correct")
+ w.WriteHeader(200)
+ hdr.Set("Key", "incorrect")
+ },
+ check(hasHeader("Key", "correct")),
+ },
+ {
+ "Trailer headers are correctly recorded",
+ func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Non-Trailer", "correct")
+ w.Header().Set("Trailer", "Trailer-A, Trailer-B")
+ w.Header().Add("Trailer", "Trailer-C")
+ io.WriteString(w, "<html>")
+ w.Header().Set("Non-Trailer", "incorrect")
+ w.Header().Set("Trailer-A", "valuea")
+ w.Header().Set("Trailer-C", "valuec")
+ w.Header().Set("Trailer-NotDeclared", "should be omitted")
+ w.Header().Set("Trailer:Trailer-D", "with prefix")
+ },
+ check(
+ hasStatus(200),
+ hasHeader("Content-Type", "text/html; charset=utf-8"),
+ hasHeader("Non-Trailer", "correct"),
+ hasNotHeaders("Trailer-A", "Trailer-B", "Trailer-C", "Trailer-NotDeclared"),
+ hasTrailer("Trailer-A", "valuea"),
+ hasTrailer("Trailer-C", "valuec"),
+ hasNotTrailers("Non-Trailer", "Trailer-B", "Trailer-NotDeclared"),
+ hasTrailer("Trailer-D", "with prefix"),
+ ),
+ },
+ {
+ "Header set without any write", // Issue 15560
+ func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("X-Foo", "1")
+
+ // Simulate somebody using
+ // new(ResponseRecorder) instead of
+ // using the constructor which sets
+ // this to 200
+ w.(*ResponseRecorder).Code = 0
+ },
+ check(
+ hasOldHeader("X-Foo", "1"),
+ hasStatus(0),
+ hasHeader("X-Foo", "1"),
+ hasResultStatus("200 OK"),
+ hasResultStatusCode(200),
+ ),
+ },
+ {
+ "HeaderMap vs FinalHeaders", // more for Issue 15560
+ func(w http.ResponseWriter, r *http.Request) {
+ h := w.Header()
+ h.Set("X-Foo", "1")
+ w.Write([]byte("hi"))
+ h.Set("X-Foo", "2")
+ h.Set("X-Bar", "2")
+ },
+ check(
+ hasOldHeader("X-Foo", "2"),
+ hasOldHeader("X-Bar", "2"),
+ hasHeader("X-Foo", "1"),
+ hasNotHeaders("X-Bar"),
+ ),
+ },
+ {
+ "setting Content-Length header",
+ func(w http.ResponseWriter, r *http.Request) {
+ body := "Some body"
+ contentLength := fmt.Sprintf("%d", len(body))
+ w.Header().Set("Content-Length", contentLength)
+ io.WriteString(w, body)
+ },
+ check(hasStatus(200), hasContents("Some body"), hasContentLength(9)),
+ },
+ {
+ "nil ResponseRecorder.Body", // Issue 26642
+ func(w http.ResponseWriter, r *http.Request) {
+ w.(*ResponseRecorder).Body = nil
+ io.WriteString(w, "hi")
+ },
+ check(hasResultContents("")), // check we don't crash reading the body
+
+ },
+ } {
+ t.Run(tt.name, func(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://foo.com/", nil)
+ h := http.HandlerFunc(tt.h)
+ rec := NewRecorder()
+ h.ServeHTTP(rec, r)
+ for _, check := range tt.checks {
+ if err := check(rec); err != nil {
+ t.Error(err)
+ }
+ }
+ })
+ }
+}
+
+// issue 39017 - disallow Content-Length values such as "+3"
+func TestParseContentLength(t *testing.T) {
+ tests := []struct {
+ cl string
+ want int64
+ }{
+ {
+ cl: "3",
+ want: 3,
+ },
+ {
+ cl: "+3",
+ want: -1,
+ },
+ {
+ cl: "-3",
+ want: -1,
+ },
+ {
+ // max int64, for safe conversion before returning
+ cl: "9223372036854775807",
+ want: 9223372036854775807,
+ },
+ {
+ cl: "9223372036854775808",
+ want: -1,
+ },
+ }
+
+ for _, tt := range tests {
+ if got := parseContentLength(tt.cl); got != tt.want {
+ t.Errorf("%q:\n\tgot=%d\n\twant=%d", tt.cl, got, tt.want)
+ }
+ }
+}
+
+// Ensure that httptest.Recorder panics when given a non-3 digit (XXX)
+// status HTTP code. See https://golang.org/issues/45353
+func TestRecorderPanicsOnNonXXXStatusCode(t *testing.T) {
+ badCodes := []int{
+ -100, 0, 99, 1000, 20000,
+ }
+ for _, badCode := range badCodes {
+ badCode := badCode
+ t.Run(fmt.Sprintf("Code=%d", badCode), func(t *testing.T) {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Fatal("Expected a panic")
+ }
+ }()
+
+ handler := func(rw http.ResponseWriter, _ *http.Request) {
+ rw.WriteHeader(badCode)
+ }
+ r, _ := http.NewRequest("GET", "http://example.org/", nil)
+ rw := NewRecorder()
+ handler(rw, r)
+ })
+ }
+}
diff --git a/src/net/http/httptest/server.go b/src/net/http/httptest/server.go
new file mode 100644
index 0000000..f254a49
--- /dev/null
+++ b/src/net/http/httptest/server.go
@@ -0,0 +1,385 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Implementation of Server
+
+package httptest
+
+import (
+ "crypto/tls"
+ "crypto/x509"
+ "flag"
+ "fmt"
+ "log"
+ "net"
+ "net/http"
+ "net/http/internal/testcert"
+ "os"
+ "strings"
+ "sync"
+ "time"
+)
+
+// A Server is an HTTP server listening on a system-chosen port on the
+// local loopback interface, for use in end-to-end HTTP tests.
+type Server struct {
+ URL string // base URL of form http://ipaddr:port with no trailing slash
+ Listener net.Listener
+
+ // EnableHTTP2 controls whether HTTP/2 is enabled
+ // on the server. It must be set between calling
+ // NewUnstartedServer and calling Server.StartTLS.
+ EnableHTTP2 bool
+
+ // TLS is the optional TLS configuration, populated with a new config
+ // after TLS is started. If set on an unstarted server before StartTLS
+ // is called, existing fields are copied into the new config.
+ TLS *tls.Config
+
+ // Config may be changed after calling NewUnstartedServer and
+ // before Start or StartTLS.
+ Config *http.Server
+
+ // certificate is a parsed version of the TLS config certificate, if present.
+ certificate *x509.Certificate
+
+ // wg counts the number of outstanding HTTP requests on this server.
+ // Close blocks until all requests are finished.
+ wg sync.WaitGroup
+
+ mu sync.Mutex // guards closed and conns
+ closed bool
+ conns map[net.Conn]http.ConnState // except terminal states
+
+ // client is configured for use with the server.
+ // Its transport is automatically closed when Close is called.
+ client *http.Client
+}
+
+func newLocalListener() net.Listener {
+ if serveFlag != "" {
+ l, err := net.Listen("tcp", serveFlag)
+ if err != nil {
+ panic(fmt.Sprintf("httptest: failed to listen on %v: %v", serveFlag, err))
+ }
+ return l
+ }
+ l, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ if l, err = net.Listen("tcp6", "[::1]:0"); err != nil {
+ panic(fmt.Sprintf("httptest: failed to listen on a port: %v", err))
+ }
+ }
+ return l
+}
+
+// When debugging a particular http server-based test,
+// this flag lets you run
+//
+// go test -run=BrokenTest -httptest.serve=127.0.0.1:8000
+//
+// to start the broken server so you can interact with it manually.
+// We only register this flag if it looks like the caller knows about it
+// and is trying to use it as we don't want to pollute flags and this
+// isn't really part of our API. Don't depend on this.
+var serveFlag string
+
+func init() {
+ if strSliceContainsPrefix(os.Args, "-httptest.serve=") || strSliceContainsPrefix(os.Args, "--httptest.serve=") {
+ flag.StringVar(&serveFlag, "httptest.serve", "", "if non-empty, httptest.NewServer serves on this address and blocks.")
+ }
+}
+
+func strSliceContainsPrefix(v []string, pre string) bool {
+ for _, s := range v {
+ if strings.HasPrefix(s, pre) {
+ return true
+ }
+ }
+ return false
+}
+
+// NewServer starts and returns a new Server.
+// The caller should call Close when finished, to shut it down.
+func NewServer(handler http.Handler) *Server {
+ ts := NewUnstartedServer(handler)
+ ts.Start()
+ return ts
+}
+
+// NewUnstartedServer returns a new Server but doesn't start it.
+//
+// After changing its configuration, the caller should call Start or
+// StartTLS.
+//
+// The caller should call Close when finished, to shut it down.
+func NewUnstartedServer(handler http.Handler) *Server {
+ return &Server{
+ Listener: newLocalListener(),
+ Config: &http.Server{Handler: handler},
+ }
+}
+
+// Start starts a server from NewUnstartedServer.
+func (s *Server) Start() {
+ if s.URL != "" {
+ panic("Server already started")
+ }
+ if s.client == nil {
+ s.client = &http.Client{Transport: &http.Transport{}}
+ }
+ s.URL = "http://" + s.Listener.Addr().String()
+ s.wrap()
+ s.goServe()
+ if serveFlag != "" {
+ fmt.Fprintln(os.Stderr, "httptest: serving on", s.URL)
+ select {}
+ }
+}
+
+// StartTLS starts TLS on a server from NewUnstartedServer.
+func (s *Server) StartTLS() {
+ if s.URL != "" {
+ panic("Server already started")
+ }
+ if s.client == nil {
+ s.client = &http.Client{Transport: &http.Transport{}}
+ }
+ cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
+ if err != nil {
+ panic(fmt.Sprintf("httptest: NewTLSServer: %v", err))
+ }
+
+ existingConfig := s.TLS
+ if existingConfig != nil {
+ s.TLS = existingConfig.Clone()
+ } else {
+ s.TLS = new(tls.Config)
+ }
+ if s.TLS.NextProtos == nil {
+ nextProtos := []string{"http/1.1"}
+ if s.EnableHTTP2 {
+ nextProtos = []string{"h2"}
+ }
+ s.TLS.NextProtos = nextProtos
+ }
+ if len(s.TLS.Certificates) == 0 {
+ s.TLS.Certificates = []tls.Certificate{cert}
+ }
+ s.certificate, err = x509.ParseCertificate(s.TLS.Certificates[0].Certificate[0])
+ if err != nil {
+ panic(fmt.Sprintf("httptest: NewTLSServer: %v", err))
+ }
+ certpool := x509.NewCertPool()
+ certpool.AddCert(s.certificate)
+ s.client.Transport = &http.Transport{
+ TLSClientConfig: &tls.Config{
+ RootCAs: certpool,
+ },
+ ForceAttemptHTTP2: s.EnableHTTP2,
+ }
+ s.Listener = tls.NewListener(s.Listener, s.TLS)
+ s.URL = "https://" + s.Listener.Addr().String()
+ s.wrap()
+ s.goServe()
+}
+
+// NewTLSServer starts and returns a new Server using TLS.
+// The caller should call Close when finished, to shut it down.
+func NewTLSServer(handler http.Handler) *Server {
+ ts := NewUnstartedServer(handler)
+ ts.StartTLS()
+ return ts
+}
+
+type closeIdleTransport interface {
+ CloseIdleConnections()
+}
+
+// Close shuts down the server and blocks until all outstanding
+// requests on this server have completed.
+func (s *Server) Close() {
+ s.mu.Lock()
+ if !s.closed {
+ s.closed = true
+ s.Listener.Close()
+ s.Config.SetKeepAlivesEnabled(false)
+ for c, st := range s.conns {
+ // Force-close any idle connections (those between
+ // requests) and new connections (those which connected
+ // but never sent a request). StateNew connections are
+ // super rare and have only been seen (in
+ // previously-flaky tests) in the case of
+ // socket-late-binding races from the http Client
+ // dialing this server and then getting an idle
+ // connection before the dial completed. There is thus
+ // a connected connection in StateNew with no
+ // associated Request. We only close StateIdle and
+ // StateNew because they're not doing anything. It's
+ // possible StateNew is about to do something in a few
+ // milliseconds, but a previous CL to check again in a
+ // few milliseconds wasn't liked (early versions of
+ // https://golang.org/cl/15151) so now we just
+ // forcefully close StateNew. The docs for Server.Close say
+ // we wait for "outstanding requests", so we don't close things
+ // in StateActive.
+ if st == http.StateIdle || st == http.StateNew {
+ s.closeConn(c)
+ }
+ }
+ // If this server doesn't shut down in 5 seconds, tell the user why.
+ t := time.AfterFunc(5*time.Second, s.logCloseHangDebugInfo)
+ defer t.Stop()
+ }
+ s.mu.Unlock()
+
+ // Not part of httptest.Server's correctness, but assume most
+ // users of httptest.Server will be using the standard
+ // transport, so help them out and close any idle connections for them.
+ if t, ok := http.DefaultTransport.(closeIdleTransport); ok {
+ t.CloseIdleConnections()
+ }
+
+ // Also close the client idle connections.
+ if s.client != nil {
+ if t, ok := s.client.Transport.(closeIdleTransport); ok {
+ t.CloseIdleConnections()
+ }
+ }
+
+ s.wg.Wait()
+}
+
+func (s *Server) logCloseHangDebugInfo() {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ var buf strings.Builder
+ buf.WriteString("httptest.Server blocked in Close after 5 seconds, waiting for connections:\n")
+ for c, st := range s.conns {
+ fmt.Fprintf(&buf, " %T %p %v in state %v\n", c, c, c.RemoteAddr(), st)
+ }
+ log.Print(buf.String())
+}
+
+// CloseClientConnections closes any open HTTP connections to the test Server.
+func (s *Server) CloseClientConnections() {
+ s.mu.Lock()
+ nconn := len(s.conns)
+ ch := make(chan struct{}, nconn)
+ for c := range s.conns {
+ go s.closeConnChan(c, ch)
+ }
+ s.mu.Unlock()
+
+ // Wait for outstanding closes to finish.
+ //
+ // Out of paranoia for making a late change in Go 1.6, we
+ // bound how long this can wait, since golang.org/issue/14291
+ // isn't fully understood yet. At least this should only be used
+ // in tests.
+ timer := time.NewTimer(5 * time.Second)
+ defer timer.Stop()
+ for i := 0; i < nconn; i++ {
+ select {
+ case <-ch:
+ case <-timer.C:
+ // Too slow. Give up.
+ return
+ }
+ }
+}
+
+// Certificate returns the certificate used by the server, or nil if
+// the server doesn't use TLS.
+func (s *Server) Certificate() *x509.Certificate {
+ return s.certificate
+}
+
+// Client returns an HTTP client configured for making requests to the server.
+// It is configured to trust the server's TLS test certificate and will
+// close its idle connections on Server.Close.
+func (s *Server) Client() *http.Client {
+ return s.client
+}
+
+func (s *Server) goServe() {
+ s.wg.Add(1)
+ go func() {
+ defer s.wg.Done()
+ s.Config.Serve(s.Listener)
+ }()
+}
+
+// wrap installs the connection state-tracking hook to know which
+// connections are idle.
+func (s *Server) wrap() {
+ oldHook := s.Config.ConnState
+ s.Config.ConnState = func(c net.Conn, cs http.ConnState) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ switch cs {
+ case http.StateNew:
+ if _, exists := s.conns[c]; exists {
+ panic("invalid state transition")
+ }
+ if s.conns == nil {
+ s.conns = make(map[net.Conn]http.ConnState)
+ }
+ // Add c to the set of tracked conns and increment it to the
+ // waitgroup.
+ s.wg.Add(1)
+ s.conns[c] = cs
+ if s.closed {
+ // Probably just a socket-late-binding dial from
+ // the default transport that lost the race (and
+ // thus this connection is now idle and will
+ // never be used).
+ s.closeConn(c)
+ }
+ case http.StateActive:
+ if oldState, ok := s.conns[c]; ok {
+ if oldState != http.StateNew && oldState != http.StateIdle {
+ panic("invalid state transition")
+ }
+ s.conns[c] = cs
+ }
+ case http.StateIdle:
+ if oldState, ok := s.conns[c]; ok {
+ if oldState != http.StateActive {
+ panic("invalid state transition")
+ }
+ s.conns[c] = cs
+ }
+ if s.closed {
+ s.closeConn(c)
+ }
+ case http.StateHijacked, http.StateClosed:
+ // Remove c from the set of tracked conns and decrement it from the
+ // waitgroup, unless it was previously removed.
+ if _, ok := s.conns[c]; ok {
+ delete(s.conns, c)
+ // Keep Close from returning until the user's ConnState hook
+ // (if any) finishes.
+ defer s.wg.Done()
+ }
+ }
+ if oldHook != nil {
+ oldHook(c, cs)
+ }
+ }
+}
+
+// closeConn closes c.
+// s.mu must be held.
+func (s *Server) closeConn(c net.Conn) { s.closeConnChan(c, nil) }
+
+// closeConnChan is like closeConn, but takes an optional channel to receive a value
+// when the goroutine closing c is done.
+func (s *Server) closeConnChan(c net.Conn, done chan<- struct{}) {
+ c.Close()
+ if done != nil {
+ done <- struct{}{}
+ }
+}
diff --git a/src/net/http/httptest/server_test.go b/src/net/http/httptest/server_test.go
new file mode 100644
index 0000000..5313f65
--- /dev/null
+++ b/src/net/http/httptest/server_test.go
@@ -0,0 +1,294 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package httptest
+
+import (
+ "bufio"
+ "io"
+ "net"
+ "net/http"
+ "sync"
+ "testing"
+)
+
+type newServerFunc func(http.Handler) *Server
+
+var newServers = map[string]newServerFunc{
+ "NewServer": NewServer,
+ "NewTLSServer": NewTLSServer,
+
+ // The manual variants of newServer create a Server manually by only filling
+ // in the exported fields of Server.
+ "NewServerManual": func(h http.Handler) *Server {
+ ts := &Server{Listener: newLocalListener(), Config: &http.Server{Handler: h}}
+ ts.Start()
+ return ts
+ },
+ "NewTLSServerManual": func(h http.Handler) *Server {
+ ts := &Server{Listener: newLocalListener(), Config: &http.Server{Handler: h}}
+ ts.StartTLS()
+ return ts
+ },
+}
+
+func TestServer(t *testing.T) {
+ for _, name := range []string{"NewServer", "NewServerManual"} {
+ t.Run(name, func(t *testing.T) {
+ newServer := newServers[name]
+ t.Run("Server", func(t *testing.T) { testServer(t, newServer) })
+ t.Run("GetAfterClose", func(t *testing.T) { testGetAfterClose(t, newServer) })
+ t.Run("ServerCloseBlocking", func(t *testing.T) { testServerCloseBlocking(t, newServer) })
+ t.Run("ServerCloseClientConnections", func(t *testing.T) { testServerCloseClientConnections(t, newServer) })
+ t.Run("ServerClientTransportType", func(t *testing.T) { testServerClientTransportType(t, newServer) })
+ })
+ }
+ for _, name := range []string{"NewTLSServer", "NewTLSServerManual"} {
+ t.Run(name, func(t *testing.T) {
+ newServer := newServers[name]
+ t.Run("ServerClient", func(t *testing.T) { testServerClient(t, newServer) })
+ t.Run("TLSServerClientTransportType", func(t *testing.T) { testTLSServerClientTransportType(t, newServer) })
+ })
+ }
+}
+
+func testServer(t *testing.T, newServer newServerFunc) {
+ ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte("hello"))
+ }))
+ defer ts.Close()
+ res, err := http.Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ got, err := io.ReadAll(res.Body)
+ res.Body.Close()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if string(got) != "hello" {
+ t.Errorf("got %q, want hello", string(got))
+ }
+}
+
+// Issue 12781
+func testGetAfterClose(t *testing.T, newServer newServerFunc) {
+ ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte("hello"))
+ }))
+
+ res, err := http.Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ got, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if string(got) != "hello" {
+ t.Fatalf("got %q, want hello", string(got))
+ }
+
+ ts.Close()
+
+ res, err = http.Get(ts.URL)
+ if err == nil {
+ body, _ := io.ReadAll(res.Body)
+ t.Fatalf("Unexpected response after close: %v, %v, %s", res.Status, res.Header, body)
+ }
+}
+
+func testServerCloseBlocking(t *testing.T, newServer newServerFunc) {
+ ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte("hello"))
+ }))
+ dial := func() net.Conn {
+ c, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ return c
+ }
+
+ // Keep one connection in StateNew (connected, but not sending anything)
+ cnew := dial()
+ defer cnew.Close()
+
+ // Keep one connection in StateIdle (idle after a request)
+ cidle := dial()
+ defer cidle.Close()
+ cidle.Write([]byte("HEAD / HTTP/1.1\r\nHost: foo\r\n\r\n"))
+ _, err := http.ReadResponse(bufio.NewReader(cidle), nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ ts.Close() // test we don't hang here forever.
+}
+
+// Issue 14290
+func testServerCloseClientConnections(t *testing.T, newServer newServerFunc) {
+ var s *Server
+ s = newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ s.CloseClientConnections()
+ }))
+ defer s.Close()
+ res, err := http.Get(s.URL)
+ if err == nil {
+ res.Body.Close()
+ t.Fatalf("Unexpected response: %#v", res)
+ }
+}
+
+// Tests that the Server.Client method works and returns an http.Client that can hit
+// NewTLSServer without cert warnings.
+func testServerClient(t *testing.T, newTLSServer newServerFunc) {
+ ts := newTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte("hello"))
+ }))
+ defer ts.Close()
+ client := ts.Client()
+ res, err := client.Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ got, err := io.ReadAll(res.Body)
+ res.Body.Close()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if string(got) != "hello" {
+ t.Errorf("got %q, want hello", string(got))
+ }
+}
+
+// Tests that the Server.Client.Transport interface is implemented
+// by a *http.Transport.
+func testServerClientTransportType(t *testing.T, newServer newServerFunc) {
+ ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ }))
+ defer ts.Close()
+ client := ts.Client()
+ if _, ok := client.Transport.(*http.Transport); !ok {
+ t.Errorf("got %T, want *http.Transport", client.Transport)
+ }
+}
+
+// Tests that the TLS Server.Client.Transport interface is implemented
+// by a *http.Transport.
+func testTLSServerClientTransportType(t *testing.T, newTLSServer newServerFunc) {
+ ts := newTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ }))
+ defer ts.Close()
+ client := ts.Client()
+ if _, ok := client.Transport.(*http.Transport); !ok {
+ t.Errorf("got %T, want *http.Transport", client.Transport)
+ }
+}
+
+type onlyCloseListener struct {
+ net.Listener
+}
+
+func (onlyCloseListener) Close() error { return nil }
+
+// Issue 19729: panic in Server.Close for values created directly
+// without a constructor (so the unexported client field is nil).
+func TestServerZeroValueClose(t *testing.T) {
+ ts := &Server{
+ Listener: onlyCloseListener{},
+ Config: &http.Server{},
+ }
+
+ ts.Close() // tests that it doesn't panic
+}
+
+// Issue 51799: test hijacking a connection and then closing it
+// concurrently with closing the server.
+func TestCloseHijackedConnection(t *testing.T) {
+ hijacked := make(chan net.Conn)
+ ts := NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ defer close(hijacked)
+ hj, ok := w.(http.Hijacker)
+ if !ok {
+ t.Fatal("failed to hijack")
+ }
+ c, _, err := hj.Hijack()
+ if err != nil {
+ t.Fatal(err)
+ }
+ hijacked <- c
+ }))
+
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ req, err := http.NewRequest("GET", ts.URL, nil)
+ if err != nil {
+ t.Log(err)
+ }
+ // Use a client not associated with the Server.
+ var c http.Client
+ resp, err := c.Do(req)
+ if err != nil {
+ t.Log(err)
+ return
+ }
+ resp.Body.Close()
+ }()
+
+ wg.Add(1)
+ conn := <-hijacked
+ go func(conn net.Conn) {
+ defer wg.Done()
+ // Close the connection and then inform the Server that
+ // we closed it.
+ conn.Close()
+ ts.Config.ConnState(conn, http.StateClosed)
+ }(conn)
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ ts.Close()
+ }()
+ wg.Wait()
+}
+
+func TestTLSServerWithHTTP2(t *testing.T) {
+ modes := []struct {
+ name string
+ wantProto string
+ }{
+ {"http1", "HTTP/1.1"},
+ {"http2", "HTTP/2.0"},
+ }
+
+ for _, tt := range modes {
+ t.Run(tt.name, func(t *testing.T) {
+ cst := NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("X-Proto", r.Proto)
+ }))
+
+ switch tt.name {
+ case "http2":
+ cst.EnableHTTP2 = true
+ cst.StartTLS()
+ default:
+ cst.Start()
+ }
+
+ defer cst.Close()
+
+ res, err := cst.Client().Get(cst.URL)
+ if err != nil {
+ t.Fatalf("Failed to make request: %v", err)
+ }
+ if g, w := res.Header.Get("X-Proto"), tt.wantProto; g != w {
+ t.Fatalf("X-Proto header mismatch:\n\tgot: %q\n\twant: %q", g, w)
+ }
+ })
+ }
+}
diff --git a/src/net/http/httptrace/example_test.go b/src/net/http/httptrace/example_test.go
new file mode 100644
index 0000000..07fdc0a
--- /dev/null
+++ b/src/net/http/httptrace/example_test.go
@@ -0,0 +1,29 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package httptrace_test
+
+import (
+ "fmt"
+ "log"
+ "net/http"
+ "net/http/httptrace"
+)
+
+func Example() {
+ req, _ := http.NewRequest("GET", "http://example.com", nil)
+ trace := &httptrace.ClientTrace{
+ GotConn: func(connInfo httptrace.GotConnInfo) {
+ fmt.Printf("Got Conn: %+v\n", connInfo)
+ },
+ DNSDone: func(dnsInfo httptrace.DNSDoneInfo) {
+ fmt.Printf("DNS Info: %+v\n", dnsInfo)
+ },
+ }
+ req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
+ _, err := http.DefaultTransport.RoundTrip(req)
+ if err != nil {
+ log.Fatal(err)
+ }
+}
diff --git a/src/net/http/httptrace/trace.go b/src/net/http/httptrace/trace.go
new file mode 100644
index 0000000..6af30f7
--- /dev/null
+++ b/src/net/http/httptrace/trace.go
@@ -0,0 +1,255 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package httptrace provides mechanisms to trace the events within
+// HTTP client requests.
+package httptrace
+
+import (
+ "context"
+ "crypto/tls"
+ "internal/nettrace"
+ "net"
+ "net/textproto"
+ "reflect"
+ "time"
+)
+
+// unique type to prevent assignment.
+type clientEventContextKey struct{}
+
+// ContextClientTrace returns the ClientTrace associated with the
+// provided context. If none, it returns nil.
+func ContextClientTrace(ctx context.Context) *ClientTrace {
+ trace, _ := ctx.Value(clientEventContextKey{}).(*ClientTrace)
+ return trace
+}
+
+// WithClientTrace returns a new context based on the provided parent
+// ctx. HTTP client requests made with the returned context will use
+// the provided trace hooks, in addition to any previous hooks
+// registered with ctx. Any hooks defined in the provided trace will
+// be called first.
+func WithClientTrace(ctx context.Context, trace *ClientTrace) context.Context {
+ if trace == nil {
+ panic("nil trace")
+ }
+ old := ContextClientTrace(ctx)
+ trace.compose(old)
+
+ ctx = context.WithValue(ctx, clientEventContextKey{}, trace)
+ if trace.hasNetHooks() {
+ nt := &nettrace.Trace{
+ ConnectStart: trace.ConnectStart,
+ ConnectDone: trace.ConnectDone,
+ }
+ if trace.DNSStart != nil {
+ nt.DNSStart = func(name string) {
+ trace.DNSStart(DNSStartInfo{Host: name})
+ }
+ }
+ if trace.DNSDone != nil {
+ nt.DNSDone = func(netIPs []any, coalesced bool, err error) {
+ addrs := make([]net.IPAddr, len(netIPs))
+ for i, ip := range netIPs {
+ addrs[i] = ip.(net.IPAddr)
+ }
+ trace.DNSDone(DNSDoneInfo{
+ Addrs: addrs,
+ Coalesced: coalesced,
+ Err: err,
+ })
+ }
+ }
+ ctx = context.WithValue(ctx, nettrace.TraceKey{}, nt)
+ }
+ return ctx
+}
+
+// ClientTrace is a set of hooks to run at various stages of an outgoing
+// HTTP request. Any particular hook may be nil. Functions may be
+// called concurrently from different goroutines and some may be called
+// after the request has completed or failed.
+//
+// ClientTrace currently traces a single HTTP request & response
+// during a single round trip and has no hooks that span a series
+// of redirected requests.
+//
+// See https://blog.golang.org/http-tracing for more.
+type ClientTrace struct {
+ // GetConn is called before a connection is created or
+ // retrieved from an idle pool. The hostPort is the
+ // "host:port" of the target or proxy. GetConn is called even
+ // if there's already an idle cached connection available.
+ GetConn func(hostPort string)
+
+ // GotConn is called after a successful connection is
+ // obtained. There is no hook for failure to obtain a
+ // connection; instead, use the error from
+ // Transport.RoundTrip.
+ GotConn func(GotConnInfo)
+
+ // PutIdleConn is called when the connection is returned to
+ // the idle pool. If err is nil, the connection was
+ // successfully returned to the idle pool. If err is non-nil,
+ // it describes why not. PutIdleConn is not called if
+ // connection reuse is disabled via Transport.DisableKeepAlives.
+ // PutIdleConn is called before the caller's Response.Body.Close
+ // call returns.
+ // For HTTP/2, this hook is not currently used.
+ PutIdleConn func(err error)
+
+ // GotFirstResponseByte is called when the first byte of the response
+ // headers is available.
+ GotFirstResponseByte func()
+
+ // Got100Continue is called if the server replies with a "100
+ // Continue" response.
+ Got100Continue func()
+
+ // Got1xxResponse is called for each 1xx informational response header
+ // returned before the final non-1xx response. Got1xxResponse is called
+ // for "100 Continue" responses, even if Got100Continue is also defined.
+ // If it returns an error, the client request is aborted with that error value.
+ Got1xxResponse func(code int, header textproto.MIMEHeader) error
+
+ // DNSStart is called when a DNS lookup begins.
+ DNSStart func(DNSStartInfo)
+
+ // DNSDone is called when a DNS lookup ends.
+ DNSDone func(DNSDoneInfo)
+
+ // ConnectStart is called when a new connection's Dial begins.
+ // If net.Dialer.DualStack (IPv6 "Happy Eyeballs") support is
+ // enabled, this may be called multiple times.
+ ConnectStart func(network, addr string)
+
+ // ConnectDone is called when a new connection's Dial
+ // completes. The provided err indicates whether the
+ // connection completed successfully.
+ // If net.Dialer.DualStack ("Happy Eyeballs") support is
+ // enabled, this may be called multiple times.
+ ConnectDone func(network, addr string, err error)
+
+ // TLSHandshakeStart is called when the TLS handshake is started. When
+ // connecting to an HTTPS site via an HTTP proxy, the handshake happens
+ // after the CONNECT request is processed by the proxy.
+ TLSHandshakeStart func()
+
+ // TLSHandshakeDone is called after the TLS handshake with either the
+ // successful handshake's connection state, or a non-nil error on handshake
+ // failure.
+ TLSHandshakeDone func(tls.ConnectionState, error)
+
+ // WroteHeaderField is called after the Transport has written
+ // each request header. At the time of this call the values
+ // might be buffered and not yet written to the network.
+ WroteHeaderField func(key string, value []string)
+
+ // WroteHeaders is called after the Transport has written
+ // all request headers.
+ WroteHeaders func()
+
+ // Wait100Continue is called if the Request specified
+ // "Expect: 100-continue" and the Transport has written the
+ // request headers but is waiting for "100 Continue" from the
+ // server before writing the request body.
+ Wait100Continue func()
+
+ // WroteRequest is called with the result of writing the
+ // request and any body. It may be called multiple times
+ // in the case of retried requests.
+ WroteRequest func(WroteRequestInfo)
+}
+
+// WroteRequestInfo contains information provided to the WroteRequest
+// hook.
+type WroteRequestInfo struct {
+ // Err is any error encountered while writing the Request.
+ Err error
+}
+
+// compose modifies t such that it respects the previously-registered hooks in old,
+// subject to the composition policy requested in t.Compose.
+func (t *ClientTrace) compose(old *ClientTrace) {
+ if old == nil {
+ return
+ }
+ tv := reflect.ValueOf(t).Elem()
+ ov := reflect.ValueOf(old).Elem()
+ structType := tv.Type()
+ for i := 0; i < structType.NumField(); i++ {
+ tf := tv.Field(i)
+ hookType := tf.Type()
+ if hookType.Kind() != reflect.Func {
+ continue
+ }
+ of := ov.Field(i)
+ if of.IsNil() {
+ continue
+ }
+ if tf.IsNil() {
+ tf.Set(of)
+ continue
+ }
+
+ // Make a copy of tf for tf to call. (Otherwise it
+ // creates a recursive call cycle and stack overflows)
+ tfCopy := reflect.ValueOf(tf.Interface())
+
+ // We need to call both tf and of in some order.
+ newFunc := reflect.MakeFunc(hookType, func(args []reflect.Value) []reflect.Value {
+ tfCopy.Call(args)
+ return of.Call(args)
+ })
+ tv.Field(i).Set(newFunc)
+ }
+}
+
+// DNSStartInfo contains information about a DNS request.
+type DNSStartInfo struct {
+ Host string
+}
+
+// DNSDoneInfo contains information about the results of a DNS lookup.
+type DNSDoneInfo struct {
+ // Addrs are the IPv4 and/or IPv6 addresses found in the DNS
+ // lookup. The contents of the slice should not be mutated.
+ Addrs []net.IPAddr
+
+ // Err is any error that occurred during the DNS lookup.
+ Err error
+
+ // Coalesced is whether the Addrs were shared with another
+ // caller who was doing the same DNS lookup concurrently.
+ Coalesced bool
+}
+
+func (t *ClientTrace) hasNetHooks() bool {
+ if t == nil {
+ return false
+ }
+ return t.DNSStart != nil || t.DNSDone != nil || t.ConnectStart != nil || t.ConnectDone != nil
+}
+
+// GotConnInfo is the argument to the ClientTrace.GotConn function and
+// contains information about the obtained connection.
+type GotConnInfo struct {
+ // Conn is the connection that was obtained. It is owned by
+ // the http.Transport and should not be read, written or
+ // closed by users of ClientTrace.
+ Conn net.Conn
+
+ // Reused is whether this connection has been previously
+ // used for another HTTP request.
+ Reused bool
+
+ // WasIdle is whether this connection was obtained from an
+ // idle pool.
+ WasIdle bool
+
+ // IdleTime reports how long the connection was previously
+ // idle, if WasIdle is true.
+ IdleTime time.Duration
+}
diff --git a/src/net/http/httptrace/trace_test.go b/src/net/http/httptrace/trace_test.go
new file mode 100644
index 0000000..6efa1f7
--- /dev/null
+++ b/src/net/http/httptrace/trace_test.go
@@ -0,0 +1,89 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package httptrace
+
+import (
+ "context"
+ "strings"
+ "testing"
+)
+
+func TestWithClientTrace(t *testing.T) {
+ var buf strings.Builder
+ connectStart := func(b byte) func(network, addr string) {
+ return func(network, addr string) {
+ buf.WriteByte(b)
+ }
+ }
+
+ ctx := context.Background()
+ oldtrace := &ClientTrace{
+ ConnectStart: connectStart('O'),
+ }
+ ctx = WithClientTrace(ctx, oldtrace)
+ newtrace := &ClientTrace{
+ ConnectStart: connectStart('N'),
+ }
+ ctx = WithClientTrace(ctx, newtrace)
+ trace := ContextClientTrace(ctx)
+
+ buf.Reset()
+ trace.ConnectStart("net", "addr")
+ if got, want := buf.String(), "NO"; got != want {
+ t.Errorf("got %q; want %q", got, want)
+ }
+}
+
+func TestCompose(t *testing.T) {
+ var buf strings.Builder
+ var testNum int
+
+ connectStart := func(b byte) func(network, addr string) {
+ return func(network, addr string) {
+ if addr != "addr" {
+ t.Errorf(`%d. args for %q case = %q, %q; want addr of "addr"`, testNum, b, network, addr)
+ }
+ buf.WriteByte(b)
+ }
+ }
+
+ tests := [...]struct {
+ trace, old *ClientTrace
+ want string
+ }{
+ 0: {
+ want: "T",
+ trace: &ClientTrace{
+ ConnectStart: connectStart('T'),
+ },
+ },
+ 1: {
+ want: "TO",
+ trace: &ClientTrace{
+ ConnectStart: connectStart('T'),
+ },
+ old: &ClientTrace{ConnectStart: connectStart('O')},
+ },
+ 2: {
+ want: "O",
+ trace: &ClientTrace{},
+ old: &ClientTrace{ConnectStart: connectStart('O')},
+ },
+ }
+ for i, tt := range tests {
+ testNum = i
+ buf.Reset()
+
+ tr := *tt.trace
+ tr.compose(tt.old)
+ if tr.ConnectStart != nil {
+ tr.ConnectStart("net", "addr")
+ }
+ if got := buf.String(); got != tt.want {
+ t.Errorf("%d. got = %q; want %q", i, got, tt.want)
+ }
+ }
+
+}
diff --git a/src/net/http/httputil/dump.go b/src/net/http/httputil/dump.go
new file mode 100644
index 0000000..7affe5e
--- /dev/null
+++ b/src/net/http/httputil/dump.go
@@ -0,0 +1,337 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package httputil
+
+import (
+ "bufio"
+ "bytes"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "net/http"
+ "net/url"
+ "strings"
+ "time"
+)
+
+// drainBody reads all of b to memory and then returns two equivalent
+// ReadClosers yielding the same bytes.
+//
+// It returns an error if the initial slurp of all bytes fails. It does not attempt
+// to make the returned ReadClosers have identical error-matching behavior.
+func drainBody(b io.ReadCloser) (r1, r2 io.ReadCloser, err error) {
+ if b == nil || b == http.NoBody {
+ // No copying needed. Preserve the magic sentinel meaning of NoBody.
+ return http.NoBody, http.NoBody, nil
+ }
+ var buf bytes.Buffer
+ if _, err = buf.ReadFrom(b); err != nil {
+ return nil, b, err
+ }
+ if err = b.Close(); err != nil {
+ return nil, b, err
+ }
+ return io.NopCloser(&buf), io.NopCloser(bytes.NewReader(buf.Bytes())), nil
+}
+
+// dumpConn is a net.Conn which writes to Writer and reads from Reader
+type dumpConn struct {
+ io.Writer
+ io.Reader
+}
+
+func (c *dumpConn) Close() error { return nil }
+func (c *dumpConn) LocalAddr() net.Addr { return nil }
+func (c *dumpConn) RemoteAddr() net.Addr { return nil }
+func (c *dumpConn) SetDeadline(t time.Time) error { return nil }
+func (c *dumpConn) SetReadDeadline(t time.Time) error { return nil }
+func (c *dumpConn) SetWriteDeadline(t time.Time) error { return nil }
+
+type neverEnding byte
+
+func (b neverEnding) Read(p []byte) (n int, err error) {
+ for i := range p {
+ p[i] = byte(b)
+ }
+ return len(p), nil
+}
+
+// outgoingLength is a copy of the unexported
+// (*http.Request).outgoingLength method.
+func outgoingLength(req *http.Request) int64 {
+ if req.Body == nil || req.Body == http.NoBody {
+ return 0
+ }
+ if req.ContentLength != 0 {
+ return req.ContentLength
+ }
+ return -1
+}
+
+// DumpRequestOut is like DumpRequest but for outgoing client requests. It
+// includes any headers that the standard http.Transport adds, such as
+// User-Agent.
+func DumpRequestOut(req *http.Request, body bool) ([]byte, error) {
+ save := req.Body
+ dummyBody := false
+ if !body {
+ contentLength := outgoingLength(req)
+ if contentLength != 0 {
+ req.Body = io.NopCloser(io.LimitReader(neverEnding('x'), contentLength))
+ dummyBody = true
+ }
+ } else {
+ var err error
+ save, req.Body, err = drainBody(req.Body)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ // Since we're using the actual Transport code to write the request,
+ // switch to http so the Transport doesn't try to do an SSL
+ // negotiation with our dumpConn and its bytes.Buffer & pipe.
+ // The wire format for https and http are the same, anyway.
+ reqSend := req
+ if req.URL.Scheme == "https" {
+ reqSend = new(http.Request)
+ *reqSend = *req
+ reqSend.URL = new(url.URL)
+ *reqSend.URL = *req.URL
+ reqSend.URL.Scheme = "http"
+ }
+
+ // Use the actual Transport code to record what we would send
+ // on the wire, but not using TCP. Use a Transport with a
+ // custom dialer that returns a fake net.Conn that waits
+ // for the full input (and recording it), and then responds
+ // with a dummy response.
+ var buf bytes.Buffer // records the output
+ pr, pw := io.Pipe()
+ defer pr.Close()
+ defer pw.Close()
+ dr := &delegateReader{c: make(chan io.Reader)}
+
+ t := &http.Transport{
+ Dial: func(net, addr string) (net.Conn, error) {
+ return &dumpConn{io.MultiWriter(&buf, pw), dr}, nil
+ },
+ }
+ defer t.CloseIdleConnections()
+
+ // We need this channel to ensure that the reader
+ // goroutine exits if t.RoundTrip returns an error.
+ // See golang.org/issue/32571.
+ quitReadCh := make(chan struct{})
+ // Wait for the request before replying with a dummy response:
+ go func() {
+ req, err := http.ReadRequest(bufio.NewReader(pr))
+ if err == nil {
+ // Ensure all the body is read; otherwise
+ // we'll get a partial dump.
+ io.Copy(io.Discard, req.Body)
+ req.Body.Close()
+ }
+ select {
+ case dr.c <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n"):
+ case <-quitReadCh:
+ // Ensure delegateReader.Read doesn't block forever if we get an error.
+ close(dr.c)
+ }
+ }()
+
+ _, err := t.RoundTrip(reqSend)
+
+ req.Body = save
+ if err != nil {
+ pw.Close()
+ dr.err = err
+ close(quitReadCh)
+ return nil, err
+ }
+ dump := buf.Bytes()
+
+ // If we used a dummy body above, remove it now.
+ // TODO: if the req.ContentLength is large, we allocate memory
+ // unnecessarily just to slice it off here. But this is just
+ // a debug function, so this is acceptable for now. We could
+ // discard the body earlier if this matters.
+ if dummyBody {
+ if i := bytes.Index(dump, []byte("\r\n\r\n")); i >= 0 {
+ dump = dump[:i+4]
+ }
+ }
+ return dump, nil
+}
+
+// delegateReader is a reader that delegates to another reader,
+// once it arrives on a channel.
+type delegateReader struct {
+ c chan io.Reader
+ err error // only used if r is nil and c is closed.
+ r io.Reader // nil until received from c
+}
+
+func (r *delegateReader) Read(p []byte) (int, error) {
+ if r.r == nil {
+ var ok bool
+ if r.r, ok = <-r.c; !ok {
+ return 0, r.err
+ }
+ }
+ return r.r.Read(p)
+}
+
+// Return value if nonempty, def otherwise.
+func valueOrDefault(value, def string) string {
+ if value != "" {
+ return value
+ }
+ return def
+}
+
+var reqWriteExcludeHeaderDump = map[string]bool{
+ "Host": true, // not in Header map anyway
+ "Transfer-Encoding": true,
+ "Trailer": true,
+}
+
+// DumpRequest returns the given request in its HTTP/1.x wire
+// representation. It should only be used by servers to debug client
+// requests. The returned representation is an approximation only;
+// some details of the initial request are lost while parsing it into
+// an http.Request. In particular, the order and case of header field
+// names are lost. The order of values in multi-valued headers is kept
+// intact. HTTP/2 requests are dumped in HTTP/1.x form, not in their
+// original binary representations.
+//
+// If body is true, DumpRequest also returns the body. To do so, it
+// consumes req.Body and then replaces it with a new io.ReadCloser
+// that yields the same bytes. If DumpRequest returns an error,
+// the state of req is undefined.
+//
+// The documentation for http.Request.Write details which fields
+// of req are included in the dump.
+func DumpRequest(req *http.Request, body bool) ([]byte, error) {
+ var err error
+ save := req.Body
+ if !body || req.Body == nil {
+ req.Body = nil
+ } else {
+ save, req.Body, err = drainBody(req.Body)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ var b bytes.Buffer
+
+ // By default, print out the unmodified req.RequestURI, which
+ // is always set for incoming server requests. But because we
+ // previously used req.URL.RequestURI and the docs weren't
+ // always so clear about when to use DumpRequest vs
+ // DumpRequestOut, fall back to the old way if the caller
+ // provides a non-server Request.
+ reqURI := req.RequestURI
+ if reqURI == "" {
+ reqURI = req.URL.RequestURI()
+ }
+
+ fmt.Fprintf(&b, "%s %s HTTP/%d.%d\r\n", valueOrDefault(req.Method, "GET"),
+ reqURI, req.ProtoMajor, req.ProtoMinor)
+
+ absRequestURI := strings.HasPrefix(req.RequestURI, "http://") || strings.HasPrefix(req.RequestURI, "https://")
+ if !absRequestURI {
+ host := req.Host
+ if host == "" && req.URL != nil {
+ host = req.URL.Host
+ }
+ if host != "" {
+ fmt.Fprintf(&b, "Host: %s\r\n", host)
+ }
+ }
+
+ chunked := len(req.TransferEncoding) > 0 && req.TransferEncoding[0] == "chunked"
+ if len(req.TransferEncoding) > 0 {
+ fmt.Fprintf(&b, "Transfer-Encoding: %s\r\n", strings.Join(req.TransferEncoding, ","))
+ }
+
+ err = req.Header.WriteSubset(&b, reqWriteExcludeHeaderDump)
+ if err != nil {
+ return nil, err
+ }
+
+ io.WriteString(&b, "\r\n")
+
+ if req.Body != nil {
+ var dest io.Writer = &b
+ if chunked {
+ dest = NewChunkedWriter(dest)
+ }
+ _, err = io.Copy(dest, req.Body)
+ if chunked {
+ dest.(io.Closer).Close()
+ io.WriteString(&b, "\r\n")
+ }
+ }
+
+ req.Body = save
+ if err != nil {
+ return nil, err
+ }
+ return b.Bytes(), nil
+}
+
+// errNoBody is a sentinel error value used by failureToReadBody so we
+// can detect that the lack of body was intentional.
+var errNoBody = errors.New("sentinel error value")
+
+// failureToReadBody is an io.ReadCloser that just returns errNoBody on
+// Read. It's swapped in when we don't actually want to consume
+// the body, but need a non-nil one, and want to distinguish the
+// error from reading the dummy body.
+type failureToReadBody struct{}
+
+func (failureToReadBody) Read([]byte) (int, error) { return 0, errNoBody }
+func (failureToReadBody) Close() error { return nil }
+
+// emptyBody is an instance of empty reader.
+var emptyBody = io.NopCloser(strings.NewReader(""))
+
+// DumpResponse is like DumpRequest but dumps a response.
+func DumpResponse(resp *http.Response, body bool) ([]byte, error) {
+ var b bytes.Buffer
+ var err error
+ save := resp.Body
+ savecl := resp.ContentLength
+
+ if !body {
+ // For content length of zero. Make sure the body is an empty
+ // reader, instead of returning error through failureToReadBody{}.
+ if resp.ContentLength == 0 {
+ resp.Body = emptyBody
+ } else {
+ resp.Body = failureToReadBody{}
+ }
+ } else if resp.Body == nil {
+ resp.Body = emptyBody
+ } else {
+ save, resp.Body, err = drainBody(resp.Body)
+ if err != nil {
+ return nil, err
+ }
+ }
+ err = resp.Write(&b)
+ if err == errNoBody {
+ err = nil
+ }
+ resp.Body = save
+ resp.ContentLength = savecl
+ if err != nil {
+ return nil, err
+ }
+ return b.Bytes(), nil
+}
diff --git a/src/net/http/httputil/dump_test.go b/src/net/http/httputil/dump_test.go
new file mode 100644
index 0000000..c20c054
--- /dev/null
+++ b/src/net/http/httputil/dump_test.go
@@ -0,0 +1,532 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package httputil
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "fmt"
+ "io"
+ "math/rand"
+ "net/http"
+ "net/url"
+ "runtime"
+ "runtime/pprof"
+ "strings"
+ "testing"
+ "time"
+)
+
+type eofReader struct{}
+
+func (n eofReader) Close() error { return nil }
+
+func (n eofReader) Read([]byte) (int, error) { return 0, io.EOF }
+
+type dumpTest struct {
+ // Either Req or GetReq can be set/nil but not both.
+ Req *http.Request
+ GetReq func() *http.Request
+
+ Body any // optional []byte or func() io.ReadCloser to populate Req.Body
+
+ WantDump string
+ WantDumpOut string
+ MustError bool // if true, the test is expected to throw an error
+ NoBody bool // if true, set DumpRequest{,Out} body to false
+}
+
+var dumpTests = []dumpTest{
+ // HTTP/1.1 => chunked coding; body; empty trailer
+ {
+ Req: &http.Request{
+ Method: "GET",
+ URL: &url.URL{
+ Scheme: "http",
+ Host: "www.google.com",
+ Path: "/search",
+ },
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ TransferEncoding: []string{"chunked"},
+ },
+
+ Body: []byte("abcdef"),
+
+ WantDump: "GET /search HTTP/1.1\r\n" +
+ "Host: www.google.com\r\n" +
+ "Transfer-Encoding: chunked\r\n\r\n" +
+ chunk("abcdef") + chunk(""),
+ },
+
+ // Verify that DumpRequest preserves the HTTP version number, doesn't add a Host,
+ // and doesn't add a User-Agent.
+ {
+ Req: &http.Request{
+ Method: "GET",
+ URL: mustParseURL("/foo"),
+ ProtoMajor: 1,
+ ProtoMinor: 0,
+ Header: http.Header{
+ "X-Foo": []string{"X-Bar"},
+ },
+ },
+
+ WantDump: "GET /foo HTTP/1.0\r\n" +
+ "X-Foo: X-Bar\r\n\r\n",
+ },
+
+ {
+ Req: mustNewRequest("GET", "http://example.com/foo", nil),
+
+ WantDumpOut: "GET /foo HTTP/1.1\r\n" +
+ "Host: example.com\r\n" +
+ "User-Agent: Go-http-client/1.1\r\n" +
+ "Accept-Encoding: gzip\r\n\r\n",
+ },
+
+ // Test that an https URL doesn't try to do an SSL negotiation
+ // with a bytes.Buffer and hang with all goroutines not
+ // runnable.
+ {
+ Req: mustNewRequest("GET", "https://example.com/foo", nil),
+ WantDumpOut: "GET /foo HTTP/1.1\r\n" +
+ "Host: example.com\r\n" +
+ "User-Agent: Go-http-client/1.1\r\n" +
+ "Accept-Encoding: gzip\r\n\r\n",
+ },
+
+ // Request with Body, but Dump requested without it.
+ {
+ Req: &http.Request{
+ Method: "POST",
+ URL: &url.URL{
+ Scheme: "http",
+ Host: "post.tld",
+ Path: "/",
+ },
+ ContentLength: 6,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ },
+
+ Body: []byte("abcdef"),
+
+ WantDumpOut: "POST / HTTP/1.1\r\n" +
+ "Host: post.tld\r\n" +
+ "User-Agent: Go-http-client/1.1\r\n" +
+ "Content-Length: 6\r\n" +
+ "Accept-Encoding: gzip\r\n\r\n",
+
+ NoBody: true,
+ },
+
+ // Request with Body > 8196 (default buffer size)
+ {
+ Req: &http.Request{
+ Method: "POST",
+ URL: &url.URL{
+ Scheme: "http",
+ Host: "post.tld",
+ Path: "/",
+ },
+ Header: http.Header{
+ "Content-Length": []string{"8193"},
+ },
+
+ ContentLength: 8193,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ },
+
+ Body: bytes.Repeat([]byte("a"), 8193),
+
+ WantDumpOut: "POST / HTTP/1.1\r\n" +
+ "Host: post.tld\r\n" +
+ "User-Agent: Go-http-client/1.1\r\n" +
+ "Content-Length: 8193\r\n" +
+ "Accept-Encoding: gzip\r\n\r\n" +
+ strings.Repeat("a", 8193),
+ WantDump: "POST / HTTP/1.1\r\n" +
+ "Host: post.tld\r\n" +
+ "Content-Length: 8193\r\n\r\n" +
+ strings.Repeat("a", 8193),
+ },
+
+ {
+ GetReq: func() *http.Request {
+ return mustReadRequest("GET http://foo.com/ HTTP/1.1\r\n" +
+ "User-Agent: blah\r\n\r\n")
+ },
+ NoBody: true,
+ WantDump: "GET http://foo.com/ HTTP/1.1\r\n" +
+ "User-Agent: blah\r\n\r\n",
+ },
+
+ // Issue #7215. DumpRequest should return the "Content-Length" when set
+ {
+ GetReq: func() *http.Request {
+ return mustReadRequest("POST /v2/api/?login HTTP/1.1\r\n" +
+ "Host: passport.myhost.com\r\n" +
+ "Content-Length: 3\r\n" +
+ "\r\nkey1=name1&key2=name2")
+ },
+ WantDump: "POST /v2/api/?login HTTP/1.1\r\n" +
+ "Host: passport.myhost.com\r\n" +
+ "Content-Length: 3\r\n" +
+ "\r\nkey",
+ },
+ // Issue #7215. DumpRequest should return the "Content-Length" in ReadRequest
+ {
+ GetReq: func() *http.Request {
+ return mustReadRequest("POST /v2/api/?login HTTP/1.1\r\n" +
+ "Host: passport.myhost.com\r\n" +
+ "Content-Length: 0\r\n" +
+ "\r\nkey1=name1&key2=name2")
+ },
+ WantDump: "POST /v2/api/?login HTTP/1.1\r\n" +
+ "Host: passport.myhost.com\r\n" +
+ "Content-Length: 0\r\n\r\n",
+ },
+
+ // Issue #7215. DumpRequest should not return the "Content-Length" if unset
+ {
+ GetReq: func() *http.Request {
+ return mustReadRequest("POST /v2/api/?login HTTP/1.1\r\n" +
+ "Host: passport.myhost.com\r\n" +
+ "\r\nkey1=name1&key2=name2")
+ },
+ WantDump: "POST /v2/api/?login HTTP/1.1\r\n" +
+ "Host: passport.myhost.com\r\n\r\n",
+ },
+
+ // Issue 18506: make drainBody recognize NoBody. Otherwise
+ // this was turning into a chunked request.
+ {
+ Req: mustNewRequest("POST", "http://example.com/foo", http.NoBody),
+ WantDumpOut: "POST /foo HTTP/1.1\r\n" +
+ "Host: example.com\r\n" +
+ "User-Agent: Go-http-client/1.1\r\n" +
+ "Content-Length: 0\r\n" +
+ "Accept-Encoding: gzip\r\n\r\n",
+ },
+
+ // Issue 34504: a non-nil Body without ContentLength set should be chunked
+ {
+ Req: &http.Request{
+ Method: "PUT",
+ URL: &url.URL{
+ Scheme: "http",
+ Host: "post.tld",
+ Path: "/test",
+ },
+ ContentLength: 0,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Body: &eofReader{},
+ },
+ NoBody: true,
+ WantDumpOut: "PUT /test HTTP/1.1\r\n" +
+ "Host: post.tld\r\n" +
+ "User-Agent: Go-http-client/1.1\r\n" +
+ "Transfer-Encoding: chunked\r\n" +
+ "Accept-Encoding: gzip\r\n\r\n",
+ },
+
+ // Issue 54616: request with Connection header doesn't result in duplicate header.
+ {
+ GetReq: func() *http.Request {
+ return mustReadRequest("GET / HTTP/1.1\r\n" +
+ "Host: example.com\r\n" +
+ "Connection: close\r\n\r\n")
+ },
+ NoBody: true,
+ WantDump: "GET / HTTP/1.1\r\n" +
+ "Host: example.com\r\n" +
+ "Connection: close\r\n\r\n",
+ },
+}
+
+func TestDumpRequest(t *testing.T) {
+ // Make a copy of dumpTests and add 10 new cases with an empty URL
+ // to test that no goroutines are leaked. See golang.org/issue/32571.
+ // 10 seems to be a decent number which always triggers the failure.
+ dumpTests := dumpTests[:]
+ for i := 0; i < 10; i++ {
+ dumpTests = append(dumpTests, dumpTest{
+ Req: mustNewRequest("GET", "", nil),
+ MustError: true,
+ })
+ }
+ numg0 := runtime.NumGoroutine()
+ for i, tt := range dumpTests {
+ if tt.Req != nil && tt.GetReq != nil || tt.Req == nil && tt.GetReq == nil {
+ t.Errorf("#%d: either .Req(%p) or .GetReq(%p) can be set/nil but not both", i, tt.Req, tt.GetReq)
+ continue
+ }
+
+ freshReq := func(ti dumpTest) *http.Request {
+ req := ti.Req
+ if req == nil {
+ req = ti.GetReq()
+ }
+
+ if req.Header == nil {
+ req.Header = make(http.Header)
+ }
+
+ if ti.Body == nil {
+ return req
+ }
+ switch b := ti.Body.(type) {
+ case []byte:
+ req.Body = io.NopCloser(bytes.NewReader(b))
+ case func() io.ReadCloser:
+ req.Body = b()
+ default:
+ t.Fatalf("Test %d: unsupported Body of %T", i, ti.Body)
+ }
+ return req
+ }
+
+ if tt.WantDump != "" {
+ req := freshReq(tt)
+ dump, err := DumpRequest(req, !tt.NoBody)
+ if err != nil {
+ t.Errorf("DumpRequest #%d: %s\nWantDump:\n%s", i, err, tt.WantDump)
+ continue
+ }
+ if string(dump) != tt.WantDump {
+ t.Errorf("DumpRequest %d, expecting:\n%s\nGot:\n%s\n", i, tt.WantDump, string(dump))
+ continue
+ }
+ }
+
+ if tt.MustError {
+ req := freshReq(tt)
+ _, err := DumpRequestOut(req, !tt.NoBody)
+ if err == nil {
+ t.Errorf("DumpRequestOut #%d: expected an error, got nil", i)
+ }
+ continue
+ }
+
+ if tt.WantDumpOut != "" {
+ req := freshReq(tt)
+ dump, err := DumpRequestOut(req, !tt.NoBody)
+ if err != nil {
+ t.Errorf("DumpRequestOut #%d: %s", i, err)
+ continue
+ }
+ if string(dump) != tt.WantDumpOut {
+ t.Errorf("DumpRequestOut %d, expecting:\n%s\nGot:\n%s\n", i, tt.WantDumpOut, string(dump))
+ continue
+ }
+ }
+ }
+
+ // Validate we haven't leaked any goroutines.
+ var dg int
+ dl := deadline(t, 5*time.Second, time.Second)
+ for time.Now().Before(dl) {
+ if dg = runtime.NumGoroutine() - numg0; dg <= 4 {
+ // No unexpected goroutines.
+ return
+ }
+
+ // Allow goroutines to schedule and die off.
+ runtime.Gosched()
+ }
+
+ buf := make([]byte, 4096)
+ buf = buf[:runtime.Stack(buf, true)]
+ t.Errorf("Unexpectedly large number of new goroutines: %d new: %s", dg, buf)
+}
+
+// deadline returns the time which is needed before t.Deadline()
+// if one is configured and it is s greater than needed in the future,
+// otherwise defaultDelay from the current time.
+func deadline(t *testing.T, defaultDelay, needed time.Duration) time.Time {
+ if dl, ok := t.Deadline(); ok {
+ if dl = dl.Add(-needed); dl.After(time.Now()) {
+ // Allow an arbitrarily long delay.
+ return dl
+ }
+ }
+
+ // No deadline configured or its closer than needed from now
+ // so just use the default.
+ return time.Now().Add(defaultDelay)
+}
+
+func chunk(s string) string {
+ return fmt.Sprintf("%x\r\n%s\r\n", len(s), s)
+}
+
+func mustParseURL(s string) *url.URL {
+ u, err := url.Parse(s)
+ if err != nil {
+ panic(fmt.Sprintf("Error parsing URL %q: %v", s, err))
+ }
+ return u
+}
+
+func mustNewRequest(method, url string, body io.Reader) *http.Request {
+ req, err := http.NewRequest(method, url, body)
+ if err != nil {
+ panic(fmt.Sprintf("NewRequest(%q, %q, %p) err = %v", method, url, body, err))
+ }
+ return req
+}
+
+func mustReadRequest(s string) *http.Request {
+ req, err := http.ReadRequest(bufio.NewReader(strings.NewReader(s)))
+ if err != nil {
+ panic(err)
+ }
+ return req
+}
+
+var dumpResTests = []struct {
+ res *http.Response
+ body bool
+ want string
+}{
+ {
+ res: &http.Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ ContentLength: 50,
+ Header: http.Header{
+ "Foo": []string{"Bar"},
+ },
+ Body: io.NopCloser(strings.NewReader("foo")), // shouldn't be used
+ },
+ body: false, // to verify we see 50, not empty or 3.
+ want: `HTTP/1.1 200 OK
+Content-Length: 50
+Foo: Bar`,
+ },
+
+ {
+ res: &http.Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ ContentLength: 3,
+ Body: io.NopCloser(strings.NewReader("foo")),
+ },
+ body: true,
+ want: `HTTP/1.1 200 OK
+Content-Length: 3
+
+foo`,
+ },
+
+ {
+ res: &http.Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ ContentLength: -1,
+ Body: io.NopCloser(strings.NewReader("foo")),
+ TransferEncoding: []string{"chunked"},
+ },
+ body: true,
+ want: `HTTP/1.1 200 OK
+Transfer-Encoding: chunked
+
+3
+foo
+0`,
+ },
+ {
+ res: &http.Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ ContentLength: 0,
+ Header: http.Header{
+ // To verify if headers are not filtered out.
+ "Foo1": []string{"Bar1"},
+ "Foo2": []string{"Bar2"},
+ },
+ Body: nil,
+ },
+ body: false, // to verify we see 0, not empty.
+ want: `HTTP/1.1 200 OK
+Foo1: Bar1
+Foo2: Bar2
+Content-Length: 0`,
+ },
+}
+
+func TestDumpResponse(t *testing.T) {
+ for i, tt := range dumpResTests {
+ gotb, err := DumpResponse(tt.res, tt.body)
+ if err != nil {
+ t.Errorf("%d. DumpResponse = %v", i, err)
+ continue
+ }
+ got := string(gotb)
+ got = strings.TrimSpace(got)
+ got = strings.ReplaceAll(got, "\r", "")
+
+ if got != tt.want {
+ t.Errorf("%d.\nDumpResponse got:\n%s\n\nWant:\n%s\n", i, got, tt.want)
+ }
+ }
+}
+
+// Issue 38352: Check for deadlock on canceled requests.
+func TestDumpRequestOutIssue38352(t *testing.T) {
+ if testing.Short() {
+ return
+ }
+ t.Parallel()
+
+ timeout := 10 * time.Second
+ if deadline, ok := t.Deadline(); ok {
+ timeout = time.Until(deadline)
+ timeout -= time.Second * 2 // Leave 2 seconds to report failures.
+ }
+ for i := 0; i < 1000; i++ {
+ delay := time.Duration(rand.Intn(5)) * time.Millisecond
+ ctx, cancel := context.WithTimeout(context.Background(), delay)
+ defer cancel()
+
+ r := bytes.NewBuffer(make([]byte, 10000))
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://example.com", r)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ out := make(chan error)
+ go func() {
+ _, err = DumpRequestOut(req, true)
+ out <- err
+ }()
+
+ select {
+ case <-out:
+ case <-time.After(timeout):
+ b := &strings.Builder{}
+ fmt.Fprintf(b, "deadlock detected on iteration %d after %s with delay: %v\n", i, timeout, delay)
+ pprof.Lookup("goroutine").WriteTo(b, 1)
+ t.Fatal(b.String())
+ }
+ }
+}
diff --git a/src/net/http/httputil/example_test.go b/src/net/http/httputil/example_test.go
new file mode 100644
index 0000000..6c107f8
--- /dev/null
+++ b/src/net/http/httputil/example_test.go
@@ -0,0 +1,128 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package httputil_test
+
+import (
+ "fmt"
+ "io"
+ "log"
+ "net/http"
+ "net/http/httptest"
+ "net/http/httputil"
+ "net/url"
+ "strings"
+)
+
+func ExampleDumpRequest() {
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ dump, err := httputil.DumpRequest(r, true)
+ if err != nil {
+ http.Error(w, fmt.Sprint(err), http.StatusInternalServerError)
+ return
+ }
+
+ fmt.Fprintf(w, "%q", dump)
+ }))
+ defer ts.Close()
+
+ const body = "Go is a general-purpose language designed with systems programming in mind."
+ req, err := http.NewRequest("POST", ts.URL, strings.NewReader(body))
+ if err != nil {
+ log.Fatal(err)
+ }
+ req.Host = "www.example.org"
+ resp, err := http.DefaultClient.Do(req)
+ if err != nil {
+ log.Fatal(err)
+ }
+ defer resp.Body.Close()
+
+ b, err := io.ReadAll(resp.Body)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ fmt.Printf("%s", b)
+
+ // Output:
+ // "POST / HTTP/1.1\r\nHost: www.example.org\r\nAccept-Encoding: gzip\r\nContent-Length: 75\r\nUser-Agent: Go-http-client/1.1\r\n\r\nGo is a general-purpose language designed with systems programming in mind."
+}
+
+func ExampleDumpRequestOut() {
+ const body = "Go is a general-purpose language designed with systems programming in mind."
+ req, err := http.NewRequest("PUT", "http://www.example.org", strings.NewReader(body))
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ dump, err := httputil.DumpRequestOut(req, true)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ fmt.Printf("%q", dump)
+
+ // Output:
+ // "PUT / HTTP/1.1\r\nHost: www.example.org\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 75\r\nAccept-Encoding: gzip\r\n\r\nGo is a general-purpose language designed with systems programming in mind."
+}
+
+func ExampleDumpResponse() {
+ const body = "Go is a general-purpose language designed with systems programming in mind."
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Date", "Wed, 19 Jul 1972 19:00:00 GMT")
+ fmt.Fprintln(w, body)
+ }))
+ defer ts.Close()
+
+ resp, err := http.Get(ts.URL)
+ if err != nil {
+ log.Fatal(err)
+ }
+ defer resp.Body.Close()
+
+ dump, err := httputil.DumpResponse(resp, true)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ fmt.Printf("%q", dump)
+
+ // Output:
+ // "HTTP/1.1 200 OK\r\nContent-Length: 76\r\nContent-Type: text/plain; charset=utf-8\r\nDate: Wed, 19 Jul 1972 19:00:00 GMT\r\n\r\nGo is a general-purpose language designed with systems programming in mind.\n"
+}
+
+func ExampleReverseProxy() {
+ backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprintln(w, "this call was relayed by the reverse proxy")
+ }))
+ defer backendServer.Close()
+
+ rpURL, err := url.Parse(backendServer.URL)
+ if err != nil {
+ log.Fatal(err)
+ }
+ frontendProxy := httptest.NewServer(&httputil.ReverseProxy{
+ Rewrite: func(r *httputil.ProxyRequest) {
+ r.SetXForwarded()
+ r.SetURL(rpURL)
+ },
+ })
+ defer frontendProxy.Close()
+
+ resp, err := http.Get(frontendProxy.URL)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ b, err := io.ReadAll(resp.Body)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ fmt.Printf("%s", b)
+
+ // Output:
+ // this call was relayed by the reverse proxy
+}
diff --git a/src/net/http/httputil/httputil.go b/src/net/http/httputil/httputil.go
new file mode 100644
index 0000000..09ea74d
--- /dev/null
+++ b/src/net/http/httputil/httputil.go
@@ -0,0 +1,41 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package httputil provides HTTP utility functions, complementing the
+// more common ones in the net/http package.
+package httputil
+
+import (
+ "io"
+ "net/http/internal"
+)
+
+// NewChunkedReader returns a new chunkedReader that translates the data read from r
+// out of HTTP "chunked" format before returning it.
+// The chunkedReader returns io.EOF when the final 0-length chunk is read.
+//
+// NewChunkedReader is not needed by normal applications. The http package
+// automatically decodes chunking when reading response bodies.
+func NewChunkedReader(r io.Reader) io.Reader {
+ return internal.NewChunkedReader(r)
+}
+
+// NewChunkedWriter returns a new chunkedWriter that translates writes into HTTP
+// "chunked" format before writing them to w. Closing the returned chunkedWriter
+// sends the final 0-length chunk that marks the end of the stream but does
+// not send the final CRLF that appears after trailers; trailers and the last
+// CRLF must be written separately.
+//
+// NewChunkedWriter is not needed by normal applications. The http
+// package adds chunking automatically if handlers don't set a
+// Content-Length header. Using NewChunkedWriter inside a handler
+// would result in double chunking or chunking with a Content-Length
+// length, both of which are wrong.
+func NewChunkedWriter(w io.Writer) io.WriteCloser {
+ return internal.NewChunkedWriter(w)
+}
+
+// ErrLineTooLong is returned when reading malformed chunked data
+// with lines that are too long.
+var ErrLineTooLong = internal.ErrLineTooLong
diff --git a/src/net/http/httputil/persist.go b/src/net/http/httputil/persist.go
new file mode 100644
index 0000000..84b116d
--- /dev/null
+++ b/src/net/http/httputil/persist.go
@@ -0,0 +1,431 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package httputil
+
+import (
+ "bufio"
+ "errors"
+ "io"
+ "net"
+ "net/http"
+ "net/textproto"
+ "sync"
+)
+
+var (
+ // Deprecated: No longer used.
+ ErrPersistEOF = &http.ProtocolError{ErrorString: "persistent connection closed"}
+
+ // Deprecated: No longer used.
+ ErrClosed = &http.ProtocolError{ErrorString: "connection closed by user"}
+
+ // Deprecated: No longer used.
+ ErrPipeline = &http.ProtocolError{ErrorString: "pipeline error"}
+)
+
+// This is an API usage error - the local side is closed.
+// ErrPersistEOF (above) reports that the remote side is closed.
+var errClosed = errors.New("i/o operation on closed connection")
+
+// ServerConn is an artifact of Go's early HTTP implementation.
+// It is low-level, old, and unused by Go's current HTTP stack.
+// We should have deleted it before Go 1.
+//
+// Deprecated: Use the Server in package net/http instead.
+type ServerConn struct {
+ mu sync.Mutex // read-write protects the following fields
+ c net.Conn
+ r *bufio.Reader
+ re, we error // read/write errors
+ lastbody io.ReadCloser
+ nread, nwritten int
+ pipereq map[*http.Request]uint
+
+ pipe textproto.Pipeline
+}
+
+// NewServerConn is an artifact of Go's early HTTP implementation.
+// It is low-level, old, and unused by Go's current HTTP stack.
+// We should have deleted it before Go 1.
+//
+// Deprecated: Use the Server in package net/http instead.
+func NewServerConn(c net.Conn, r *bufio.Reader) *ServerConn {
+ if r == nil {
+ r = bufio.NewReader(c)
+ }
+ return &ServerConn{c: c, r: r, pipereq: make(map[*http.Request]uint)}
+}
+
+// Hijack detaches the ServerConn and returns the underlying connection as well
+// as the read-side bufio which may have some left over data. Hijack may be
+// called before Read has signaled the end of the keep-alive logic. The user
+// should not call Hijack while Read or Write is in progress.
+func (sc *ServerConn) Hijack() (net.Conn, *bufio.Reader) {
+ sc.mu.Lock()
+ defer sc.mu.Unlock()
+ c := sc.c
+ r := sc.r
+ sc.c = nil
+ sc.r = nil
+ return c, r
+}
+
+// Close calls Hijack and then also closes the underlying connection.
+func (sc *ServerConn) Close() error {
+ c, _ := sc.Hijack()
+ if c != nil {
+ return c.Close()
+ }
+ return nil
+}
+
+// Read returns the next request on the wire. An ErrPersistEOF is returned if
+// it is gracefully determined that there are no more requests (e.g. after the
+// first request on an HTTP/1.0 connection, or after a Connection:close on a
+// HTTP/1.1 connection).
+func (sc *ServerConn) Read() (*http.Request, error) {
+ var req *http.Request
+ var err error
+
+ // Ensure ordered execution of Reads and Writes
+ id := sc.pipe.Next()
+ sc.pipe.StartRequest(id)
+ defer func() {
+ sc.pipe.EndRequest(id)
+ if req == nil {
+ sc.pipe.StartResponse(id)
+ sc.pipe.EndResponse(id)
+ } else {
+ // Remember the pipeline id of this request
+ sc.mu.Lock()
+ sc.pipereq[req] = id
+ sc.mu.Unlock()
+ }
+ }()
+
+ sc.mu.Lock()
+ if sc.we != nil { // no point receiving if write-side broken or closed
+ defer sc.mu.Unlock()
+ return nil, sc.we
+ }
+ if sc.re != nil {
+ defer sc.mu.Unlock()
+ return nil, sc.re
+ }
+ if sc.r == nil { // connection closed by user in the meantime
+ defer sc.mu.Unlock()
+ return nil, errClosed
+ }
+ r := sc.r
+ lastbody := sc.lastbody
+ sc.lastbody = nil
+ sc.mu.Unlock()
+
+ // Make sure body is fully consumed, even if user does not call body.Close
+ if lastbody != nil {
+ // body.Close is assumed to be idempotent and multiple calls to
+ // it should return the error that its first invocation
+ // returned.
+ err = lastbody.Close()
+ if err != nil {
+ sc.mu.Lock()
+ defer sc.mu.Unlock()
+ sc.re = err
+ return nil, err
+ }
+ }
+
+ req, err = http.ReadRequest(r)
+ sc.mu.Lock()
+ defer sc.mu.Unlock()
+ if err != nil {
+ if err == io.ErrUnexpectedEOF {
+ // A close from the opposing client is treated as a
+ // graceful close, even if there was some unparse-able
+ // data before the close.
+ sc.re = ErrPersistEOF
+ return nil, sc.re
+ } else {
+ sc.re = err
+ return req, err
+ }
+ }
+ sc.lastbody = req.Body
+ sc.nread++
+ if req.Close {
+ sc.re = ErrPersistEOF
+ return req, sc.re
+ }
+ return req, err
+}
+
+// Pending returns the number of unanswered requests
+// that have been received on the connection.
+func (sc *ServerConn) Pending() int {
+ sc.mu.Lock()
+ defer sc.mu.Unlock()
+ return sc.nread - sc.nwritten
+}
+
+// Write writes resp in response to req. To close the connection gracefully, set the
+// Response.Close field to true. Write should be considered operational until
+// it returns an error, regardless of any errors returned on the Read side.
+func (sc *ServerConn) Write(req *http.Request, resp *http.Response) error {
+
+ // Retrieve the pipeline ID of this request/response pair
+ sc.mu.Lock()
+ id, ok := sc.pipereq[req]
+ delete(sc.pipereq, req)
+ if !ok {
+ sc.mu.Unlock()
+ return ErrPipeline
+ }
+ sc.mu.Unlock()
+
+ // Ensure pipeline order
+ sc.pipe.StartResponse(id)
+ defer sc.pipe.EndResponse(id)
+
+ sc.mu.Lock()
+ if sc.we != nil {
+ defer sc.mu.Unlock()
+ return sc.we
+ }
+ if sc.c == nil { // connection closed by user in the meantime
+ defer sc.mu.Unlock()
+ return ErrClosed
+ }
+ c := sc.c
+ if sc.nread <= sc.nwritten {
+ defer sc.mu.Unlock()
+ return errors.New("persist server pipe count")
+ }
+ if resp.Close {
+ // After signaling a keep-alive close, any pipelined unread
+ // requests will be lost. It is up to the user to drain them
+ // before signaling.
+ sc.re = ErrPersistEOF
+ }
+ sc.mu.Unlock()
+
+ err := resp.Write(c)
+ sc.mu.Lock()
+ defer sc.mu.Unlock()
+ if err != nil {
+ sc.we = err
+ return err
+ }
+ sc.nwritten++
+
+ return nil
+}
+
+// ClientConn is an artifact of Go's early HTTP implementation.
+// It is low-level, old, and unused by Go's current HTTP stack.
+// We should have deleted it before Go 1.
+//
+// Deprecated: Use Client or Transport in package net/http instead.
+type ClientConn struct {
+ mu sync.Mutex // read-write protects the following fields
+ c net.Conn
+ r *bufio.Reader
+ re, we error // read/write errors
+ lastbody io.ReadCloser
+ nread, nwritten int
+ pipereq map[*http.Request]uint
+
+ pipe textproto.Pipeline
+ writeReq func(*http.Request, io.Writer) error
+}
+
+// NewClientConn is an artifact of Go's early HTTP implementation.
+// It is low-level, old, and unused by Go's current HTTP stack.
+// We should have deleted it before Go 1.
+//
+// Deprecated: Use the Client or Transport in package net/http instead.
+func NewClientConn(c net.Conn, r *bufio.Reader) *ClientConn {
+ if r == nil {
+ r = bufio.NewReader(c)
+ }
+ return &ClientConn{
+ c: c,
+ r: r,
+ pipereq: make(map[*http.Request]uint),
+ writeReq: (*http.Request).Write,
+ }
+}
+
+// NewProxyClientConn is an artifact of Go's early HTTP implementation.
+// It is low-level, old, and unused by Go's current HTTP stack.
+// We should have deleted it before Go 1.
+//
+// Deprecated: Use the Client or Transport in package net/http instead.
+func NewProxyClientConn(c net.Conn, r *bufio.Reader) *ClientConn {
+ cc := NewClientConn(c, r)
+ cc.writeReq = (*http.Request).WriteProxy
+ return cc
+}
+
+// Hijack detaches the ClientConn and returns the underlying connection as well
+// as the read-side bufio which may have some left over data. Hijack may be
+// called before the user or Read have signaled the end of the keep-alive
+// logic. The user should not call Hijack while Read or Write is in progress.
+func (cc *ClientConn) Hijack() (c net.Conn, r *bufio.Reader) {
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ c = cc.c
+ r = cc.r
+ cc.c = nil
+ cc.r = nil
+ return
+}
+
+// Close calls Hijack and then also closes the underlying connection.
+func (cc *ClientConn) Close() error {
+ c, _ := cc.Hijack()
+ if c != nil {
+ return c.Close()
+ }
+ return nil
+}
+
+// Write writes a request. An ErrPersistEOF error is returned if the connection
+// has been closed in an HTTP keep-alive sense. If req.Close equals true, the
+// keep-alive connection is logically closed after this request and the opposing
+// server is informed. An ErrUnexpectedEOF indicates the remote closed the
+// underlying TCP connection, which is usually considered as graceful close.
+func (cc *ClientConn) Write(req *http.Request) error {
+ var err error
+
+ // Ensure ordered execution of Writes
+ id := cc.pipe.Next()
+ cc.pipe.StartRequest(id)
+ defer func() {
+ cc.pipe.EndRequest(id)
+ if err != nil {
+ cc.pipe.StartResponse(id)
+ cc.pipe.EndResponse(id)
+ } else {
+ // Remember the pipeline id of this request
+ cc.mu.Lock()
+ cc.pipereq[req] = id
+ cc.mu.Unlock()
+ }
+ }()
+
+ cc.mu.Lock()
+ if cc.re != nil { // no point sending if read-side closed or broken
+ defer cc.mu.Unlock()
+ return cc.re
+ }
+ if cc.we != nil {
+ defer cc.mu.Unlock()
+ return cc.we
+ }
+ if cc.c == nil { // connection closed by user in the meantime
+ defer cc.mu.Unlock()
+ return errClosed
+ }
+ c := cc.c
+ if req.Close {
+ // We write the EOF to the write-side error, because there
+ // still might be some pipelined reads
+ cc.we = ErrPersistEOF
+ }
+ cc.mu.Unlock()
+
+ err = cc.writeReq(req, c)
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ if err != nil {
+ cc.we = err
+ return err
+ }
+ cc.nwritten++
+
+ return nil
+}
+
+// Pending returns the number of unanswered requests
+// that have been sent on the connection.
+func (cc *ClientConn) Pending() int {
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ return cc.nwritten - cc.nread
+}
+
+// Read reads the next response from the wire. A valid response might be
+// returned together with an ErrPersistEOF, which means that the remote
+// requested that this be the last request serviced. Read can be called
+// concurrently with Write, but not with another Read.
+func (cc *ClientConn) Read(req *http.Request) (resp *http.Response, err error) {
+ // Retrieve the pipeline ID of this request/response pair
+ cc.mu.Lock()
+ id, ok := cc.pipereq[req]
+ delete(cc.pipereq, req)
+ if !ok {
+ cc.mu.Unlock()
+ return nil, ErrPipeline
+ }
+ cc.mu.Unlock()
+
+ // Ensure pipeline order
+ cc.pipe.StartResponse(id)
+ defer cc.pipe.EndResponse(id)
+
+ cc.mu.Lock()
+ if cc.re != nil {
+ defer cc.mu.Unlock()
+ return nil, cc.re
+ }
+ if cc.r == nil { // connection closed by user in the meantime
+ defer cc.mu.Unlock()
+ return nil, errClosed
+ }
+ r := cc.r
+ lastbody := cc.lastbody
+ cc.lastbody = nil
+ cc.mu.Unlock()
+
+ // Make sure body is fully consumed, even if user does not call body.Close
+ if lastbody != nil {
+ // body.Close is assumed to be idempotent and multiple calls to
+ // it should return the error that its first invocation
+ // returned.
+ err = lastbody.Close()
+ if err != nil {
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ cc.re = err
+ return nil, err
+ }
+ }
+
+ resp, err = http.ReadResponse(r, req)
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ if err != nil {
+ cc.re = err
+ return resp, err
+ }
+ cc.lastbody = resp.Body
+
+ cc.nread++
+
+ if resp.Close {
+ cc.re = ErrPersistEOF // don't send any more requests
+ return resp, cc.re
+ }
+ return resp, err
+}
+
+// Do is convenience method that writes a request and reads a response.
+func (cc *ClientConn) Do(req *http.Request) (*http.Response, error) {
+ err := cc.Write(req)
+ if err != nil {
+ return nil, err
+ }
+ return cc.Read(req)
+}
diff --git a/src/net/http/httputil/reverseproxy.go b/src/net/http/httputil/reverseproxy.go
new file mode 100644
index 0000000..2a76b0b
--- /dev/null
+++ b/src/net/http/httputil/reverseproxy.go
@@ -0,0 +1,834 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// HTTP reverse proxy handler
+
+package httputil
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "log"
+ "mime"
+ "net"
+ "net/http"
+ "net/http/httptrace"
+ "net/http/internal/ascii"
+ "net/textproto"
+ "net/url"
+ "strings"
+ "sync"
+ "time"
+
+ "golang.org/x/net/http/httpguts"
+)
+
+// A ProxyRequest contains a request to be rewritten by a ReverseProxy.
+type ProxyRequest struct {
+ // In is the request received by the proxy.
+ // The Rewrite function must not modify In.
+ In *http.Request
+
+ // Out is the request which will be sent by the proxy.
+ // The Rewrite function may modify or replace this request.
+ // Hop-by-hop headers are removed from this request
+ // before Rewrite is called.
+ Out *http.Request
+}
+
+// SetURL routes the outbound request to the scheme, host, and base path
+// provided in target. If the target's path is "/base" and the incoming
+// request was for "/dir", the target request will be for "/base/dir".
+//
+// SetURL rewrites the outbound Host header to match the target's host.
+// To preserve the inbound request's Host header (the default behavior
+// of NewSingleHostReverseProxy):
+//
+// rewriteFunc := func(r *httputil.ProxyRequest) {
+// r.SetURL(url)
+// r.Out.Host = r.In.Host
+// }
+func (r *ProxyRequest) SetURL(target *url.URL) {
+ rewriteRequestURL(r.Out, target)
+ r.Out.Host = ""
+}
+
+// SetXForwarded sets the X-Forwarded-For, X-Forwarded-Host, and
+// X-Forwarded-Proto headers of the outbound request.
+//
+// - The X-Forwarded-For header is set to the client IP address.
+// - The X-Forwarded-Host header is set to the host name requested
+// by the client.
+// - The X-Forwarded-Proto header is set to "http" or "https", depending
+// on whether the inbound request was made on a TLS-enabled connection.
+//
+// If the outbound request contains an existing X-Forwarded-For header,
+// SetXForwarded appends the client IP address to it. To append to the
+// inbound request's X-Forwarded-For header (the default behavior of
+// ReverseProxy when using a Director function), copy the header
+// from the inbound request before calling SetXForwarded:
+//
+// rewriteFunc := func(r *httputil.ProxyRequest) {
+// r.Out.Header["X-Forwarded-For"] = r.In.Header["X-Forwarded-For"]
+// r.SetXForwarded()
+// }
+func (r *ProxyRequest) SetXForwarded() {
+ clientIP, _, err := net.SplitHostPort(r.In.RemoteAddr)
+ if err == nil {
+ prior := r.Out.Header["X-Forwarded-For"]
+ if len(prior) > 0 {
+ clientIP = strings.Join(prior, ", ") + ", " + clientIP
+ }
+ r.Out.Header.Set("X-Forwarded-For", clientIP)
+ } else {
+ r.Out.Header.Del("X-Forwarded-For")
+ }
+ r.Out.Header.Set("X-Forwarded-Host", r.In.Host)
+ if r.In.TLS == nil {
+ r.Out.Header.Set("X-Forwarded-Proto", "http")
+ } else {
+ r.Out.Header.Set("X-Forwarded-Proto", "https")
+ }
+}
+
+// ReverseProxy is an HTTP Handler that takes an incoming request and
+// sends it to another server, proxying the response back to the
+// client.
+//
+// 1xx responses are forwarded to the client if the underlying
+// transport supports ClientTrace.Got1xxResponse.
+type ReverseProxy struct {
+ // Rewrite must be a function which modifies
+ // the request into a new request to be sent
+ // using Transport. Its response is then copied
+ // back to the original client unmodified.
+ // Rewrite must not access the provided ProxyRequest
+ // or its contents after returning.
+ //
+ // The Forwarded, X-Forwarded, X-Forwarded-Host,
+ // and X-Forwarded-Proto headers are removed from the
+ // outbound request before Rewrite is called. See also
+ // the ProxyRequest.SetXForwarded method.
+ //
+ // Unparsable query parameters are removed from the
+ // outbound request before Rewrite is called.
+ // The Rewrite function may copy the inbound URL's
+ // RawQuery to the outbound URL to preserve the original
+ // parameter string. Note that this can lead to security
+ // issues if the proxy's interpretation of query parameters
+ // does not match that of the downstream server.
+ //
+ // At most one of Rewrite or Director may be set.
+ Rewrite func(*ProxyRequest)
+
+ // Director is a function which modifies
+ // the request into a new request to be sent
+ // using Transport. Its response is then copied
+ // back to the original client unmodified.
+ // Director must not access the provided Request
+ // after returning.
+ //
+ // By default, the X-Forwarded-For header is set to the
+ // value of the client IP address. If an X-Forwarded-For
+ // header already exists, the client IP is appended to the
+ // existing values. As a special case, if the header
+ // exists in the Request.Header map but has a nil value
+ // (such as when set by the Director func), the X-Forwarded-For
+ // header is not modified.
+ //
+ // To prevent IP spoofing, be sure to delete any pre-existing
+ // X-Forwarded-For header coming from the client or
+ // an untrusted proxy.
+ //
+ // Hop-by-hop headers are removed from the request after
+ // Director returns, which can remove headers added by
+ // Director. Use a Rewrite function instead to ensure
+ // modifications to the request are preserved.
+ //
+ // Unparsable query parameters are removed from the outbound
+ // request if Request.Form is set after Director returns.
+ //
+ // At most one of Rewrite or Director may be set.
+ Director func(*http.Request)
+
+ // The transport used to perform proxy requests.
+ // If nil, http.DefaultTransport is used.
+ Transport http.RoundTripper
+
+ // FlushInterval specifies the flush interval
+ // to flush to the client while copying the
+ // response body.
+ // If zero, no periodic flushing is done.
+ // A negative value means to flush immediately
+ // after each write to the client.
+ // The FlushInterval is ignored when ReverseProxy
+ // recognizes a response as a streaming response, or
+ // if its ContentLength is -1; for such responses, writes
+ // are flushed to the client immediately.
+ FlushInterval time.Duration
+
+ // ErrorLog specifies an optional logger for errors
+ // that occur when attempting to proxy the request.
+ // If nil, logging is done via the log package's standard logger.
+ ErrorLog *log.Logger
+
+ // BufferPool optionally specifies a buffer pool to
+ // get byte slices for use by io.CopyBuffer when
+ // copying HTTP response bodies.
+ BufferPool BufferPool
+
+ // ModifyResponse is an optional function that modifies the
+ // Response from the backend. It is called if the backend
+ // returns a response at all, with any HTTP status code.
+ // If the backend is unreachable, the optional ErrorHandler is
+ // called without any call to ModifyResponse.
+ //
+ // If ModifyResponse returns an error, ErrorHandler is called
+ // with its error value. If ErrorHandler is nil, its default
+ // implementation is used.
+ ModifyResponse func(*http.Response) error
+
+ // ErrorHandler is an optional function that handles errors
+ // reaching the backend or errors from ModifyResponse.
+ //
+ // If nil, the default is to log the provided error and return
+ // a 502 Status Bad Gateway response.
+ ErrorHandler func(http.ResponseWriter, *http.Request, error)
+}
+
+// A BufferPool is an interface for getting and returning temporary
+// byte slices for use by io.CopyBuffer.
+type BufferPool interface {
+ Get() []byte
+ Put([]byte)
+}
+
+func singleJoiningSlash(a, b string) string {
+ aslash := strings.HasSuffix(a, "/")
+ bslash := strings.HasPrefix(b, "/")
+ switch {
+ case aslash && bslash:
+ return a + b[1:]
+ case !aslash && !bslash:
+ return a + "/" + b
+ }
+ return a + b
+}
+
+func joinURLPath(a, b *url.URL) (path, rawpath string) {
+ if a.RawPath == "" && b.RawPath == "" {
+ return singleJoiningSlash(a.Path, b.Path), ""
+ }
+ // Same as singleJoiningSlash, but uses EscapedPath to determine
+ // whether a slash should be added
+ apath := a.EscapedPath()
+ bpath := b.EscapedPath()
+
+ aslash := strings.HasSuffix(apath, "/")
+ bslash := strings.HasPrefix(bpath, "/")
+
+ switch {
+ case aslash && bslash:
+ return a.Path + b.Path[1:], apath + bpath[1:]
+ case !aslash && !bslash:
+ return a.Path + "/" + b.Path, apath + "/" + bpath
+ }
+ return a.Path + b.Path, apath + bpath
+}
+
+// NewSingleHostReverseProxy returns a new ReverseProxy that routes
+// URLs to the scheme, host, and base path provided in target. If the
+// target's path is "/base" and the incoming request was for "/dir",
+// the target request will be for /base/dir.
+//
+// NewSingleHostReverseProxy does not rewrite the Host header.
+//
+// To customize the ReverseProxy behavior beyond what
+// NewSingleHostReverseProxy provides, use ReverseProxy directly
+// with a Rewrite function. The ProxyRequest SetURL method
+// may be used to route the outbound request. (Note that SetURL,
+// unlike NewSingleHostReverseProxy, rewrites the Host header
+// of the outbound request by default.)
+//
+// proxy := &ReverseProxy{
+// Rewrite: func(r *ProxyRequest) {
+// r.SetURL(target)
+// r.Out.Host = r.In.Host // if desired
+// },
+// }
+func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy {
+ director := func(req *http.Request) {
+ rewriteRequestURL(req, target)
+ }
+ return &ReverseProxy{Director: director}
+}
+
+func rewriteRequestURL(req *http.Request, target *url.URL) {
+ targetQuery := target.RawQuery
+ req.URL.Scheme = target.Scheme
+ req.URL.Host = target.Host
+ req.URL.Path, req.URL.RawPath = joinURLPath(target, req.URL)
+ if targetQuery == "" || req.URL.RawQuery == "" {
+ req.URL.RawQuery = targetQuery + req.URL.RawQuery
+ } else {
+ req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
+ }
+}
+
+func copyHeader(dst, src http.Header) {
+ for k, vv := range src {
+ for _, v := range vv {
+ dst.Add(k, v)
+ }
+ }
+}
+
+// Hop-by-hop headers. These are removed when sent to the backend.
+// As of RFC 7230, hop-by-hop headers are required to appear in the
+// Connection header field. These are the headers defined by the
+// obsoleted RFC 2616 (section 13.5.1) and are used for backward
+// compatibility.
+var hopHeaders = []string{
+ "Connection",
+ "Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
+ "Keep-Alive",
+ "Proxy-Authenticate",
+ "Proxy-Authorization",
+ "Te", // canonicalized version of "TE"
+ "Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522
+ "Transfer-Encoding",
+ "Upgrade",
+}
+
+func (p *ReverseProxy) defaultErrorHandler(rw http.ResponseWriter, req *http.Request, err error) {
+ p.logf("http: proxy error: %v", err)
+ rw.WriteHeader(http.StatusBadGateway)
+}
+
+func (p *ReverseProxy) getErrorHandler() func(http.ResponseWriter, *http.Request, error) {
+ if p.ErrorHandler != nil {
+ return p.ErrorHandler
+ }
+ return p.defaultErrorHandler
+}
+
+// modifyResponse conditionally runs the optional ModifyResponse hook
+// and reports whether the request should proceed.
+func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, req *http.Request) bool {
+ if p.ModifyResponse == nil {
+ return true
+ }
+ if err := p.ModifyResponse(res); err != nil {
+ res.Body.Close()
+ p.getErrorHandler()(rw, req, err)
+ return false
+ }
+ return true
+}
+
+func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
+ transport := p.Transport
+ if transport == nil {
+ transport = http.DefaultTransport
+ }
+
+ ctx := req.Context()
+ if ctx.Done() != nil {
+ // CloseNotifier predates context.Context, and has been
+ // entirely superseded by it. If the request contains
+ // a Context that carries a cancellation signal, don't
+ // bother spinning up a goroutine to watch the CloseNotify
+ // channel (if any).
+ //
+ // If the request Context has a nil Done channel (which
+ // means it is either context.Background, or a custom
+ // Context implementation with no cancellation signal),
+ // then consult the CloseNotifier if available.
+ } else if cn, ok := rw.(http.CloseNotifier); ok {
+ var cancel context.CancelFunc
+ ctx, cancel = context.WithCancel(ctx)
+ defer cancel()
+ notifyChan := cn.CloseNotify()
+ go func() {
+ select {
+ case <-notifyChan:
+ cancel()
+ case <-ctx.Done():
+ }
+ }()
+ }
+
+ outreq := req.Clone(ctx)
+ if req.ContentLength == 0 {
+ outreq.Body = nil // Issue 16036: nil Body for http.Transport retries
+ }
+ if outreq.Body != nil {
+ // Reading from the request body after returning from a handler is not
+ // allowed, and the RoundTrip goroutine that reads the Body can outlive
+ // this handler. This can lead to a crash if the handler panics (see
+ // Issue 46866). Although calling Close doesn't guarantee there isn't
+ // any Read in flight after the handle returns, in practice it's safe to
+ // read after closing it.
+ defer outreq.Body.Close()
+ }
+ if outreq.Header == nil {
+ outreq.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate
+ }
+
+ if (p.Director != nil) == (p.Rewrite != nil) {
+ p.getErrorHandler()(rw, req, errors.New("ReverseProxy must have exactly one of Director or Rewrite set"))
+ return
+ }
+
+ if p.Director != nil {
+ p.Director(outreq)
+ if outreq.Form != nil {
+ outreq.URL.RawQuery = cleanQueryParams(outreq.URL.RawQuery)
+ }
+ }
+ outreq.Close = false
+
+ reqUpType := upgradeType(outreq.Header)
+ if !ascii.IsPrint(reqUpType) {
+ p.getErrorHandler()(rw, req, fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType))
+ return
+ }
+ removeHopByHopHeaders(outreq.Header)
+
+ // Issue 21096: tell backend applications that care about trailer support
+ // that we support trailers. (We do, but we don't go out of our way to
+ // advertise that unless the incoming client request thought it was worth
+ // mentioning.) Note that we look at req.Header, not outreq.Header, since
+ // the latter has passed through removeHopByHopHeaders.
+ if httpguts.HeaderValuesContainsToken(req.Header["Te"], "trailers") {
+ outreq.Header.Set("Te", "trailers")
+ }
+
+ // After stripping all the hop-by-hop connection headers above, add back any
+ // necessary for protocol upgrades, such as for websockets.
+ if reqUpType != "" {
+ outreq.Header.Set("Connection", "Upgrade")
+ outreq.Header.Set("Upgrade", reqUpType)
+ }
+
+ if p.Rewrite != nil {
+ // Strip client-provided forwarding headers.
+ // The Rewrite func may use SetXForwarded to set new values
+ // for these or copy the previous values from the inbound request.
+ outreq.Header.Del("Forwarded")
+ outreq.Header.Del("X-Forwarded-For")
+ outreq.Header.Del("X-Forwarded-Host")
+ outreq.Header.Del("X-Forwarded-Proto")
+
+ // Remove unparsable query parameters from the outbound request.
+ outreq.URL.RawQuery = cleanQueryParams(outreq.URL.RawQuery)
+
+ pr := &ProxyRequest{
+ In: req,
+ Out: outreq,
+ }
+ p.Rewrite(pr)
+ outreq = pr.Out
+ } else {
+ if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
+ // If we aren't the first proxy retain prior
+ // X-Forwarded-For information as a comma+space
+ // separated list and fold multiple headers into one.
+ prior, ok := outreq.Header["X-Forwarded-For"]
+ omit := ok && prior == nil // Issue 38079: nil now means don't populate the header
+ if len(prior) > 0 {
+ clientIP = strings.Join(prior, ", ") + ", " + clientIP
+ }
+ if !omit {
+ outreq.Header.Set("X-Forwarded-For", clientIP)
+ }
+ }
+ }
+
+ if _, ok := outreq.Header["User-Agent"]; !ok {
+ // If the outbound request doesn't have a User-Agent header set,
+ // don't send the default Go HTTP client User-Agent.
+ outreq.Header.Set("User-Agent", "")
+ }
+
+ trace := &httptrace.ClientTrace{
+ Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
+ h := rw.Header()
+ copyHeader(h, http.Header(header))
+ rw.WriteHeader(code)
+
+ // Clear headers, it's not automatically done by ResponseWriter.WriteHeader() for 1xx responses
+ for k := range h {
+ delete(h, k)
+ }
+
+ return nil
+ },
+ }
+ outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace))
+
+ res, err := transport.RoundTrip(outreq)
+ if err != nil {
+ p.getErrorHandler()(rw, outreq, err)
+ return
+ }
+
+ // Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
+ if res.StatusCode == http.StatusSwitchingProtocols {
+ if !p.modifyResponse(rw, res, outreq) {
+ return
+ }
+ p.handleUpgradeResponse(rw, outreq, res)
+ return
+ }
+
+ removeHopByHopHeaders(res.Header)
+
+ if !p.modifyResponse(rw, res, outreq) {
+ return
+ }
+
+ copyHeader(rw.Header(), res.Header)
+
+ // The "Trailer" header isn't included in the Transport's response,
+ // at least for *http.Transport. Build it up from Trailer.
+ announcedTrailers := len(res.Trailer)
+ if announcedTrailers > 0 {
+ trailerKeys := make([]string, 0, len(res.Trailer))
+ for k := range res.Trailer {
+ trailerKeys = append(trailerKeys, k)
+ }
+ rw.Header().Add("Trailer", strings.Join(trailerKeys, ", "))
+ }
+
+ rw.WriteHeader(res.StatusCode)
+
+ err = p.copyResponse(rw, res.Body, p.flushInterval(res))
+ if err != nil {
+ defer res.Body.Close()
+ // Since we're streaming the response, if we run into an error all we can do
+ // is abort the request. Issue 23643: ReverseProxy should use ErrAbortHandler
+ // on read error while copying body.
+ if !shouldPanicOnCopyError(req) {
+ p.logf("suppressing panic for copyResponse error in test; copy error: %v", err)
+ return
+ }
+ panic(http.ErrAbortHandler)
+ }
+ res.Body.Close() // close now, instead of defer, to populate res.Trailer
+
+ if len(res.Trailer) > 0 {
+ // Force chunking if we saw a response trailer.
+ // This prevents net/http from calculating the length for short
+ // bodies and adding a Content-Length.
+ http.NewResponseController(rw).Flush()
+ }
+
+ if len(res.Trailer) == announcedTrailers {
+ copyHeader(rw.Header(), res.Trailer)
+ return
+ }
+
+ for k, vv := range res.Trailer {
+ k = http.TrailerPrefix + k
+ for _, v := range vv {
+ rw.Header().Add(k, v)
+ }
+ }
+}
+
+var inOurTests bool // whether we're in our own tests
+
+// shouldPanicOnCopyError reports whether the reverse proxy should
+// panic with http.ErrAbortHandler. This is the right thing to do by
+// default, but Go 1.10 and earlier did not, so existing unit tests
+// weren't expecting panics. Only panic in our own tests, or when
+// running under the HTTP server.
+func shouldPanicOnCopyError(req *http.Request) bool {
+ if inOurTests {
+ // Our tests know to handle this panic.
+ return true
+ }
+ if req.Context().Value(http.ServerContextKey) != nil {
+ // We seem to be running under an HTTP server, so
+ // it'll recover the panic.
+ return true
+ }
+ // Otherwise act like Go 1.10 and earlier to not break
+ // existing tests.
+ return false
+}
+
+// removeHopByHopHeaders removes hop-by-hop headers.
+func removeHopByHopHeaders(h http.Header) {
+ // RFC 7230, section 6.1: Remove headers listed in the "Connection" header.
+ for _, f := range h["Connection"] {
+ for _, sf := range strings.Split(f, ",") {
+ if sf = textproto.TrimString(sf); sf != "" {
+ h.Del(sf)
+ }
+ }
+ }
+ // RFC 2616, section 13.5.1: Remove a set of known hop-by-hop headers.
+ // This behavior is superseded by the RFC 7230 Connection header, but
+ // preserve it for backwards compatibility.
+ for _, f := range hopHeaders {
+ h.Del(f)
+ }
+}
+
+// flushInterval returns the p.FlushInterval value, conditionally
+// overriding its value for a specific request/response.
+func (p *ReverseProxy) flushInterval(res *http.Response) time.Duration {
+ resCT := res.Header.Get("Content-Type")
+
+ // For Server-Sent Events responses, flush immediately.
+ // The MIME type is defined in https://www.w3.org/TR/eventsource/#text-event-stream
+ if baseCT, _, _ := mime.ParseMediaType(resCT); baseCT == "text/event-stream" {
+ return -1 // negative means immediately
+ }
+
+ // We might have the case of streaming for which Content-Length might be unset.
+ if res.ContentLength == -1 {
+ return -1
+ }
+
+ return p.FlushInterval
+}
+
+func (p *ReverseProxy) copyResponse(dst http.ResponseWriter, src io.Reader, flushInterval time.Duration) error {
+ var w io.Writer = dst
+
+ if flushInterval != 0 {
+ mlw := &maxLatencyWriter{
+ dst: dst,
+ flush: http.NewResponseController(dst).Flush,
+ latency: flushInterval,
+ }
+ defer mlw.stop()
+
+ // set up initial timer so headers get flushed even if body writes are delayed
+ mlw.flushPending = true
+ mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush)
+
+ w = mlw
+ }
+
+ var buf []byte
+ if p.BufferPool != nil {
+ buf = p.BufferPool.Get()
+ defer p.BufferPool.Put(buf)
+ }
+ _, err := p.copyBuffer(w, src, buf)
+ return err
+}
+
+// copyBuffer returns any write errors or non-EOF read errors, and the amount
+// of bytes written.
+func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) {
+ if len(buf) == 0 {
+ buf = make([]byte, 32*1024)
+ }
+ var written int64
+ for {
+ nr, rerr := src.Read(buf)
+ if rerr != nil && rerr != io.EOF && rerr != context.Canceled {
+ p.logf("httputil: ReverseProxy read error during body copy: %v", rerr)
+ }
+ if nr > 0 {
+ nw, werr := dst.Write(buf[:nr])
+ if nw > 0 {
+ written += int64(nw)
+ }
+ if werr != nil {
+ return written, werr
+ }
+ if nr != nw {
+ return written, io.ErrShortWrite
+ }
+ }
+ if rerr != nil {
+ if rerr == io.EOF {
+ rerr = nil
+ }
+ return written, rerr
+ }
+ }
+}
+
+func (p *ReverseProxy) logf(format string, args ...any) {
+ if p.ErrorLog != nil {
+ p.ErrorLog.Printf(format, args...)
+ } else {
+ log.Printf(format, args...)
+ }
+}
+
+type maxLatencyWriter struct {
+ dst io.Writer
+ flush func() error
+ latency time.Duration // non-zero; negative means to flush immediately
+
+ mu sync.Mutex // protects t, flushPending, and dst.Flush
+ t *time.Timer
+ flushPending bool
+}
+
+func (m *maxLatencyWriter) Write(p []byte) (n int, err error) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ n, err = m.dst.Write(p)
+ if m.latency < 0 {
+ m.flush()
+ return
+ }
+ if m.flushPending {
+ return
+ }
+ if m.t == nil {
+ m.t = time.AfterFunc(m.latency, m.delayedFlush)
+ } else {
+ m.t.Reset(m.latency)
+ }
+ m.flushPending = true
+ return
+}
+
+func (m *maxLatencyWriter) delayedFlush() {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ if !m.flushPending { // if stop was called but AfterFunc already started this goroutine
+ return
+ }
+ m.flush()
+ m.flushPending = false
+}
+
+func (m *maxLatencyWriter) stop() {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.flushPending = false
+ if m.t != nil {
+ m.t.Stop()
+ }
+}
+
+func upgradeType(h http.Header) string {
+ if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") {
+ return ""
+ }
+ return h.Get("Upgrade")
+}
+
+func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) {
+ reqUpType := upgradeType(req.Header)
+ resUpType := upgradeType(res.Header)
+ if !ascii.IsPrint(resUpType) { // We know reqUpType is ASCII, it's checked by the caller.
+ p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType))
+ }
+ if !ascii.EqualFold(reqUpType, resUpType) {
+ p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType))
+ return
+ }
+
+ backConn, ok := res.Body.(io.ReadWriteCloser)
+ if !ok {
+ p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body"))
+ return
+ }
+
+ rc := http.NewResponseController(rw)
+ conn, brw, hijackErr := rc.Hijack()
+ if errors.Is(hijackErr, http.ErrNotSupported) {
+ p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw))
+ return
+ }
+
+ backConnCloseCh := make(chan bool)
+ go func() {
+ // Ensure that the cancellation of a request closes the backend.
+ // See issue https://golang.org/issue/35559.
+ select {
+ case <-req.Context().Done():
+ case <-backConnCloseCh:
+ }
+ backConn.Close()
+ }()
+ defer close(backConnCloseCh)
+
+ if hijackErr != nil {
+ p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", hijackErr))
+ return
+ }
+ defer conn.Close()
+
+ copyHeader(rw.Header(), res.Header)
+
+ res.Header = rw.Header()
+ res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above
+ if err := res.Write(brw); err != nil {
+ p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err))
+ return
+ }
+ if err := brw.Flush(); err != nil {
+ p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err))
+ return
+ }
+ errc := make(chan error, 1)
+ spc := switchProtocolCopier{user: conn, backend: backConn}
+ go spc.copyToBackend(errc)
+ go spc.copyFromBackend(errc)
+ <-errc
+}
+
+// switchProtocolCopier exists so goroutines proxying data back and
+// forth have nice names in stacks.
+type switchProtocolCopier struct {
+ user, backend io.ReadWriter
+}
+
+func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
+ _, err := io.Copy(c.user, c.backend)
+ errc <- err
+}
+
+func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
+ _, err := io.Copy(c.backend, c.user)
+ errc <- err
+}
+
+func cleanQueryParams(s string) string {
+ reencode := func(s string) string {
+ v, _ := url.ParseQuery(s)
+ return v.Encode()
+ }
+ for i := 0; i < len(s); {
+ switch s[i] {
+ case ';':
+ return reencode(s)
+ case '%':
+ if i+2 >= len(s) || !ishex(s[i+1]) || !ishex(s[i+2]) {
+ return reencode(s)
+ }
+ i += 3
+ default:
+ i++
+ }
+ }
+ return s
+}
+
+func ishex(c byte) bool {
+ switch {
+ case '0' <= c && c <= '9':
+ return true
+ case 'a' <= c && c <= 'f':
+ return true
+ case 'A' <= c && c <= 'F':
+ return true
+ }
+ return false
+}
diff --git a/src/net/http/httputil/reverseproxy_test.go b/src/net/http/httputil/reverseproxy_test.go
new file mode 100644
index 0000000..dd3330b
--- /dev/null
+++ b/src/net/http/httputil/reverseproxy_test.go
@@ -0,0 +1,1863 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Reverse proxy tests.
+
+package httputil
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "log"
+ "net/http"
+ "net/http/httptest"
+ "net/http/httptrace"
+ "net/http/internal/ascii"
+ "net/textproto"
+ "net/url"
+ "os"
+ "reflect"
+ "sort"
+ "strconv"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+)
+
+const fakeHopHeader = "X-Fake-Hop-Header-For-Test"
+
+func init() {
+ inOurTests = true
+ hopHeaders = append(hopHeaders, fakeHopHeader)
+}
+
+func TestReverseProxy(t *testing.T) {
+ const backendResponse = "I am the backend"
+ const backendStatus = 404
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.Method == "GET" && r.FormValue("mode") == "hangup" {
+ c, _, _ := w.(http.Hijacker).Hijack()
+ c.Close()
+ return
+ }
+ if len(r.TransferEncoding) > 0 {
+ t.Errorf("backend got unexpected TransferEncoding: %v", r.TransferEncoding)
+ }
+ if r.Header.Get("X-Forwarded-For") == "" {
+ t.Errorf("didn't get X-Forwarded-For header")
+ }
+ if c := r.Header.Get("Connection"); c != "" {
+ t.Errorf("handler got Connection header value %q", c)
+ }
+ if c := r.Header.Get("Te"); c != "trailers" {
+ t.Errorf("handler got Te header value %q; want 'trailers'", c)
+ }
+ if c := r.Header.Get("Upgrade"); c != "" {
+ t.Errorf("handler got Upgrade header value %q", c)
+ }
+ if c := r.Header.Get("Proxy-Connection"); c != "" {
+ t.Errorf("handler got Proxy-Connection header value %q", c)
+ }
+ if g, e := r.Host, "some-name"; g != e {
+ t.Errorf("backend got Host header %q, want %q", g, e)
+ }
+ w.Header().Set("Trailers", "not a special header field name")
+ w.Header().Set("Trailer", "X-Trailer")
+ w.Header().Set("X-Foo", "bar")
+ w.Header().Set("Upgrade", "foo")
+ w.Header().Set(fakeHopHeader, "foo")
+ w.Header().Add("X-Multi-Value", "foo")
+ w.Header().Add("X-Multi-Value", "bar")
+ http.SetCookie(w, &http.Cookie{Name: "flavor", Value: "chocolateChip"})
+ w.WriteHeader(backendStatus)
+ w.Write([]byte(backendResponse))
+ w.Header().Set("X-Trailer", "trailer_value")
+ w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value")
+ }))
+ defer backend.Close()
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+ frontendClient := frontend.Client()
+
+ getReq, _ := http.NewRequest("GET", frontend.URL, nil)
+ getReq.Host = "some-name"
+ getReq.Header.Set("Connection", "close, TE")
+ getReq.Header.Add("Te", "foo")
+ getReq.Header.Add("Te", "bar, trailers")
+ getReq.Header.Set("Proxy-Connection", "should be deleted")
+ getReq.Header.Set("Upgrade", "foo")
+ getReq.Close = true
+ res, err := frontendClient.Do(getReq)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ if g, e := res.StatusCode, backendStatus; g != e {
+ t.Errorf("got res.StatusCode %d; expected %d", g, e)
+ }
+ if g, e := res.Header.Get("X-Foo"), "bar"; g != e {
+ t.Errorf("got X-Foo %q; expected %q", g, e)
+ }
+ if c := res.Header.Get(fakeHopHeader); c != "" {
+ t.Errorf("got %s header value %q", fakeHopHeader, c)
+ }
+ if g, e := res.Header.Get("Trailers"), "not a special header field name"; g != e {
+ t.Errorf("header Trailers = %q; want %q", g, e)
+ }
+ if g, e := len(res.Header["X-Multi-Value"]), 2; g != e {
+ t.Errorf("got %d X-Multi-Value header values; expected %d", g, e)
+ }
+ if g, e := len(res.Header["Set-Cookie"]), 1; g != e {
+ t.Fatalf("got %d SetCookies, want %d", g, e)
+ }
+ if g, e := res.Trailer, (http.Header{"X-Trailer": nil}); !reflect.DeepEqual(g, e) {
+ t.Errorf("before reading body, Trailer = %#v; want %#v", g, e)
+ }
+ if cookie := res.Cookies()[0]; cookie.Name != "flavor" {
+ t.Errorf("unexpected cookie %q", cookie.Name)
+ }
+ bodyBytes, _ := io.ReadAll(res.Body)
+ if g, e := string(bodyBytes), backendResponse; g != e {
+ t.Errorf("got body %q; expected %q", g, e)
+ }
+ if g, e := res.Trailer.Get("X-Trailer"), "trailer_value"; g != e {
+ t.Errorf("Trailer(X-Trailer) = %q ; want %q", g, e)
+ }
+ if g, e := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != e {
+ t.Errorf("Trailer(X-Unannounced-Trailer) = %q ; want %q", g, e)
+ }
+
+ // Test that a backend failing to be reached or one which doesn't return
+ // a response results in a StatusBadGateway.
+ getReq, _ = http.NewRequest("GET", frontend.URL+"/?mode=hangup", nil)
+ getReq.Close = true
+ res, err = frontendClient.Do(getReq)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ if res.StatusCode != http.StatusBadGateway {
+ t.Errorf("request to bad proxy = %v; want 502 StatusBadGateway", res.Status)
+ }
+
+}
+
+// Issue 16875: remove any proxied headers mentioned in the "Connection"
+// header value.
+func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) {
+ const fakeConnectionToken = "X-Fake-Connection-Token"
+ const backendResponse = "I am the backend"
+
+ // someConnHeader is some arbitrary header to be declared as a hop-by-hop header
+ // in the Request's Connection header.
+ const someConnHeader = "X-Some-Conn-Header"
+
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if c := r.Header.Get("Connection"); c != "" {
+ t.Errorf("handler got header %q = %q; want empty", "Connection", c)
+ }
+ if c := r.Header.Get(fakeConnectionToken); c != "" {
+ t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
+ }
+ if c := r.Header.Get(someConnHeader); c != "" {
+ t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
+ }
+ w.Header().Add("Connection", "Upgrade, "+fakeConnectionToken)
+ w.Header().Add("Connection", someConnHeader)
+ w.Header().Set(someConnHeader, "should be deleted")
+ w.Header().Set(fakeConnectionToken, "should be deleted")
+ io.WriteString(w, backendResponse)
+ }))
+ defer backend.Close()
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ proxyHandler.ServeHTTP(w, r)
+ if c := r.Header.Get(someConnHeader); c != "should be deleted" {
+ t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted")
+ }
+ if c := r.Header.Get(fakeConnectionToken); c != "should be deleted" {
+ t.Errorf("handler modified header %q = %q; want %q", fakeConnectionToken, c, "should be deleted")
+ }
+ c := r.Header["Connection"]
+ var cf []string
+ for _, f := range c {
+ for _, sf := range strings.Split(f, ",") {
+ if sf = strings.TrimSpace(sf); sf != "" {
+ cf = append(cf, sf)
+ }
+ }
+ }
+ sort.Strings(cf)
+ expectedValues := []string{"Upgrade", someConnHeader, fakeConnectionToken}
+ sort.Strings(expectedValues)
+ if !reflect.DeepEqual(cf, expectedValues) {
+ t.Errorf("handler modified header %q = %q; want %q", "Connection", cf, expectedValues)
+ }
+ }))
+ defer frontend.Close()
+
+ getReq, _ := http.NewRequest("GET", frontend.URL, nil)
+ getReq.Header.Add("Connection", "Upgrade, "+fakeConnectionToken)
+ getReq.Header.Add("Connection", someConnHeader)
+ getReq.Header.Set(someConnHeader, "should be deleted")
+ getReq.Header.Set(fakeConnectionToken, "should be deleted")
+ res, err := frontend.Client().Do(getReq)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ defer res.Body.Close()
+ bodyBytes, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatalf("reading body: %v", err)
+ }
+ if got, want := string(bodyBytes), backendResponse; got != want {
+ t.Errorf("got body %q; want %q", got, want)
+ }
+ if c := res.Header.Get("Connection"); c != "" {
+ t.Errorf("handler got header %q = %q; want empty", "Connection", c)
+ }
+ if c := res.Header.Get(someConnHeader); c != "" {
+ t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
+ }
+ if c := res.Header.Get(fakeConnectionToken); c != "" {
+ t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
+ }
+}
+
+func TestReverseProxyStripEmptyConnection(t *testing.T) {
+ // See Issue 46313.
+ const backendResponse = "I am the backend"
+
+ // someConnHeader is some arbitrary header to be declared as a hop-by-hop header
+ // in the Request's Connection header.
+ const someConnHeader = "X-Some-Conn-Header"
+
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if c := r.Header.Values("Connection"); len(c) != 0 {
+ t.Errorf("handler got header %q = %v; want empty", "Connection", c)
+ }
+ if c := r.Header.Get(someConnHeader); c != "" {
+ t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
+ }
+ w.Header().Add("Connection", "")
+ w.Header().Add("Connection", someConnHeader)
+ w.Header().Set(someConnHeader, "should be deleted")
+ io.WriteString(w, backendResponse)
+ }))
+ defer backend.Close()
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ proxyHandler.ServeHTTP(w, r)
+ if c := r.Header.Get(someConnHeader); c != "should be deleted" {
+ t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted")
+ }
+ }))
+ defer frontend.Close()
+
+ getReq, _ := http.NewRequest("GET", frontend.URL, nil)
+ getReq.Header.Add("Connection", "")
+ getReq.Header.Add("Connection", someConnHeader)
+ getReq.Header.Set(someConnHeader, "should be deleted")
+ res, err := frontend.Client().Do(getReq)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ defer res.Body.Close()
+ bodyBytes, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatalf("reading body: %v", err)
+ }
+ if got, want := string(bodyBytes), backendResponse; got != want {
+ t.Errorf("got body %q; want %q", got, want)
+ }
+ if c := res.Header.Get("Connection"); c != "" {
+ t.Errorf("handler got header %q = %q; want empty", "Connection", c)
+ }
+ if c := res.Header.Get(someConnHeader); c != "" {
+ t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
+ }
+}
+
+func TestXForwardedFor(t *testing.T) {
+ const prevForwardedFor = "client ip"
+ const backendResponse = "I am the backend"
+ const backendStatus = 404
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.Header.Get("X-Forwarded-For") == "" {
+ t.Errorf("didn't get X-Forwarded-For header")
+ }
+ if !strings.Contains(r.Header.Get("X-Forwarded-For"), prevForwardedFor) {
+ t.Errorf("X-Forwarded-For didn't contain prior data")
+ }
+ w.WriteHeader(backendStatus)
+ w.Write([]byte(backendResponse))
+ }))
+ defer backend.Close()
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+
+ getReq, _ := http.NewRequest("GET", frontend.URL, nil)
+ getReq.Header.Set("Connection", "close")
+ getReq.Header.Set("X-Forwarded-For", prevForwardedFor)
+ getReq.Close = true
+ res, err := frontend.Client().Do(getReq)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ if g, e := res.StatusCode, backendStatus; g != e {
+ t.Errorf("got res.StatusCode %d; expected %d", g, e)
+ }
+ bodyBytes, _ := io.ReadAll(res.Body)
+ if g, e := string(bodyBytes), backendResponse; g != e {
+ t.Errorf("got body %q; expected %q", g, e)
+ }
+}
+
+// Issue 38079: don't append to X-Forwarded-For if it's present but nil
+func TestXForwardedFor_Omit(t *testing.T) {
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if v := r.Header.Get("X-Forwarded-For"); v != "" {
+ t.Errorf("got X-Forwarded-For header: %q", v)
+ }
+ w.Write([]byte("hi"))
+ }))
+ defer backend.Close()
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+
+ oldDirector := proxyHandler.Director
+ proxyHandler.Director = func(r *http.Request) {
+ r.Header["X-Forwarded-For"] = nil
+ oldDirector(r)
+ }
+
+ getReq, _ := http.NewRequest("GET", frontend.URL, nil)
+ getReq.Host = "some-name"
+ getReq.Close = true
+ res, err := frontend.Client().Do(getReq)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ res.Body.Close()
+}
+
+func TestReverseProxyRewriteStripsForwarded(t *testing.T) {
+ headers := []string{
+ "Forwarded",
+ "X-Forwarded-For",
+ "X-Forwarded-Host",
+ "X-Forwarded-Proto",
+ }
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ for _, h := range headers {
+ if v := r.Header.Get(h); v != "" {
+ t.Errorf("got %v header: %q", h, v)
+ }
+ }
+ }))
+ defer backend.Close()
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ proxyHandler := &ReverseProxy{
+ Rewrite: func(r *ProxyRequest) {
+ r.SetURL(backendURL)
+ },
+ }
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+
+ getReq, _ := http.NewRequest("GET", frontend.URL, nil)
+ getReq.Host = "some-name"
+ getReq.Close = true
+ for _, h := range headers {
+ getReq.Header.Set(h, "x")
+ }
+ res, err := frontend.Client().Do(getReq)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ res.Body.Close()
+}
+
+var proxyQueryTests = []struct {
+ baseSuffix string // suffix to add to backend URL
+ reqSuffix string // suffix to add to frontend's request URL
+ want string // what backend should see for final request URL (without ?)
+}{
+ {"", "", ""},
+ {"?sta=tic", "?us=er", "sta=tic&us=er"},
+ {"", "?us=er", "us=er"},
+ {"?sta=tic", "", "sta=tic"},
+}
+
+func TestReverseProxyQuery(t *testing.T) {
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("X-Got-Query", r.URL.RawQuery)
+ w.Write([]byte("hi"))
+ }))
+ defer backend.Close()
+
+ for i, tt := range proxyQueryTests {
+ backendURL, err := url.Parse(backend.URL + tt.baseSuffix)
+ if err != nil {
+ t.Fatal(err)
+ }
+ frontend := httptest.NewServer(NewSingleHostReverseProxy(backendURL))
+ req, _ := http.NewRequest("GET", frontend.URL+tt.reqSuffix, nil)
+ req.Close = true
+ res, err := frontend.Client().Do(req)
+ if err != nil {
+ t.Fatalf("%d. Get: %v", i, err)
+ }
+ if g, e := res.Header.Get("X-Got-Query"), tt.want; g != e {
+ t.Errorf("%d. got query %q; expected %q", i, g, e)
+ }
+ res.Body.Close()
+ frontend.Close()
+ }
+}
+
+func TestReverseProxyFlushInterval(t *testing.T) {
+ const expected = "hi"
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte(expected))
+ }))
+ defer backend.Close()
+
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ proxyHandler.FlushInterval = time.Microsecond
+
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+
+ req, _ := http.NewRequest("GET", frontend.URL, nil)
+ req.Close = true
+ res, err := frontend.Client().Do(req)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ defer res.Body.Close()
+ if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expected {
+ t.Errorf("got body %q; expected %q", bodyBytes, expected)
+ }
+}
+
+type mockFlusher struct {
+ http.ResponseWriter
+ flushed bool
+}
+
+func (m *mockFlusher) Flush() {
+ m.flushed = true
+}
+
+type wrappedRW struct {
+ http.ResponseWriter
+}
+
+func (w *wrappedRW) Unwrap() http.ResponseWriter {
+ return w.ResponseWriter
+}
+
+func TestReverseProxyResponseControllerFlushInterval(t *testing.T) {
+ const expected = "hi"
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte(expected))
+ }))
+ defer backend.Close()
+
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ mf := &mockFlusher{}
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ proxyHandler.FlushInterval = -1 // flush immediately
+ proxyWithMiddleware := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ mf.ResponseWriter = w
+ w = &wrappedRW{mf}
+ proxyHandler.ServeHTTP(w, r)
+ })
+
+ frontend := httptest.NewServer(proxyWithMiddleware)
+ defer frontend.Close()
+
+ req, _ := http.NewRequest("GET", frontend.URL, nil)
+ req.Close = true
+ res, err := frontend.Client().Do(req)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ defer res.Body.Close()
+ if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expected {
+ t.Errorf("got body %q; expected %q", bodyBytes, expected)
+ }
+ if !mf.flushed {
+ t.Errorf("response writer was not flushed")
+ }
+}
+
+func TestReverseProxyFlushIntervalHeaders(t *testing.T) {
+ const expected = "hi"
+ stopCh := make(chan struct{})
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Add("MyHeader", expected)
+ w.WriteHeader(200)
+ w.(http.Flusher).Flush()
+ <-stopCh
+ }))
+ defer backend.Close()
+ defer close(stopCh)
+
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ proxyHandler.FlushInterval = time.Microsecond
+
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+
+ req, _ := http.NewRequest("GET", frontend.URL, nil)
+ req.Close = true
+
+ ctx, cancel := context.WithTimeout(req.Context(), 10*time.Second)
+ defer cancel()
+ req = req.WithContext(ctx)
+
+ res, err := frontend.Client().Do(req)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ defer res.Body.Close()
+
+ if res.Header.Get("MyHeader") != expected {
+ t.Errorf("got header %q; expected %q", res.Header.Get("MyHeader"), expected)
+ }
+}
+
+func TestReverseProxyCancellation(t *testing.T) {
+ const backendResponse = "I am the backend"
+
+ reqInFlight := make(chan struct{})
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ close(reqInFlight) // cause the client to cancel its request
+
+ select {
+ case <-time.After(10 * time.Second):
+ // Note: this should only happen in broken implementations, and the
+ // closenotify case should be instantaneous.
+ t.Error("Handler never saw CloseNotify")
+ return
+ case <-w.(http.CloseNotifier).CloseNotify():
+ }
+
+ w.WriteHeader(http.StatusOK)
+ w.Write([]byte(backendResponse))
+ }))
+
+ defer backend.Close()
+
+ backend.Config.ErrorLog = log.New(io.Discard, "", 0)
+
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+
+ // Discards errors of the form:
+ // http: proxy error: read tcp 127.0.0.1:44643: use of closed network connection
+ proxyHandler.ErrorLog = log.New(io.Discard, "", 0)
+
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+ frontendClient := frontend.Client()
+
+ getReq, _ := http.NewRequest("GET", frontend.URL, nil)
+ go func() {
+ <-reqInFlight
+ frontendClient.Transport.(*http.Transport).CancelRequest(getReq)
+ }()
+ res, err := frontendClient.Do(getReq)
+ if res != nil {
+ t.Errorf("got response %v; want nil", res.Status)
+ }
+ if err == nil {
+ // This should be an error like:
+ // Get "http://127.0.0.1:58079": read tcp 127.0.0.1:58079:
+ // use of closed network connection
+ t.Error("Server.Client().Do() returned nil error; want non-nil error")
+ }
+}
+
+func req(t *testing.T, v string) *http.Request {
+ req, err := http.ReadRequest(bufio.NewReader(strings.NewReader(v)))
+ if err != nil {
+ t.Fatal(err)
+ }
+ return req
+}
+
+// Issue 12344
+func TestNilBody(t *testing.T) {
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte("hi"))
+ }))
+ defer backend.Close()
+
+ frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ backURL, _ := url.Parse(backend.URL)
+ rp := NewSingleHostReverseProxy(backURL)
+ r := req(t, "GET / HTTP/1.0\r\n\r\n")
+ r.Body = nil // this accidentally worked in Go 1.4 and below, so keep it working
+ rp.ServeHTTP(w, r)
+ }))
+ defer frontend.Close()
+
+ res, err := http.Get(frontend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ slurp, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if string(slurp) != "hi" {
+ t.Errorf("Got %q; want %q", slurp, "hi")
+ }
+}
+
+// Issue 15524
+func TestUserAgentHeader(t *testing.T) {
+ var gotUA string
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ gotUA = r.Header.Get("User-Agent")
+ }))
+ defer backend.Close()
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ proxyHandler := new(ReverseProxy)
+ proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
+ proxyHandler.Director = func(req *http.Request) {
+ req.URL = backendURL
+ }
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+ frontendClient := frontend.Client()
+
+ for _, sentUA := range []string{"explicit UA", ""} {
+ getReq, _ := http.NewRequest("GET", frontend.URL, nil)
+ getReq.Header.Set("User-Agent", sentUA)
+ getReq.Close = true
+ res, err := frontendClient.Do(getReq)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ res.Body.Close()
+ if got, want := gotUA, sentUA; got != want {
+ t.Errorf("got forwarded User-Agent %q, want %q", got, want)
+ }
+ }
+}
+
+type bufferPool struct {
+ get func() []byte
+ put func([]byte)
+}
+
+func (bp bufferPool) Get() []byte { return bp.get() }
+func (bp bufferPool) Put(v []byte) { bp.put(v) }
+
+func TestReverseProxyGetPutBuffer(t *testing.T) {
+ const msg = "hi"
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ io.WriteString(w, msg)
+ }))
+ defer backend.Close()
+
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ var (
+ mu sync.Mutex
+ log []string
+ )
+ addLog := func(event string) {
+ mu.Lock()
+ defer mu.Unlock()
+ log = append(log, event)
+ }
+ rp := NewSingleHostReverseProxy(backendURL)
+ const size = 1234
+ rp.BufferPool = bufferPool{
+ get: func() []byte {
+ addLog("getBuf")
+ return make([]byte, size)
+ },
+ put: func(p []byte) {
+ addLog("putBuf-" + strconv.Itoa(len(p)))
+ },
+ }
+ frontend := httptest.NewServer(rp)
+ defer frontend.Close()
+
+ req, _ := http.NewRequest("GET", frontend.URL, nil)
+ req.Close = true
+ res, err := frontend.Client().Do(req)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ slurp, err := io.ReadAll(res.Body)
+ res.Body.Close()
+ if err != nil {
+ t.Fatalf("reading body: %v", err)
+ }
+ if string(slurp) != msg {
+ t.Errorf("msg = %q; want %q", slurp, msg)
+ }
+ wantLog := []string{"getBuf", "putBuf-" + strconv.Itoa(size)}
+ mu.Lock()
+ defer mu.Unlock()
+ if !reflect.DeepEqual(log, wantLog) {
+ t.Errorf("Log events = %q; want %q", log, wantLog)
+ }
+}
+
+func TestReverseProxy_Post(t *testing.T) {
+ const backendResponse = "I am the backend"
+ const backendStatus = 200
+ var requestBody = bytes.Repeat([]byte("a"), 1<<20)
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ slurp, err := io.ReadAll(r.Body)
+ if err != nil {
+ t.Errorf("Backend body read = %v", err)
+ }
+ if len(slurp) != len(requestBody) {
+ t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody))
+ }
+ if !bytes.Equal(slurp, requestBody) {
+ t.Error("Backend read wrong request body.") // 1MB; omitting details
+ }
+ w.Write([]byte(backendResponse))
+ }))
+ defer backend.Close()
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+
+ postReq, _ := http.NewRequest("POST", frontend.URL, bytes.NewReader(requestBody))
+ res, err := frontend.Client().Do(postReq)
+ if err != nil {
+ t.Fatalf("Do: %v", err)
+ }
+ if g, e := res.StatusCode, backendStatus; g != e {
+ t.Errorf("got res.StatusCode %d; expected %d", g, e)
+ }
+ bodyBytes, _ := io.ReadAll(res.Body)
+ if g, e := string(bodyBytes), backendResponse; g != e {
+ t.Errorf("got body %q; expected %q", g, e)
+ }
+}
+
+type RoundTripperFunc func(*http.Request) (*http.Response, error)
+
+func (fn RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
+ return fn(req)
+}
+
+// Issue 16036: send a Request with a nil Body when possible
+func TestReverseProxy_NilBody(t *testing.T) {
+ backendURL, _ := url.Parse("http://fake.tld/")
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
+ proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
+ if req.Body != nil {
+ t.Error("Body != nil; want a nil Body")
+ }
+ return nil, errors.New("done testing the interesting part; so force a 502 Gateway error")
+ })
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+
+ res, err := frontend.Client().Get(frontend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ if res.StatusCode != 502 {
+ t.Errorf("status code = %v; want 502 (Gateway Error)", res.Status)
+ }
+}
+
+// Issue 33142: always allocate the request headers
+func TestReverseProxy_AllocatedHeader(t *testing.T) {
+ proxyHandler := new(ReverseProxy)
+ proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
+ proxyHandler.Director = func(*http.Request) {} // noop
+ proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
+ if req.Header == nil {
+ t.Error("Header == nil; want a non-nil Header")
+ }
+ return nil, errors.New("done testing the interesting part; so force a 502 Gateway error")
+ })
+
+ proxyHandler.ServeHTTP(httptest.NewRecorder(), &http.Request{
+ Method: "GET",
+ URL: &url.URL{Scheme: "http", Host: "fake.tld", Path: "/"},
+ Proto: "HTTP/1.0",
+ ProtoMajor: 1,
+ })
+}
+
+// Issue 14237. Test ModifyResponse and that an error from it
+// causes the proxy to return StatusBadGateway, or StatusOK otherwise.
+func TestReverseProxyModifyResponse(t *testing.T) {
+ backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Add("X-Hit-Mod", fmt.Sprintf("%v", r.URL.Path == "/mod"))
+ }))
+ defer backendServer.Close()
+
+ rpURL, _ := url.Parse(backendServer.URL)
+ rproxy := NewSingleHostReverseProxy(rpURL)
+ rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
+ rproxy.ModifyResponse = func(resp *http.Response) error {
+ if resp.Header.Get("X-Hit-Mod") != "true" {
+ return fmt.Errorf("tried to by-pass proxy")
+ }
+ return nil
+ }
+
+ frontendProxy := httptest.NewServer(rproxy)
+ defer frontendProxy.Close()
+
+ tests := []struct {
+ url string
+ wantCode int
+ }{
+ {frontendProxy.URL + "/mod", http.StatusOK},
+ {frontendProxy.URL + "/schedule", http.StatusBadGateway},
+ }
+
+ for i, tt := range tests {
+ resp, err := http.Get(tt.url)
+ if err != nil {
+ t.Fatalf("failed to reach proxy: %v", err)
+ }
+ if g, e := resp.StatusCode, tt.wantCode; g != e {
+ t.Errorf("#%d: got res.StatusCode %d; expected %d", i, g, e)
+ }
+ resp.Body.Close()
+ }
+}
+
+type failingRoundTripper struct{}
+
+func (failingRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
+ return nil, errors.New("some error")
+}
+
+type staticResponseRoundTripper struct{ res *http.Response }
+
+func (rt staticResponseRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
+ return rt.res, nil
+}
+
+func TestReverseProxyErrorHandler(t *testing.T) {
+ tests := []struct {
+ name string
+ wantCode int
+ errorHandler func(http.ResponseWriter, *http.Request, error)
+ transport http.RoundTripper // defaults to failingRoundTripper
+ modifyResponse func(*http.Response) error
+ }{
+ {
+ name: "default",
+ wantCode: http.StatusBadGateway,
+ },
+ {
+ name: "errorhandler",
+ wantCode: http.StatusTeapot,
+ errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
+ },
+ {
+ name: "modifyresponse_noerr",
+ transport: staticResponseRoundTripper{
+ &http.Response{StatusCode: 345, Body: http.NoBody},
+ },
+ modifyResponse: func(res *http.Response) error {
+ res.StatusCode++
+ return nil
+ },
+ errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
+ wantCode: 346,
+ },
+ {
+ name: "modifyresponse_err",
+ transport: staticResponseRoundTripper{
+ &http.Response{StatusCode: 345, Body: http.NoBody},
+ },
+ modifyResponse: func(res *http.Response) error {
+ res.StatusCode++
+ return errors.New("some error to trigger errorHandler")
+ },
+ errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
+ wantCode: http.StatusTeapot,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ target := &url.URL{
+ Scheme: "http",
+ Host: "dummy.tld",
+ Path: "/",
+ }
+ rproxy := NewSingleHostReverseProxy(target)
+ rproxy.Transport = tt.transport
+ rproxy.ModifyResponse = tt.modifyResponse
+ if rproxy.Transport == nil {
+ rproxy.Transport = failingRoundTripper{}
+ }
+ rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
+ if tt.errorHandler != nil {
+ rproxy.ErrorHandler = tt.errorHandler
+ }
+ frontendProxy := httptest.NewServer(rproxy)
+ defer frontendProxy.Close()
+
+ resp, err := http.Get(frontendProxy.URL + "/test")
+ if err != nil {
+ t.Fatalf("failed to reach proxy: %v", err)
+ }
+ if g, e := resp.StatusCode, tt.wantCode; g != e {
+ t.Errorf("got res.StatusCode %d; expected %d", g, e)
+ }
+ resp.Body.Close()
+ })
+ }
+}
+
+// Issue 16659: log errors from short read
+func TestReverseProxy_CopyBuffer(t *testing.T) {
+ backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ out := "this call was relayed by the reverse proxy"
+ // Coerce a wrong content length to induce io.UnexpectedEOF
+ w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
+ fmt.Fprintln(w, out)
+ }))
+ defer backendServer.Close()
+
+ rpURL, err := url.Parse(backendServer.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ var proxyLog bytes.Buffer
+ rproxy := NewSingleHostReverseProxy(rpURL)
+ rproxy.ErrorLog = log.New(&proxyLog, "", log.Lshortfile)
+ donec := make(chan bool, 1)
+ frontendProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ defer func() { donec <- true }()
+ rproxy.ServeHTTP(w, r)
+ }))
+ defer frontendProxy.Close()
+
+ if _, err = frontendProxy.Client().Get(frontendProxy.URL); err == nil {
+ t.Fatalf("want non-nil error")
+ }
+ // The race detector complains about the proxyLog usage in logf in copyBuffer
+ // and our usage below with proxyLog.Bytes() so we're explicitly using a
+ // channel to ensure that the ReverseProxy's ServeHTTP is done before we
+ // continue after Get.
+ <-donec
+
+ expected := []string{
+ "EOF",
+ "read",
+ }
+ for _, phrase := range expected {
+ if !bytes.Contains(proxyLog.Bytes(), []byte(phrase)) {
+ t.Errorf("expected log to contain phrase %q", phrase)
+ }
+ }
+}
+
+type staticTransport struct {
+ res *http.Response
+}
+
+func (t *staticTransport) RoundTrip(r *http.Request) (*http.Response, error) {
+ return t.res, nil
+}
+
+func BenchmarkServeHTTP(b *testing.B) {
+ res := &http.Response{
+ StatusCode: 200,
+ Body: io.NopCloser(strings.NewReader("")),
+ }
+ proxy := &ReverseProxy{
+ Director: func(*http.Request) {},
+ Transport: &staticTransport{res},
+ }
+
+ w := httptest.NewRecorder()
+ r := httptest.NewRequest("GET", "/", nil)
+
+ b.ReportAllocs()
+ for i := 0; i < b.N; i++ {
+ proxy.ServeHTTP(w, r)
+ }
+}
+
+func TestServeHTTPDeepCopy(t *testing.T) {
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte("Hello Gopher!"))
+ }))
+ defer backend.Close()
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ type result struct {
+ before, after string
+ }
+
+ resultChan := make(chan result, 1)
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ before := r.URL.String()
+ proxyHandler.ServeHTTP(w, r)
+ after := r.URL.String()
+ resultChan <- result{before: before, after: after}
+ }))
+ defer frontend.Close()
+
+ want := result{before: "/", after: "/"}
+
+ res, err := frontend.Client().Get(frontend.URL)
+ if err != nil {
+ t.Fatalf("Do: %v", err)
+ }
+ res.Body.Close()
+
+ got := <-resultChan
+ if got != want {
+ t.Errorf("got = %+v; want = %+v", got, want)
+ }
+}
+
+// Issue 18327: verify we always do a deep copy of the Request.Header map
+// before any mutations.
+func TestClonesRequestHeaders(t *testing.T) {
+ log.SetOutput(io.Discard)
+ defer log.SetOutput(os.Stderr)
+ req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
+ req.RemoteAddr = "1.2.3.4:56789"
+ rp := &ReverseProxy{
+ Director: func(req *http.Request) {
+ req.Header.Set("From-Director", "1")
+ },
+ Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
+ if v := req.Header.Get("From-Director"); v != "1" {
+ t.Errorf("From-Directory value = %q; want 1", v)
+ }
+ return nil, io.EOF
+ }),
+ }
+ rp.ServeHTTP(httptest.NewRecorder(), req)
+
+ for _, h := range []string{
+ "From-Director",
+ "X-Forwarded-For",
+ } {
+ if req.Header.Get(h) != "" {
+ t.Errorf("%v header mutation modified caller's request", h)
+ }
+ }
+}
+
+type roundTripperFunc func(req *http.Request) (*http.Response, error)
+
+func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
+ return fn(req)
+}
+
+func TestModifyResponseClosesBody(t *testing.T) {
+ req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
+ req.RemoteAddr = "1.2.3.4:56789"
+ closeCheck := new(checkCloser)
+ logBuf := new(strings.Builder)
+ outErr := errors.New("ModifyResponse error")
+ rp := &ReverseProxy{
+ Director: func(req *http.Request) {},
+ Transport: &staticTransport{&http.Response{
+ StatusCode: 200,
+ Body: closeCheck,
+ }},
+ ErrorLog: log.New(logBuf, "", 0),
+ ModifyResponse: func(*http.Response) error {
+ return outErr
+ },
+ }
+ rec := httptest.NewRecorder()
+ rp.ServeHTTP(rec, req)
+ res := rec.Result()
+ if g, e := res.StatusCode, http.StatusBadGateway; g != e {
+ t.Errorf("got res.StatusCode %d; expected %d", g, e)
+ }
+ if !closeCheck.closed {
+ t.Errorf("body should have been closed")
+ }
+ if g, e := logBuf.String(), outErr.Error(); !strings.Contains(g, e) {
+ t.Errorf("ErrorLog %q does not contain %q", g, e)
+ }
+}
+
+type checkCloser struct {
+ closed bool
+}
+
+func (cc *checkCloser) Close() error {
+ cc.closed = true
+ return nil
+}
+
+func (cc *checkCloser) Read(b []byte) (int, error) {
+ return len(b), nil
+}
+
+// Issue 23643: panic on body copy error
+func TestReverseProxy_PanicBodyError(t *testing.T) {
+ log.SetOutput(io.Discard)
+ defer log.SetOutput(os.Stderr)
+ backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ out := "this call was relayed by the reverse proxy"
+ // Coerce a wrong content length to induce io.ErrUnexpectedEOF
+ w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
+ fmt.Fprintln(w, out)
+ }))
+ defer backendServer.Close()
+
+ rpURL, err := url.Parse(backendServer.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ rproxy := NewSingleHostReverseProxy(rpURL)
+
+ // Ensure that the handler panics when the body read encounters an
+ // io.ErrUnexpectedEOF
+ defer func() {
+ err := recover()
+ if err == nil {
+ t.Fatal("handler should have panicked")
+ }
+ if err != http.ErrAbortHandler {
+ t.Fatal("expected ErrAbortHandler, got", err)
+ }
+ }()
+ req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
+ rproxy.ServeHTTP(httptest.NewRecorder(), req)
+}
+
+// Issue #46866: panic without closing incoming request body causes a panic
+func TestReverseProxy_PanicClosesIncomingBody(t *testing.T) {
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ out := "this call was relayed by the reverse proxy"
+ // Coerce a wrong content length to induce io.ErrUnexpectedEOF
+ w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
+ fmt.Fprintln(w, out)
+ }))
+ defer backend.Close()
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+ frontendClient := frontend.Client()
+
+ var wg sync.WaitGroup
+ for i := 0; i < 2; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ for j := 0; j < 10; j++ {
+ const reqLen = 6 * 1024 * 1024
+ req, _ := http.NewRequest("POST", frontend.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen})
+ req.ContentLength = reqLen
+ resp, _ := frontendClient.Transport.RoundTrip(req)
+ if resp != nil {
+ io.Copy(io.Discard, resp.Body)
+ resp.Body.Close()
+ }
+ }
+ }()
+ }
+ wg.Wait()
+}
+
+func TestSelectFlushInterval(t *testing.T) {
+ tests := []struct {
+ name string
+ p *ReverseProxy
+ res *http.Response
+ want time.Duration
+ }{
+ {
+ name: "default",
+ res: &http.Response{},
+ p: &ReverseProxy{FlushInterval: 123},
+ want: 123,
+ },
+ {
+ name: "server-sent events overrides non-zero",
+ res: &http.Response{
+ Header: http.Header{
+ "Content-Type": {"text/event-stream"},
+ },
+ },
+ p: &ReverseProxy{FlushInterval: 123},
+ want: -1,
+ },
+ {
+ name: "server-sent events overrides zero",
+ res: &http.Response{
+ Header: http.Header{
+ "Content-Type": {"text/event-stream"},
+ },
+ },
+ p: &ReverseProxy{FlushInterval: 0},
+ want: -1,
+ },
+ {
+ name: "server-sent events with media-type parameters overrides non-zero",
+ res: &http.Response{
+ Header: http.Header{
+ "Content-Type": {"text/event-stream;charset=utf-8"},
+ },
+ },
+ p: &ReverseProxy{FlushInterval: 123},
+ want: -1,
+ },
+ {
+ name: "server-sent events with media-type parameters overrides zero",
+ res: &http.Response{
+ Header: http.Header{
+ "Content-Type": {"text/event-stream;charset=utf-8"},
+ },
+ },
+ p: &ReverseProxy{FlushInterval: 0},
+ want: -1,
+ },
+ {
+ name: "Content-Length: -1, overrides non-zero",
+ res: &http.Response{
+ ContentLength: -1,
+ },
+ p: &ReverseProxy{FlushInterval: 123},
+ want: -1,
+ },
+ {
+ name: "Content-Length: -1, overrides zero",
+ res: &http.Response{
+ ContentLength: -1,
+ },
+ p: &ReverseProxy{FlushInterval: 0},
+ want: -1,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := tt.p.flushInterval(tt.res)
+ if got != tt.want {
+ t.Errorf("flushLatency = %v; want %v", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestReverseProxyWebSocket(t *testing.T) {
+ backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if upgradeType(r.Header) != "websocket" {
+ t.Error("unexpected backend request")
+ http.Error(w, "unexpected request", 400)
+ return
+ }
+ c, _, err := w.(http.Hijacker).Hijack()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ defer c.Close()
+ io.WriteString(c, "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n")
+ bs := bufio.NewScanner(c)
+ if !bs.Scan() {
+ t.Errorf("backend failed to read line from client: %v", bs.Err())
+ return
+ }
+ fmt.Fprintf(c, "backend got %q\n", bs.Text())
+ }))
+ defer backendServer.Close()
+
+ backURL, _ := url.Parse(backendServer.URL)
+ rproxy := NewSingleHostReverseProxy(backURL)
+ rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
+ rproxy.ModifyResponse = func(res *http.Response) error {
+ res.Header.Add("X-Modified", "true")
+ return nil
+ }
+
+ handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
+ rw.Header().Set("X-Header", "X-Value")
+ rproxy.ServeHTTP(rw, req)
+ if got, want := rw.Header().Get("X-Modified"), "true"; got != want {
+ t.Errorf("response writer X-Modified header = %q; want %q", got, want)
+ }
+ })
+
+ frontendProxy := httptest.NewServer(handler)
+ defer frontendProxy.Close()
+
+ req, _ := http.NewRequest("GET", frontendProxy.URL, nil)
+ req.Header.Set("Connection", "Upgrade")
+ req.Header.Set("Upgrade", "websocket")
+
+ c := frontendProxy.Client()
+ res, err := c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if res.StatusCode != 101 {
+ t.Fatalf("status = %v; want 101", res.Status)
+ }
+
+ got := res.Header.Get("X-Header")
+ want := "X-Value"
+ if got != want {
+ t.Errorf("Header(XHeader) = %q; want %q", got, want)
+ }
+
+ if !ascii.EqualFold(upgradeType(res.Header), "websocket") {
+ t.Fatalf("not websocket upgrade; got %#v", res.Header)
+ }
+ rwc, ok := res.Body.(io.ReadWriteCloser)
+ if !ok {
+ t.Fatalf("response body is of type %T; does not implement ReadWriteCloser", res.Body)
+ }
+ defer rwc.Close()
+
+ if got, want := res.Header.Get("X-Modified"), "true"; got != want {
+ t.Errorf("response X-Modified header = %q; want %q", got, want)
+ }
+
+ io.WriteString(rwc, "Hello\n")
+ bs := bufio.NewScanner(rwc)
+ if !bs.Scan() {
+ t.Fatalf("Scan: %v", bs.Err())
+ }
+ got = bs.Text()
+ want = `backend got "Hello"`
+ if got != want {
+ t.Errorf("got %#q, want %#q", got, want)
+ }
+}
+
+func TestReverseProxyWebSocketCancellation(t *testing.T) {
+ n := 5
+ triggerCancelCh := make(chan bool, n)
+ nthResponse := func(i int) string {
+ return fmt.Sprintf("backend response #%d\n", i)
+ }
+ terminalMsg := "final message"
+
+ cst := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if g, ws := upgradeType(r.Header), "websocket"; g != ws {
+ t.Errorf("Unexpected upgrade type %q, want %q", g, ws)
+ http.Error(w, "Unexpected request", 400)
+ return
+ }
+ conn, bufrw, err := w.(http.Hijacker).Hijack()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ defer conn.Close()
+
+ upgradeMsg := "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n"
+ if _, err := io.WriteString(conn, upgradeMsg); err != nil {
+ t.Error(err)
+ return
+ }
+ if _, _, err := bufrw.ReadLine(); err != nil {
+ t.Errorf("Failed to read line from client: %v", err)
+ return
+ }
+
+ for i := 0; i < n; i++ {
+ if _, err := bufrw.WriteString(nthResponse(i)); err != nil {
+ select {
+ case <-triggerCancelCh:
+ default:
+ t.Errorf("Writing response #%d failed: %v", i, err)
+ }
+ return
+ }
+ bufrw.Flush()
+ time.Sleep(time.Second)
+ }
+ if _, err := bufrw.WriteString(terminalMsg); err != nil {
+ select {
+ case <-triggerCancelCh:
+ default:
+ t.Errorf("Failed to write terminal message: %v", err)
+ }
+ }
+ bufrw.Flush()
+ }))
+ defer cst.Close()
+
+ backendURL, _ := url.Parse(cst.URL)
+ rproxy := NewSingleHostReverseProxy(backendURL)
+ rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
+ rproxy.ModifyResponse = func(res *http.Response) error {
+ res.Header.Add("X-Modified", "true")
+ return nil
+ }
+
+ handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
+ rw.Header().Set("X-Header", "X-Value")
+ ctx, cancel := context.WithCancel(req.Context())
+ go func() {
+ <-triggerCancelCh
+ cancel()
+ }()
+ rproxy.ServeHTTP(rw, req.WithContext(ctx))
+ })
+
+ frontendProxy := httptest.NewServer(handler)
+ defer frontendProxy.Close()
+
+ req, _ := http.NewRequest("GET", frontendProxy.URL, nil)
+ req.Header.Set("Connection", "Upgrade")
+ req.Header.Set("Upgrade", "websocket")
+
+ res, err := frontendProxy.Client().Do(req)
+ if err != nil {
+ t.Fatalf("Dialing to frontend proxy: %v", err)
+ }
+ defer res.Body.Close()
+ if g, w := res.StatusCode, 101; g != w {
+ t.Fatalf("Switching protocols failed, got: %d, want: %d", g, w)
+ }
+
+ if g, w := res.Header.Get("X-Header"), "X-Value"; g != w {
+ t.Errorf("X-Header mismatch\n\tgot: %q\n\twant: %q", g, w)
+ }
+
+ if g, w := upgradeType(res.Header), "websocket"; !ascii.EqualFold(g, w) {
+ t.Fatalf("Upgrade header mismatch\n\tgot: %q\n\twant: %q", g, w)
+ }
+
+ rwc, ok := res.Body.(io.ReadWriteCloser)
+ if !ok {
+ t.Fatalf("Response body type mismatch, got %T, want io.ReadWriteCloser", res.Body)
+ }
+
+ if got, want := res.Header.Get("X-Modified"), "true"; got != want {
+ t.Errorf("response X-Modified header = %q; want %q", got, want)
+ }
+
+ if _, err := io.WriteString(rwc, "Hello\n"); err != nil {
+ t.Fatalf("Failed to write first message: %v", err)
+ }
+
+ // Read loop.
+
+ br := bufio.NewReader(rwc)
+ for {
+ line, err := br.ReadString('\n')
+ switch {
+ case line == terminalMsg: // this case before "err == io.EOF"
+ t.Fatalf("The websocket request was not canceled, unfortunately!")
+
+ case err == io.EOF:
+ return
+
+ case err != nil:
+ t.Fatalf("Unexpected error: %v", err)
+
+ case line == nthResponse(0): // We've gotten the first response back
+ // Let's trigger a cancel.
+ close(triggerCancelCh)
+ }
+ }
+}
+
+func TestUnannouncedTrailer(t *testing.T) {
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ w.(http.Flusher).Flush()
+ w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value")
+ }))
+ defer backend.Close()
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+ frontendClient := frontend.Client()
+
+ res, err := frontendClient.Get(frontend.URL)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+
+ io.ReadAll(res.Body)
+
+ if g, w := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != w {
+ t.Errorf("Trailer(X-Unannounced-Trailer) = %q; want %q", g, w)
+ }
+
+}
+
+func TestSetURL(t *testing.T) {
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte(r.Host))
+ }))
+ defer backend.Close()
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ proxyHandler := &ReverseProxy{
+ Rewrite: func(r *ProxyRequest) {
+ r.SetURL(backendURL)
+ },
+ }
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+ frontendClient := frontend.Client()
+
+ res, err := frontendClient.Get(frontend.URL)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ defer res.Body.Close()
+
+ body, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatalf("Reading body: %v", err)
+ }
+
+ if got, want := string(body), backendURL.Host; got != want {
+ t.Errorf("backend got Host %q, want %q", got, want)
+ }
+}
+
+func TestSingleJoinSlash(t *testing.T) {
+ tests := []struct {
+ slasha string
+ slashb string
+ expected string
+ }{
+ {"https://www.google.com/", "/favicon.ico", "https://www.google.com/favicon.ico"},
+ {"https://www.google.com", "/favicon.ico", "https://www.google.com/favicon.ico"},
+ {"https://www.google.com", "favicon.ico", "https://www.google.com/favicon.ico"},
+ {"https://www.google.com", "", "https://www.google.com/"},
+ {"", "favicon.ico", "/favicon.ico"},
+ }
+ for _, tt := range tests {
+ if got := singleJoiningSlash(tt.slasha, tt.slashb); got != tt.expected {
+ t.Errorf("singleJoiningSlash(%q,%q) want %q got %q",
+ tt.slasha,
+ tt.slashb,
+ tt.expected,
+ got)
+ }
+ }
+}
+
+func TestJoinURLPath(t *testing.T) {
+ tests := []struct {
+ a *url.URL
+ b *url.URL
+ wantPath string
+ wantRaw string
+ }{
+ {&url.URL{Path: "/a/b"}, &url.URL{Path: "/c"}, "/a/b/c", ""},
+ {&url.URL{Path: "/a/b", RawPath: "badpath"}, &url.URL{Path: "c"}, "/a/b/c", "/a/b/c"},
+ {&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"},
+ {&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"},
+ {&url.URL{Path: "/a/b/", RawPath: "/a%2Fb%2F"}, &url.URL{Path: "c"}, "/a/b//c", "/a%2Fb%2F/c"},
+ {&url.URL{Path: "/a/b/", RawPath: "/a%2Fb/"}, &url.URL{Path: "/c/d", RawPath: "/c%2Fd"}, "/a/b/c/d", "/a%2Fb/c%2Fd"},
+ }
+
+ for _, tt := range tests {
+ p, rp := joinURLPath(tt.a, tt.b)
+ if p != tt.wantPath || rp != tt.wantRaw {
+ t.Errorf("joinURLPath(URL(%q,%q),URL(%q,%q)) want (%q,%q) got (%q,%q)",
+ tt.a.Path, tt.a.RawPath,
+ tt.b.Path, tt.b.RawPath,
+ tt.wantPath, tt.wantRaw,
+ p, rp)
+ }
+ }
+}
+
+func TestReverseProxyRewriteReplacesOut(t *testing.T) {
+ const content = "response_content"
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte(content))
+ }))
+ defer backend.Close()
+ proxyHandler := &ReverseProxy{
+ Rewrite: func(r *ProxyRequest) {
+ r.Out, _ = http.NewRequest("GET", backend.URL, nil)
+ },
+ }
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+
+ res, err := frontend.Client().Get(frontend.URL)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ defer res.Body.Close()
+ body, _ := io.ReadAll(res.Body)
+ if got, want := string(body), content; got != want {
+ t.Errorf("got response %q, want %q", got, want)
+ }
+}
+
+func Test1xxResponses(t *testing.T) {
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ h := w.Header()
+ h.Add("Link", "</style.css>; rel=preload; as=style")
+ h.Add("Link", "</script.js>; rel=preload; as=script")
+ w.WriteHeader(http.StatusEarlyHints)
+
+ h.Add("Link", "</foo.js>; rel=preload; as=script")
+ w.WriteHeader(http.StatusProcessing)
+
+ w.Write([]byte("Hello"))
+ }))
+ defer backend.Close()
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+ frontendClient := frontend.Client()
+
+ checkLinkHeaders := func(t *testing.T, expected, got []string) {
+ t.Helper()
+
+ if len(expected) != len(got) {
+ t.Errorf("Expected %d link headers; got %d", len(expected), len(got))
+ }
+
+ for i := range expected {
+ if i >= len(got) {
+ t.Errorf("Expected %q link header; got nothing", expected[i])
+
+ continue
+ }
+
+ if expected[i] != got[i] {
+ t.Errorf("Expected %q link header; got %q", expected[i], got[i])
+ }
+ }
+ }
+
+ var respCounter uint8
+ trace := &httptrace.ClientTrace{
+ Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
+ switch code {
+ case http.StatusEarlyHints:
+ checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script"}, header["Link"])
+ case http.StatusProcessing:
+ checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, header["Link"])
+ default:
+ t.Error("Unexpected 1xx response")
+ }
+
+ respCounter++
+
+ return nil
+ },
+ }
+ req, _ := http.NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), "GET", frontend.URL, nil)
+
+ res, err := frontendClient.Do(req)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+
+ defer res.Body.Close()
+
+ if respCounter != 2 {
+ t.Errorf("Expected 2 1xx responses; got %d", respCounter)
+ }
+ checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, res.Header["Link"])
+
+ body, _ := io.ReadAll(res.Body)
+ if string(body) != "Hello" {
+ t.Errorf("Read body %q; want Hello", body)
+ }
+}
+
+const (
+ testWantsCleanQuery = true
+ testWantsRawQuery = false
+)
+
+func TestReverseProxyQueryParameterSmugglingDirectorDoesNotParseForm(t *testing.T) {
+ testReverseProxyQueryParameterSmuggling(t, testWantsRawQuery, func(u *url.URL) *ReverseProxy {
+ proxyHandler := NewSingleHostReverseProxy(u)
+ oldDirector := proxyHandler.Director
+ proxyHandler.Director = func(r *http.Request) {
+ oldDirector(r)
+ }
+ return proxyHandler
+ })
+}
+
+func TestReverseProxyQueryParameterSmugglingDirectorParsesForm(t *testing.T) {
+ testReverseProxyQueryParameterSmuggling(t, testWantsCleanQuery, func(u *url.URL) *ReverseProxy {
+ proxyHandler := NewSingleHostReverseProxy(u)
+ oldDirector := proxyHandler.Director
+ proxyHandler.Director = func(r *http.Request) {
+ // Parsing the form causes ReverseProxy to remove unparsable
+ // query parameters before forwarding.
+ r.FormValue("a")
+ oldDirector(r)
+ }
+ return proxyHandler
+ })
+}
+
+func TestReverseProxyQueryParameterSmugglingRewrite(t *testing.T) {
+ testReverseProxyQueryParameterSmuggling(t, testWantsCleanQuery, func(u *url.URL) *ReverseProxy {
+ return &ReverseProxy{
+ Rewrite: func(r *ProxyRequest) {
+ r.SetURL(u)
+ },
+ }
+ })
+}
+
+func TestReverseProxyQueryParameterSmugglingRewritePreservesRawQuery(t *testing.T) {
+ testReverseProxyQueryParameterSmuggling(t, testWantsRawQuery, func(u *url.URL) *ReverseProxy {
+ return &ReverseProxy{
+ Rewrite: func(r *ProxyRequest) {
+ r.SetURL(u)
+ r.Out.URL.RawQuery = r.In.URL.RawQuery
+ },
+ }
+ })
+}
+
+func testReverseProxyQueryParameterSmuggling(t *testing.T, wantCleanQuery bool, newProxy func(*url.URL) *ReverseProxy) {
+ const content = "response_content"
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte(r.URL.RawQuery))
+ }))
+ defer backend.Close()
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ proxyHandler := newProxy(backendURL)
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+
+ // Don't spam output with logs of queries containing semicolons.
+ backend.Config.ErrorLog = log.New(io.Discard, "", 0)
+ frontend.Config.ErrorLog = log.New(io.Discard, "", 0)
+
+ for _, test := range []struct {
+ rawQuery string
+ cleanQuery string
+ }{{
+ rawQuery: "a=1&a=2;b=3",
+ cleanQuery: "a=1",
+ }, {
+ rawQuery: "a=1&a=%zz&b=3",
+ cleanQuery: "a=1&b=3",
+ }} {
+ res, err := frontend.Client().Get(frontend.URL + "?" + test.rawQuery)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ defer res.Body.Close()
+ body, _ := io.ReadAll(res.Body)
+ wantQuery := test.rawQuery
+ if wantCleanQuery {
+ wantQuery = test.cleanQuery
+ }
+ if got, want := string(body), wantQuery; got != want {
+ t.Errorf("proxy forwarded raw query %q as %q, want %q", test.rawQuery, got, want)
+ }
+ }
+}
diff --git a/src/net/http/internal/ascii/print.go b/src/net/http/internal/ascii/print.go
new file mode 100644
index 0000000..585e5ba
--- /dev/null
+++ b/src/net/http/internal/ascii/print.go
@@ -0,0 +1,61 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ascii
+
+import (
+ "strings"
+ "unicode"
+)
+
+// EqualFold is strings.EqualFold, ASCII only. It reports whether s and t
+// are equal, ASCII-case-insensitively.
+func EqualFold(s, t string) bool {
+ if len(s) != len(t) {
+ return false
+ }
+ for i := 0; i < len(s); i++ {
+ if lower(s[i]) != lower(t[i]) {
+ return false
+ }
+ }
+ return true
+}
+
+// lower returns the ASCII lowercase version of b.
+func lower(b byte) byte {
+ if 'A' <= b && b <= 'Z' {
+ return b + ('a' - 'A')
+ }
+ return b
+}
+
+// IsPrint returns whether s is ASCII and printable according to
+// https://tools.ietf.org/html/rfc20#section-4.2.
+func IsPrint(s string) bool {
+ for i := 0; i < len(s); i++ {
+ if s[i] < ' ' || s[i] > '~' {
+ return false
+ }
+ }
+ return true
+}
+
+// Is returns whether s is ASCII.
+func Is(s string) bool {
+ for i := 0; i < len(s); i++ {
+ if s[i] > unicode.MaxASCII {
+ return false
+ }
+ }
+ return true
+}
+
+// ToLower returns the lowercase version of s if s is ASCII and printable.
+func ToLower(s string) (lower string, ok bool) {
+ if !IsPrint(s) {
+ return "", false
+ }
+ return strings.ToLower(s), true
+}
diff --git a/src/net/http/internal/ascii/print_test.go b/src/net/http/internal/ascii/print_test.go
new file mode 100644
index 0000000..0b7767c
--- /dev/null
+++ b/src/net/http/internal/ascii/print_test.go
@@ -0,0 +1,95 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ascii
+
+import "testing"
+
+func TestEqualFold(t *testing.T) {
+ var tests = []struct {
+ name string
+ a, b string
+ want bool
+ }{
+ {
+ name: "empty",
+ want: true,
+ },
+ {
+ name: "simple match",
+ a: "CHUNKED",
+ b: "chunked",
+ want: true,
+ },
+ {
+ name: "same string",
+ a: "chunked",
+ b: "chunked",
+ want: true,
+ },
+ {
+ name: "Unicode Kelvin symbol",
+ a: "chunKed", // This "K" is 'KELVIN SIGN' (\u212A)
+ b: "chunked",
+ want: false,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := EqualFold(tt.a, tt.b); got != tt.want {
+ t.Errorf("AsciiEqualFold(%q,%q): got %v want %v", tt.a, tt.b, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestIsPrint(t *testing.T) {
+ var tests = []struct {
+ name string
+ in string
+ want bool
+ }{
+ {
+ name: "empty",
+ want: true,
+ },
+ {
+ name: "ASCII low",
+ in: "This is a space: ' '",
+ want: true,
+ },
+ {
+ name: "ASCII high",
+ in: "This is a tilde: '~'",
+ want: true,
+ },
+ {
+ name: "ASCII low non-print",
+ in: "This is a unit separator: \x1F",
+ want: false,
+ },
+ {
+ name: "Ascii high non-print",
+ in: "This is a Delete: \x7F",
+ want: false,
+ },
+ {
+ name: "Unicode letter",
+ in: "Today it's 280K outside: it's freezing!", // This "K" is 'KELVIN SIGN' (\u212A)
+ want: false,
+ },
+ {
+ name: "Unicode emoji",
+ in: "Gophers like 🧀",
+ want: false,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := IsPrint(tt.in); got != tt.want {
+ t.Errorf("IsASCIIPrint(%q): got %v want %v", tt.in, got, tt.want)
+ }
+ })
+ }
+}
diff --git a/src/net/http/internal/chunked.go b/src/net/http/internal/chunked.go
new file mode 100644
index 0000000..aad8e5a
--- /dev/null
+++ b/src/net/http/internal/chunked.go
@@ -0,0 +1,284 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// The wire protocol for HTTP's "chunked" Transfer-Encoding.
+
+// Package internal contains HTTP internals shared by net/http and
+// net/http/httputil.
+package internal
+
+import (
+ "bufio"
+ "bytes"
+ "errors"
+ "fmt"
+ "io"
+)
+
+const maxLineLength = 4096 // assumed <= bufio.defaultBufSize
+
+var ErrLineTooLong = errors.New("header line too long")
+
+// NewChunkedReader returns a new chunkedReader that translates the data read from r
+// out of HTTP "chunked" format before returning it.
+// The chunkedReader returns io.EOF when the final 0-length chunk is read.
+//
+// NewChunkedReader is not needed by normal applications. The http package
+// automatically decodes chunking when reading response bodies.
+func NewChunkedReader(r io.Reader) io.Reader {
+ br, ok := r.(*bufio.Reader)
+ if !ok {
+ br = bufio.NewReader(r)
+ }
+ return &chunkedReader{r: br}
+}
+
+type chunkedReader struct {
+ r *bufio.Reader
+ n uint64 // unread bytes in chunk
+ err error
+ buf [2]byte
+ checkEnd bool // whether need to check for \r\n chunk footer
+ excess int64 // "excessive" chunk overhead, for malicious sender detection
+}
+
+func (cr *chunkedReader) beginChunk() {
+ // chunk-size CRLF
+ var line []byte
+ line, cr.err = readChunkLine(cr.r)
+ if cr.err != nil {
+ return
+ }
+ cr.excess += int64(len(line)) + 2 // header, plus \r\n after the chunk data
+ line = trimTrailingWhitespace(line)
+ line, cr.err = removeChunkExtension(line)
+ if cr.err != nil {
+ return
+ }
+ cr.n, cr.err = parseHexUint(line)
+ if cr.err != nil {
+ return
+ }
+ // A sender who sends one byte per chunk will send 5 bytes of overhead
+ // for every byte of data. ("1\r\nX\r\n" to send "X".)
+ // We want to allow this, since streaming a byte at a time can be legitimate.
+ //
+ // A sender can use chunk extensions to add arbitrary amounts of additional
+ // data per byte read. ("1;very long extension\r\nX\r\n" to send "X".)
+ // We don't want to disallow extensions (although we discard them),
+ // but we also don't want to allow a sender to reduce the signal/noise ratio
+ // arbitrarily.
+ //
+ // We track the amount of excess overhead read,
+ // and produce an error if it grows too large.
+ //
+ // Currently, we say that we're willing to accept 16 bytes of overhead per chunk,
+ // plus twice the amount of real data in the chunk.
+ cr.excess -= 16 + (2 * int64(cr.n))
+ cr.excess = max(cr.excess, 0)
+ if cr.excess > 16*1024 {
+ cr.err = errors.New("chunked encoding contains too much non-data")
+ }
+ if cr.n == 0 {
+ cr.err = io.EOF
+ }
+}
+
+func (cr *chunkedReader) chunkHeaderAvailable() bool {
+ n := cr.r.Buffered()
+ if n > 0 {
+ peek, _ := cr.r.Peek(n)
+ return bytes.IndexByte(peek, '\n') >= 0
+ }
+ return false
+}
+
+func (cr *chunkedReader) Read(b []uint8) (n int, err error) {
+ for cr.err == nil {
+ if cr.checkEnd {
+ if n > 0 && cr.r.Buffered() < 2 {
+ // We have some data. Return early (per the io.Reader
+ // contract) instead of potentially blocking while
+ // reading more.
+ break
+ }
+ if _, cr.err = io.ReadFull(cr.r, cr.buf[:2]); cr.err == nil {
+ if string(cr.buf[:]) != "\r\n" {
+ cr.err = errors.New("malformed chunked encoding")
+ break
+ }
+ } else {
+ if cr.err == io.EOF {
+ cr.err = io.ErrUnexpectedEOF
+ }
+ break
+ }
+ cr.checkEnd = false
+ }
+ if cr.n == 0 {
+ if n > 0 && !cr.chunkHeaderAvailable() {
+ // We've read enough. Don't potentially block
+ // reading a new chunk header.
+ break
+ }
+ cr.beginChunk()
+ continue
+ }
+ if len(b) == 0 {
+ break
+ }
+ rbuf := b
+ if uint64(len(rbuf)) > cr.n {
+ rbuf = rbuf[:cr.n]
+ }
+ var n0 int
+ n0, cr.err = cr.r.Read(rbuf)
+ n += n0
+ b = b[n0:]
+ cr.n -= uint64(n0)
+ // If we're at the end of a chunk, read the next two
+ // bytes to verify they are "\r\n".
+ if cr.n == 0 && cr.err == nil {
+ cr.checkEnd = true
+ } else if cr.err == io.EOF {
+ cr.err = io.ErrUnexpectedEOF
+ }
+ }
+ return n, cr.err
+}
+
+// Read a line of bytes (up to \n) from b.
+// Give up if the line exceeds maxLineLength.
+// The returned bytes are owned by the bufio.Reader
+// so they are only valid until the next bufio read.
+func readChunkLine(b *bufio.Reader) ([]byte, error) {
+ p, err := b.ReadSlice('\n')
+ if err != nil {
+ // We always know when EOF is coming.
+ // If the caller asked for a line, there should be a line.
+ if err == io.EOF {
+ err = io.ErrUnexpectedEOF
+ } else if err == bufio.ErrBufferFull {
+ err = ErrLineTooLong
+ }
+ return nil, err
+ }
+ if len(p) >= maxLineLength {
+ return nil, ErrLineTooLong
+ }
+ return p, nil
+}
+
+func trimTrailingWhitespace(b []byte) []byte {
+ for len(b) > 0 && isASCIISpace(b[len(b)-1]) {
+ b = b[:len(b)-1]
+ }
+ return b
+}
+
+func isASCIISpace(b byte) bool {
+ return b == ' ' || b == '\t' || b == '\n' || b == '\r'
+}
+
+var semi = []byte(";")
+
+// removeChunkExtension removes any chunk-extension from p.
+// For example,
+//
+// "0" => "0"
+// "0;token" => "0"
+// "0;token=val" => "0"
+// `0;token="quoted string"` => "0"
+func removeChunkExtension(p []byte) ([]byte, error) {
+ p, _, _ = bytes.Cut(p, semi)
+ // TODO: care about exact syntax of chunk extensions? We're
+ // ignoring and stripping them anyway. For now just never
+ // return an error.
+ return p, nil
+}
+
+// NewChunkedWriter returns a new chunkedWriter that translates writes into HTTP
+// "chunked" format before writing them to w. Closing the returned chunkedWriter
+// sends the final 0-length chunk that marks the end of the stream but does
+// not send the final CRLF that appears after trailers; trailers and the last
+// CRLF must be written separately.
+//
+// NewChunkedWriter is not needed by normal applications. The http
+// package adds chunking automatically if handlers don't set a
+// Content-Length header. Using newChunkedWriter inside a handler
+// would result in double chunking or chunking with a Content-Length
+// length, both of which are wrong.
+func NewChunkedWriter(w io.Writer) io.WriteCloser {
+ return &chunkedWriter{w}
+}
+
+// Writing to chunkedWriter translates to writing in HTTP chunked Transfer
+// Encoding wire format to the underlying Wire chunkedWriter.
+type chunkedWriter struct {
+ Wire io.Writer
+}
+
+// Write the contents of data as one chunk to Wire.
+// NOTE: Note that the corresponding chunk-writing procedure in Conn.Write has
+// a bug since it does not check for success of io.WriteString
+func (cw *chunkedWriter) Write(data []byte) (n int, err error) {
+
+ // Don't send 0-length data. It looks like EOF for chunked encoding.
+ if len(data) == 0 {
+ return 0, nil
+ }
+
+ if _, err = fmt.Fprintf(cw.Wire, "%x\r\n", len(data)); err != nil {
+ return 0, err
+ }
+ if n, err = cw.Wire.Write(data); err != nil {
+ return
+ }
+ if n != len(data) {
+ err = io.ErrShortWrite
+ return
+ }
+ if _, err = io.WriteString(cw.Wire, "\r\n"); err != nil {
+ return
+ }
+ if bw, ok := cw.Wire.(*FlushAfterChunkWriter); ok {
+ err = bw.Flush()
+ }
+ return
+}
+
+func (cw *chunkedWriter) Close() error {
+ _, err := io.WriteString(cw.Wire, "0\r\n")
+ return err
+}
+
+// FlushAfterChunkWriter signals from the caller of NewChunkedWriter
+// that each chunk should be followed by a flush. It is used by the
+// http.Transport code to keep the buffering behavior for headers and
+// trailers, but flush out chunks aggressively in the middle for
+// request bodies which may be generated slowly. See Issue 6574.
+type FlushAfterChunkWriter struct {
+ *bufio.Writer
+}
+
+func parseHexUint(v []byte) (n uint64, err error) {
+ for i, b := range v {
+ switch {
+ case '0' <= b && b <= '9':
+ b = b - '0'
+ case 'a' <= b && b <= 'f':
+ b = b - 'a' + 10
+ case 'A' <= b && b <= 'F':
+ b = b - 'A' + 10
+ default:
+ return 0, errors.New("invalid byte in chunk length")
+ }
+ if i == 16 {
+ return 0, errors.New("http chunk length too large")
+ }
+ n <<= 4
+ n |= uint64(b)
+ }
+ return
+}
diff --git a/src/net/http/internal/chunked_test.go b/src/net/http/internal/chunked_test.go
new file mode 100644
index 0000000..b99090c
--- /dev/null
+++ b/src/net/http/internal/chunked_test.go
@@ -0,0 +1,300 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package internal
+
+import (
+ "bufio"
+ "bytes"
+ "fmt"
+ "io"
+ "strings"
+ "testing"
+ "testing/iotest"
+)
+
+func TestChunk(t *testing.T) {
+ var b bytes.Buffer
+
+ w := NewChunkedWriter(&b)
+ const chunk1 = "hello, "
+ const chunk2 = "world! 0123456789abcdef"
+ w.Write([]byte(chunk1))
+ w.Write([]byte(chunk2))
+ w.Close()
+
+ if g, e := b.String(), "7\r\nhello, \r\n17\r\nworld! 0123456789abcdef\r\n0\r\n"; g != e {
+ t.Fatalf("chunk writer wrote %q; want %q", g, e)
+ }
+
+ r := NewChunkedReader(&b)
+ data, err := io.ReadAll(r)
+ if err != nil {
+ t.Logf(`data: "%s"`, data)
+ t.Fatalf("ReadAll from reader: %v", err)
+ }
+ if g, e := string(data), chunk1+chunk2; g != e {
+ t.Errorf("chunk reader read %q; want %q", g, e)
+ }
+}
+
+func TestChunkReadMultiple(t *testing.T) {
+ // Bunch of small chunks, all read together.
+ {
+ var b bytes.Buffer
+ w := NewChunkedWriter(&b)
+ w.Write([]byte("foo"))
+ w.Write([]byte("bar"))
+ w.Close()
+
+ r := NewChunkedReader(&b)
+ buf := make([]byte, 10)
+ n, err := r.Read(buf)
+ if n != 6 || err != io.EOF {
+ t.Errorf("Read = %d, %v; want 6, EOF", n, err)
+ }
+ buf = buf[:n]
+ if string(buf) != "foobar" {
+ t.Errorf("Read = %q; want %q", buf, "foobar")
+ }
+ }
+
+ // One big chunk followed by a little chunk, but the small bufio.Reader size
+ // should prevent the second chunk header from being read.
+ {
+ var b bytes.Buffer
+ w := NewChunkedWriter(&b)
+ // fillBufChunk is 11 bytes + 3 bytes header + 2 bytes footer = 16 bytes,
+ // the same as the bufio ReaderSize below (the minimum), so even
+ // though we're going to try to Read with a buffer larger enough to also
+ // receive "foo", the second chunk header won't be read yet.
+ const fillBufChunk = "0123456789a"
+ const shortChunk = "foo"
+ w.Write([]byte(fillBufChunk))
+ w.Write([]byte(shortChunk))
+ w.Close()
+
+ r := NewChunkedReader(bufio.NewReaderSize(&b, 16))
+ buf := make([]byte, len(fillBufChunk)+len(shortChunk))
+ n, err := r.Read(buf)
+ if n != len(fillBufChunk) || err != nil {
+ t.Errorf("Read = %d, %v; want %d, nil", n, err, len(fillBufChunk))
+ }
+ buf = buf[:n]
+ if string(buf) != fillBufChunk {
+ t.Errorf("Read = %q; want %q", buf, fillBufChunk)
+ }
+
+ n, err = r.Read(buf)
+ if n != len(shortChunk) || err != io.EOF {
+ t.Errorf("Read = %d, %v; want %d, EOF", n, err, len(shortChunk))
+ }
+ }
+
+ // And test that we see an EOF chunk, even though our buffer is already full:
+ {
+ r := NewChunkedReader(bufio.NewReader(strings.NewReader("3\r\nfoo\r\n0\r\n")))
+ buf := make([]byte, 3)
+ n, err := r.Read(buf)
+ if n != 3 || err != io.EOF {
+ t.Errorf("Read = %d, %v; want 3, EOF", n, err)
+ }
+ if string(buf) != "foo" {
+ t.Errorf("buf = %q; want foo", buf)
+ }
+ }
+}
+
+func TestChunkReaderAllocs(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping in short mode")
+ }
+ var buf bytes.Buffer
+ w := NewChunkedWriter(&buf)
+ a, b, c := []byte("aaaaaa"), []byte("bbbbbbbbbbbb"), []byte("cccccccccccccccccccccccc")
+ w.Write(a)
+ w.Write(b)
+ w.Write(c)
+ w.Close()
+
+ readBuf := make([]byte, len(a)+len(b)+len(c)+1)
+ byter := bytes.NewReader(buf.Bytes())
+ bufr := bufio.NewReader(byter)
+ mallocs := testing.AllocsPerRun(100, func() {
+ byter.Seek(0, io.SeekStart)
+ bufr.Reset(byter)
+ r := NewChunkedReader(bufr)
+ n, err := io.ReadFull(r, readBuf)
+ if n != len(readBuf)-1 {
+ t.Fatalf("read %d bytes; want %d", n, len(readBuf)-1)
+ }
+ if err != io.ErrUnexpectedEOF {
+ t.Fatalf("read error = %v; want ErrUnexpectedEOF", err)
+ }
+ })
+ if mallocs > 1.5 {
+ t.Errorf("mallocs = %v; want 1", mallocs)
+ }
+}
+
+func TestParseHexUint(t *testing.T) {
+ type testCase struct {
+ in string
+ want uint64
+ wantErr string
+ }
+ tests := []testCase{
+ {"x", 0, "invalid byte in chunk length"},
+ {"0000000000000000", 0, ""},
+ {"0000000000000001", 1, ""},
+ {"ffffffffffffffff", 1<<64 - 1, ""},
+ {"000000000000bogus", 0, "invalid byte in chunk length"},
+ {"00000000000000000", 0, "http chunk length too large"}, // could accept if we wanted
+ {"10000000000000000", 0, "http chunk length too large"},
+ {"00000000000000001", 0, "http chunk length too large"}, // could accept if we wanted
+ }
+ for i := uint64(0); i <= 1234; i++ {
+ tests = append(tests, testCase{in: fmt.Sprintf("%x", i), want: i})
+ }
+ for _, tt := range tests {
+ got, err := parseHexUint([]byte(tt.in))
+ if tt.wantErr != "" {
+ if !strings.Contains(fmt.Sprint(err), tt.wantErr) {
+ t.Errorf("parseHexUint(%q) = %v, %v; want error %q", tt.in, got, err, tt.wantErr)
+ }
+ } else {
+ if err != nil || got != tt.want {
+ t.Errorf("parseHexUint(%q) = %v, %v; want %v", tt.in, got, err, tt.want)
+ }
+ }
+ }
+}
+
+func TestChunkReadingIgnoresExtensions(t *testing.T) {
+ in := "7;ext=\"some quoted string\"\r\n" + // token=quoted string
+ "hello, \r\n" +
+ "17;someext\r\n" + // token without value
+ "world! 0123456789abcdef\r\n" +
+ "0;someextension=sometoken\r\n" // token=token
+ data, err := io.ReadAll(NewChunkedReader(strings.NewReader(in)))
+ if err != nil {
+ t.Fatalf("ReadAll = %q, %v", data, err)
+ }
+ if g, e := string(data), "hello, world! 0123456789abcdef"; g != e {
+ t.Errorf("read %q; want %q", g, e)
+ }
+}
+
+// Issue 17355: ChunkedReader shouldn't block waiting for more data
+// if it can return something.
+func TestChunkReadPartial(t *testing.T) {
+ pr, pw := io.Pipe()
+ go func() {
+ pw.Write([]byte("7\r\n1234567"))
+ }()
+ cr := NewChunkedReader(pr)
+ readBuf := make([]byte, 7)
+ n, err := cr.Read(readBuf)
+ if err != nil {
+ t.Fatal(err)
+ }
+ want := "1234567"
+ if n != 7 || string(readBuf) != want {
+ t.Fatalf("Read: %v %q; want %d, %q", n, readBuf[:n], len(want), want)
+ }
+ go func() {
+ pw.Write([]byte("xx"))
+ }()
+ _, err = cr.Read(readBuf)
+ if got := fmt.Sprint(err); !strings.Contains(got, "malformed") {
+ t.Fatalf("second read = %v; want malformed error", err)
+ }
+
+}
+
+// Issue 48861: ChunkedReader should report incomplete chunks
+func TestIncompleteChunk(t *testing.T) {
+ const valid = "4\r\nabcd\r\n" + "5\r\nabc\r\n\r\n" + "0\r\n"
+
+ for i := 0; i < len(valid); i++ {
+ incomplete := valid[:i]
+ r := NewChunkedReader(strings.NewReader(incomplete))
+ if _, err := io.ReadAll(r); err != io.ErrUnexpectedEOF {
+ t.Errorf("expected io.ErrUnexpectedEOF for %q, got %v", incomplete, err)
+ }
+ }
+
+ r := NewChunkedReader(strings.NewReader(valid))
+ if _, err := io.ReadAll(r); err != nil {
+ t.Errorf("unexpected error for %q: %v", valid, err)
+ }
+}
+
+func TestChunkEndReadError(t *testing.T) {
+ readErr := fmt.Errorf("chunk end read error")
+
+ r := NewChunkedReader(io.MultiReader(strings.NewReader("4\r\nabcd"), iotest.ErrReader(readErr)))
+ if _, err := io.ReadAll(r); err != readErr {
+ t.Errorf("expected %v, got %v", readErr, err)
+ }
+}
+
+func TestChunkReaderTooMuchOverhead(t *testing.T) {
+ // If the sender is sending 100x as many chunk header bytes as chunk data,
+ // we should reject the stream at some point.
+ chunk := []byte("1;")
+ for i := 0; i < 100; i++ {
+ chunk = append(chunk, 'a') // chunk extension
+ }
+ chunk = append(chunk, "\r\nX\r\n"...)
+ const bodylen = 1 << 20
+ r := NewChunkedReader(&funcReader{f: func(i int) ([]byte, error) {
+ if i < bodylen {
+ return chunk, nil
+ }
+ return []byte("0\r\n"), nil
+ }})
+ _, err := io.ReadAll(r)
+ if err == nil {
+ t.Fatalf("successfully read body with excessive overhead; want error")
+ }
+}
+
+func TestChunkReaderByteAtATime(t *testing.T) {
+ // Sending one byte per chunk should not trip the excess-overhead detection.
+ const bodylen = 1 << 20
+ r := NewChunkedReader(&funcReader{f: func(i int) ([]byte, error) {
+ if i < bodylen {
+ return []byte("1\r\nX\r\n"), nil
+ }
+ return []byte("0\r\n"), nil
+ }})
+ got, err := io.ReadAll(r)
+ if err != nil {
+ t.Errorf("unexpected error: %v", err)
+ }
+ if len(got) != bodylen {
+ t.Errorf("read %v bytes, want %v", len(got), bodylen)
+ }
+}
+
+type funcReader struct {
+ f func(iteration int) ([]byte, error)
+ i int
+ b []byte
+ err error
+}
+
+func (r *funcReader) Read(p []byte) (n int, err error) {
+ if len(r.b) == 0 && r.err == nil {
+ r.b, r.err = r.f(r.i)
+ r.i++
+ }
+ n = copy(p, r.b)
+ r.b = r.b[n:]
+ if len(r.b) > 0 {
+ return n, nil
+ }
+ return n, r.err
+}
diff --git a/src/net/http/internal/testcert/testcert.go b/src/net/http/internal/testcert/testcert.go
new file mode 100644
index 0000000..d510e79
--- /dev/null
+++ b/src/net/http/internal/testcert/testcert.go
@@ -0,0 +1,65 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package testcert contains a test-only localhost certificate.
+package testcert
+
+import "strings"
+
+// LocalhostCert is a PEM-encoded TLS cert with SAN IPs
+// "127.0.0.1" and "[::1]", expiring at Jan 29 16:00:00 2084 GMT.
+// generated from src/crypto/tls:
+// go run generate_cert.go --rsa-bits 2048 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h
+var LocalhostCert = []byte(`-----BEGIN CERTIFICATE-----
+MIIDOTCCAiGgAwIBAgIQSRJrEpBGFc7tNb1fb5pKFzANBgkqhkiG9w0BAQsFADAS
+MRAwDgYDVQQKEwdBY21lIENvMCAXDTcwMDEwMTAwMDAwMFoYDzIwODQwMTI5MTYw
+MDAwWjASMRAwDgYDVQQKEwdBY21lIENvMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A
+MIIBCgKCAQEA6Gba5tHV1dAKouAaXO3/ebDUU4rvwCUg/CNaJ2PT5xLD4N1Vcb8r
+bFSW2HXKq+MPfVdwIKR/1DczEoAGf/JWQTW7EgzlXrCd3rlajEX2D73faWJekD0U
+aUgz5vtrTXZ90BQL7WvRICd7FlEZ6FPOcPlumiyNmzUqtwGhO+9ad1W5BqJaRI6P
+YfouNkwR6Na4TzSj5BrqUfP0FwDizKSJ0XXmh8g8G9mtwxOSN3Ru1QFc61Xyeluk
+POGKBV/q6RBNklTNe0gI8usUMlYyoC7ytppNMW7X2vodAelSu25jgx2anj9fDVZu
+h7AXF5+4nJS4AAt0n1lNY7nGSsdZas8PbQIDAQABo4GIMIGFMA4GA1UdDwEB/wQE
+AwICpDATBgNVHSUEDDAKBggrBgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MB0GA1Ud
+DgQWBBStsdjh3/JCXXYlQryOrL4Sh7BW5TAuBgNVHREEJzAlggtleGFtcGxlLmNv
+bYcEfwAAAYcQAAAAAAAAAAAAAAAAAAAAATANBgkqhkiG9w0BAQsFAAOCAQEAxWGI
+5NhpF3nwwy/4yB4i/CwwSpLrWUa70NyhvprUBC50PxiXav1TeDzwzLx/o5HyNwsv
+cxv3HdkLW59i/0SlJSrNnWdfZ19oTcS+6PtLoVyISgtyN6DpkKpdG1cOkW3Cy2P2
++tK/tKHRP1Y/Ra0RiDpOAmqn0gCOFGz8+lqDIor/T7MTpibL3IxqWfPrvfVRHL3B
+grw/ZQTTIVjjh4JBSW3WyWgNo/ikC1lrVxzl4iPUGptxT36Cr7Zk2Bsg0XqwbOvK
+5d+NTDREkSnUbie4GeutujmX3Dsx88UiV6UY/4lHJa6I5leHUNOHahRbpbWeOfs/
+WkBKOclmOV2xlTVuPw==
+-----END CERTIFICATE-----`)
+
+// LocalhostKey is the private key for LocalhostCert.
+var LocalhostKey = []byte(testingKey(`-----BEGIN RSA TESTING KEY-----
+MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQDoZtrm0dXV0Aqi
+4Bpc7f95sNRTiu/AJSD8I1onY9PnEsPg3VVxvytsVJbYdcqr4w99V3AgpH/UNzMS
+gAZ/8lZBNbsSDOVesJ3euVqMRfYPvd9pYl6QPRRpSDPm+2tNdn3QFAvta9EgJ3sW
+URnoU85w+W6aLI2bNSq3AaE771p3VbkGolpEjo9h+i42TBHo1rhPNKPkGupR8/QX
+AOLMpInRdeaHyDwb2a3DE5I3dG7VAVzrVfJ6W6Q84YoFX+rpEE2SVM17SAjy6xQy
+VjKgLvK2mk0xbtfa+h0B6VK7bmODHZqeP18NVm6HsBcXn7iclLgAC3SfWU1jucZK
+x1lqzw9tAgMBAAECggEABWzxS1Y2wckblnXY57Z+sl6YdmLV+gxj2r8Qib7g4ZIk
+lIlWR1OJNfw7kU4eryib4fc6nOh6O4AWZyYqAK6tqNQSS/eVG0LQTLTTEldHyVJL
+dvBe+MsUQOj4nTndZW+QvFzbcm2D8lY5n2nBSxU5ypVoKZ1EqQzytFcLZpTN7d89
+EPj0qDyrV4NZlWAwL1AygCwnlwhMQjXEalVF1ylXwU3QzyZ/6MgvF6d3SSUlh+sq
+XefuyigXw484cQQgbzopv6niMOmGP3of+yV4JQqUSb3IDmmT68XjGd2Dkxl4iPki
+6ZwXf3CCi+c+i/zVEcufgZ3SLf8D99kUGE7v7fZ6AQKBgQD1ZX3RAla9hIhxCf+O
+3D+I1j2LMrdjAh0ZKKqwMR4JnHX3mjQI6LwqIctPWTU8wYFECSh9klEclSdCa64s
+uI/GNpcqPXejd0cAAdqHEEeG5sHMDt0oFSurL4lyud0GtZvwlzLuwEweuDtvT9cJ
+Wfvl86uyO36IW8JdvUprYDctrQKBgQDycZ697qutBieZlGkHpnYWUAeImVA878sJ
+w44NuXHvMxBPz+lbJGAg8Cn8fcxNAPqHIraK+kx3po8cZGQywKHUWsxi23ozHoxo
++bGqeQb9U661TnfdDspIXia+xilZt3mm5BPzOUuRqlh4Y9SOBpSWRmEhyw76w4ZP
+OPxjWYAgwQKBgA/FehSYxeJgRjSdo+MWnK66tjHgDJE8bYpUZsP0JC4R9DL5oiaA
+brd2fI6Y+SbyeNBallObt8LSgzdtnEAbjIH8uDJqyOmknNePRvAvR6mP4xyuR+Bv
+m+Lgp0DMWTw5J9CKpydZDItc49T/mJ5tPhdFVd+am0NAQnmr1MCZ6nHxAoGABS3Y
+LkaC9FdFUUqSU8+Chkd/YbOkuyiENdkvl6t2e52jo5DVc1T7mLiIrRQi4SI8N9bN
+/3oJWCT+uaSLX2ouCtNFunblzWHBrhxnZzTeqVq4SLc8aESAnbslKL4i8/+vYZlN
+s8xtiNcSvL+lMsOBORSXzpj/4Ot8WwTkn1qyGgECgYBKNTypzAHeLE6yVadFp3nQ
+Ckq9yzvP/ib05rvgbvrne00YeOxqJ9gtTrzgh7koqJyX1L4NwdkEza4ilDWpucn0
+xiUZS4SoaJq6ZvcBYS62Yr1t8n09iG47YL8ibgtmH3L+svaotvpVxVK+d7BLevA/
+ZboOWVe3icTy64BT3OQhmg==
+-----END RSA TESTING KEY-----`))
+
+func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") }
diff --git a/src/net/http/jar.go b/src/net/http/jar.go
new file mode 100644
index 0000000..5c3de0d
--- /dev/null
+++ b/src/net/http/jar.go
@@ -0,0 +1,27 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "net/url"
+)
+
+// A CookieJar manages storage and use of cookies in HTTP requests.
+//
+// Implementations of CookieJar must be safe for concurrent use by multiple
+// goroutines.
+//
+// The net/http/cookiejar package provides a CookieJar implementation.
+type CookieJar interface {
+ // SetCookies handles the receipt of the cookies in a reply for the
+ // given URL. It may or may not choose to save the cookies, depending
+ // on the jar's policy and implementation.
+ SetCookies(u *url.URL, cookies []*Cookie)
+
+ // Cookies returns the cookies to send in a request for the given URL.
+ // It is up to the implementation to honor the standard cookie use
+ // restrictions such as in RFC 6265.
+ Cookies(u *url.URL) []*Cookie
+}
diff --git a/src/net/http/main_test.go b/src/net/http/main_test.go
new file mode 100644
index 0000000..ff56ef8
--- /dev/null
+++ b/src/net/http/main_test.go
@@ -0,0 +1,175 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http_test
+
+import (
+ "fmt"
+ "io"
+ "log"
+ "net/http"
+ "os"
+ "runtime"
+ "sort"
+ "strings"
+ "testing"
+ "time"
+)
+
+var quietLog = log.New(io.Discard, "", 0)
+
+func TestMain(m *testing.M) {
+ *http.MaxWriteWaitBeforeConnReuse = 60 * time.Minute
+ v := m.Run()
+ if v == 0 && goroutineLeaked() {
+ os.Exit(1)
+ }
+ os.Exit(v)
+}
+
+func interestingGoroutines() (gs []string) {
+ buf := make([]byte, 2<<20)
+ buf = buf[:runtime.Stack(buf, true)]
+ for _, g := range strings.Split(string(buf), "\n\n") {
+ _, stack, _ := strings.Cut(g, "\n")
+ stack = strings.TrimSpace(stack)
+ if stack == "" ||
+ strings.Contains(stack, "testing.(*M).before.func1") ||
+ strings.Contains(stack, "os/signal.signal_recv") ||
+ strings.Contains(stack, "created by net.startServer") ||
+ strings.Contains(stack, "created by testing.RunTests") ||
+ strings.Contains(stack, "closeWriteAndWait") ||
+ strings.Contains(stack, "testing.Main(") ||
+ // These only show up with GOTRACEBACK=2; Issue 5005 (comment 28)
+ strings.Contains(stack, "runtime.goexit") ||
+ strings.Contains(stack, "created by runtime.gc") ||
+ strings.Contains(stack, "interestingGoroutines") ||
+ strings.Contains(stack, "runtime.MHeap_Scavenger") {
+ continue
+ }
+ gs = append(gs, stack)
+ }
+ sort.Strings(gs)
+ return
+}
+
+// Verify the other tests didn't leave any goroutines running.
+func goroutineLeaked() bool {
+ if testing.Short() || runningBenchmarks() {
+ // Don't worry about goroutine leaks in -short mode or in
+ // benchmark mode. Too distracting when there are false positives.
+ return false
+ }
+
+ var stackCount map[string]int
+ for i := 0; i < 5; i++ {
+ n := 0
+ stackCount = make(map[string]int)
+ gs := interestingGoroutines()
+ for _, g := range gs {
+ stackCount[g]++
+ n++
+ }
+ if n == 0 {
+ return false
+ }
+ // Wait for goroutines to schedule and die off:
+ time.Sleep(100 * time.Millisecond)
+ }
+ fmt.Fprintf(os.Stderr, "Too many goroutines running after net/http test(s).\n")
+ for stack, count := range stackCount {
+ fmt.Fprintf(os.Stderr, "%d instances of:\n%s\n", count, stack)
+ }
+ return true
+}
+
+// setParallel marks t as a parallel test if we're in short mode
+// (all.bash), but as a serial test otherwise. Using t.Parallel isn't
+// compatible with the afterTest func in non-short mode.
+func setParallel(t *testing.T) {
+ if strings.Contains(t.Name(), "HTTP2") {
+ http.CondSkipHTTP2(t)
+ }
+ if testing.Short() {
+ t.Parallel()
+ }
+}
+
+func runningBenchmarks() bool {
+ for i, arg := range os.Args {
+ if strings.HasPrefix(arg, "-test.bench=") && !strings.HasSuffix(arg, "=") {
+ return true
+ }
+ if arg == "-test.bench" && i < len(os.Args)-1 && os.Args[i+1] != "" {
+ return true
+ }
+ }
+ return false
+}
+
+var leakReported bool
+
+func afterTest(t testing.TB) {
+ http.DefaultTransport.(*http.Transport).CloseIdleConnections()
+ if testing.Short() {
+ return
+ }
+ if leakReported {
+ // To avoid confusion, only report the first leak of each test run.
+ // After the first leak has been reported, we can't tell whether the leaked
+ // goroutines are a new leak from a subsequent test or just the same
+ // goroutines from the first leak still hanging around, and we may add a lot
+ // of latency waiting for them to exit at the end of each test.
+ return
+ }
+
+ // We shouldn't be running the leak check for parallel tests, because we might
+ // report the goroutines from a test that is still running as a leak from a
+ // completely separate test that has just finished. So we use non-atomic loads
+ // and stores for the leakReported variable, and store every time we start a
+ // leak check so that the race detector will flag concurrent leak checks as a
+ // race even if we don't detect any leaks.
+ leakReported = true
+
+ var bad string
+ badSubstring := map[string]string{
+ ").readLoop(": "a Transport",
+ ").writeLoop(": "a Transport",
+ "created by net/http/httptest.(*Server).Start": "an httptest.Server",
+ "timeoutHandler": "a TimeoutHandler",
+ "net.(*netFD).connect(": "a timing out dial",
+ ").noteClientGone(": "a closenotifier sender",
+ }
+ var stacks string
+ for i := 0; i < 10; i++ {
+ bad = ""
+ stacks = strings.Join(interestingGoroutines(), "\n\n")
+ for substr, what := range badSubstring {
+ if strings.Contains(stacks, substr) {
+ bad = what
+ }
+ }
+ if bad == "" {
+ leakReported = false
+ return
+ }
+ // Bad stuff found, but goroutines might just still be
+ // shutting down, so give it some time.
+ time.Sleep(250 * time.Millisecond)
+ }
+ t.Errorf("Test appears to have leaked %s:\n%s", bad, stacks)
+}
+
+// waitCondition waits for fn to return true,
+// checking immediately and then at exponentially increasing intervals.
+func waitCondition(t testing.TB, delay time.Duration, fn func(time.Duration) bool) {
+ t.Helper()
+ start := time.Now()
+ var since time.Duration
+ for !fn(since) {
+ time.Sleep(delay)
+ delay = 2*delay - (delay / 2) // 1.5x, rounded up
+ since = time.Since(start)
+ }
+}
diff --git a/src/net/http/method.go b/src/net/http/method.go
new file mode 100644
index 0000000..6f46155
--- /dev/null
+++ b/src/net/http/method.go
@@ -0,0 +1,20 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+// Common HTTP methods.
+//
+// Unless otherwise noted, these are defined in RFC 7231 section 4.3.
+const (
+ MethodGet = "GET"
+ MethodHead = "HEAD"
+ MethodPost = "POST"
+ MethodPut = "PUT"
+ MethodPatch = "PATCH" // RFC 5789
+ MethodDelete = "DELETE"
+ MethodConnect = "CONNECT"
+ MethodOptions = "OPTIONS"
+ MethodTrace = "TRACE"
+)
diff --git a/src/net/http/omithttp2.go b/src/net/http/omithttp2.go
new file mode 100644
index 0000000..ca08ddf
--- /dev/null
+++ b/src/net/http/omithttp2.go
@@ -0,0 +1,79 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build nethttpomithttp2
+
+package http
+
+import (
+ "errors"
+ "sync"
+ "time"
+)
+
+func init() {
+ omitBundledHTTP2 = true
+}
+
+const noHTTP2 = "no bundled HTTP/2" // should never see this
+
+var http2errRequestCanceled = errors.New("net/http: request canceled")
+
+var http2goAwayTimeout = 1 * time.Second
+
+const http2NextProtoTLS = "h2"
+
+type http2Transport struct {
+ MaxHeaderListSize uint32
+ ConnPool any
+}
+
+func (*http2Transport) RoundTrip(*Request) (*Response, error) { panic(noHTTP2) }
+func (*http2Transport) CloseIdleConnections() {}
+
+type http2noDialH2RoundTripper struct{}
+
+func (http2noDialH2RoundTripper) RoundTrip(*Request) (*Response, error) { panic(noHTTP2) }
+
+type http2noDialClientConnPool struct {
+ http2clientConnPool http2clientConnPool
+}
+
+type http2clientConnPool struct {
+ mu *sync.Mutex
+ conns map[string][]*http2clientConn
+}
+
+type http2clientConn struct{}
+
+type http2clientConnIdleState struct {
+ canTakeNewRequest bool
+}
+
+func (cc *http2clientConn) idleState() http2clientConnIdleState { return http2clientConnIdleState{} }
+
+func http2configureTransports(*Transport) (*http2Transport, error) { panic(noHTTP2) }
+
+func http2isNoCachedConnError(err error) bool {
+ _, ok := err.(interface{ IsHTTP2NoCachedConnError() })
+ return ok
+}
+
+type http2Server struct {
+ NewWriteScheduler func() http2WriteScheduler
+}
+
+type http2WriteScheduler any
+
+func http2NewPriorityWriteScheduler(any) http2WriteScheduler { panic(noHTTP2) }
+
+func http2ConfigureServer(s *Server, conf *http2Server) error { panic(noHTTP2) }
+
+var http2ErrNoCachedConn = http2noCachedConnError{}
+
+type http2noCachedConnError struct{}
+
+func (http2noCachedConnError) IsHTTP2NoCachedConnError() {}
+
+func (http2noCachedConnError) Error() string { return "http2: no cached connection was available" }
diff --git a/src/net/http/pprof/pprof.go b/src/net/http/pprof/pprof.go
new file mode 100644
index 0000000..bc3225d
--- /dev/null
+++ b/src/net/http/pprof/pprof.go
@@ -0,0 +1,464 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package pprof serves via its HTTP server runtime profiling data
+// in the format expected by the pprof visualization tool.
+//
+// The package is typically only imported for the side effect of
+// registering its HTTP handlers.
+// The handled paths all begin with /debug/pprof/.
+//
+// To use pprof, link this package into your program:
+//
+// import _ "net/http/pprof"
+//
+// If your application is not already running an http server, you
+// need to start one. Add "net/http" and "log" to your imports and
+// the following code to your main function:
+//
+// go func() {
+// log.Println(http.ListenAndServe("localhost:6060", nil))
+// }()
+//
+// By default, all the profiles listed in [runtime/pprof.Profile] are
+// available (via [Handler]), in addition to the [Cmdline], [Profile], [Symbol],
+// and [Trace] profiles defined in this package.
+// If you are not using DefaultServeMux, you will have to register handlers
+// with the mux you are using.
+//
+// # Parameters
+//
+// Parameters can be passed via GET query params:
+//
+// - debug=N (all profiles): response format: N = 0: binary (default), N > 0: plaintext
+// - gc=N (heap profile): N > 0: run a garbage collection cycle before profiling
+// - seconds=N (allocs, block, goroutine, heap, mutex, threadcreate profiles): return a delta profile
+// - seconds=N (cpu (profile), trace profiles): profile for the given duration
+//
+// # Usage examples
+//
+// Use the pprof tool to look at the heap profile:
+//
+// go tool pprof http://localhost:6060/debug/pprof/heap
+//
+// Or to look at a 30-second CPU profile:
+//
+// go tool pprof http://localhost:6060/debug/pprof/profile?seconds=30
+//
+// Or to look at the goroutine blocking profile, after calling
+// runtime.SetBlockProfileRate in your program:
+//
+// go tool pprof http://localhost:6060/debug/pprof/block
+//
+// Or to look at the holders of contended mutexes, after calling
+// runtime.SetMutexProfileFraction in your program:
+//
+// go tool pprof http://localhost:6060/debug/pprof/mutex
+//
+// The package also exports a handler that serves execution trace data
+// for the "go tool trace" command. To collect a 5-second execution trace:
+//
+// curl -o trace.out http://localhost:6060/debug/pprof/trace?seconds=5
+// go tool trace trace.out
+//
+// To view all available profiles, open http://localhost:6060/debug/pprof/
+// in your browser.
+//
+// For a study of the facility in action, visit
+// https://blog.golang.org/2011/06/profiling-go-programs.html.
+package pprof
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "fmt"
+ "html"
+ "internal/profile"
+ "io"
+ "log"
+ "net/http"
+ "net/url"
+ "os"
+ "runtime"
+ "runtime/pprof"
+ "runtime/trace"
+ "sort"
+ "strconv"
+ "strings"
+ "time"
+)
+
+func init() {
+ http.HandleFunc("/debug/pprof/", Index)
+ http.HandleFunc("/debug/pprof/cmdline", Cmdline)
+ http.HandleFunc("/debug/pprof/profile", Profile)
+ http.HandleFunc("/debug/pprof/symbol", Symbol)
+ http.HandleFunc("/debug/pprof/trace", Trace)
+}
+
+// Cmdline responds with the running program's
+// command line, with arguments separated by NUL bytes.
+// The package initialization registers it as /debug/pprof/cmdline.
+func Cmdline(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("X-Content-Type-Options", "nosniff")
+ w.Header().Set("Content-Type", "text/plain; charset=utf-8")
+ fmt.Fprint(w, strings.Join(os.Args, "\x00"))
+}
+
+func sleep(r *http.Request, d time.Duration) {
+ select {
+ case <-time.After(d):
+ case <-r.Context().Done():
+ }
+}
+
+func durationExceedsWriteTimeout(r *http.Request, seconds float64) bool {
+ srv, ok := r.Context().Value(http.ServerContextKey).(*http.Server)
+ return ok && srv.WriteTimeout != 0 && seconds >= srv.WriteTimeout.Seconds()
+}
+
+func serveError(w http.ResponseWriter, status int, txt string) {
+ w.Header().Set("Content-Type", "text/plain; charset=utf-8")
+ w.Header().Set("X-Go-Pprof", "1")
+ w.Header().Del("Content-Disposition")
+ w.WriteHeader(status)
+ fmt.Fprintln(w, txt)
+}
+
+// Profile responds with the pprof-formatted cpu profile.
+// Profiling lasts for duration specified in seconds GET parameter, or for 30 seconds if not specified.
+// The package initialization registers it as /debug/pprof/profile.
+func Profile(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("X-Content-Type-Options", "nosniff")
+ sec, err := strconv.ParseInt(r.FormValue("seconds"), 10, 64)
+ if sec <= 0 || err != nil {
+ sec = 30
+ }
+
+ if durationExceedsWriteTimeout(r, float64(sec)) {
+ serveError(w, http.StatusBadRequest, "profile duration exceeds server's WriteTimeout")
+ return
+ }
+
+ // Set Content Type assuming StartCPUProfile will work,
+ // because if it does it starts writing.
+ w.Header().Set("Content-Type", "application/octet-stream")
+ w.Header().Set("Content-Disposition", `attachment; filename="profile"`)
+ if err := pprof.StartCPUProfile(w); err != nil {
+ // StartCPUProfile failed, so no writes yet.
+ serveError(w, http.StatusInternalServerError,
+ fmt.Sprintf("Could not enable CPU profiling: %s", err))
+ return
+ }
+ sleep(r, time.Duration(sec)*time.Second)
+ pprof.StopCPUProfile()
+}
+
+// Trace responds with the execution trace in binary form.
+// Tracing lasts for duration specified in seconds GET parameter, or for 1 second if not specified.
+// The package initialization registers it as /debug/pprof/trace.
+func Trace(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("X-Content-Type-Options", "nosniff")
+ sec, err := strconv.ParseFloat(r.FormValue("seconds"), 64)
+ if sec <= 0 || err != nil {
+ sec = 1
+ }
+
+ if durationExceedsWriteTimeout(r, sec) {
+ serveError(w, http.StatusBadRequest, "profile duration exceeds server's WriteTimeout")
+ return
+ }
+
+ // Set Content Type assuming trace.Start will work,
+ // because if it does it starts writing.
+ w.Header().Set("Content-Type", "application/octet-stream")
+ w.Header().Set("Content-Disposition", `attachment; filename="trace"`)
+ if err := trace.Start(w); err != nil {
+ // trace.Start failed, so no writes yet.
+ serveError(w, http.StatusInternalServerError,
+ fmt.Sprintf("Could not enable tracing: %s", err))
+ return
+ }
+ sleep(r, time.Duration(sec*float64(time.Second)))
+ trace.Stop()
+}
+
+// Symbol looks up the program counters listed in the request,
+// responding with a table mapping program counters to function names.
+// The package initialization registers it as /debug/pprof/symbol.
+func Symbol(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("X-Content-Type-Options", "nosniff")
+ w.Header().Set("Content-Type", "text/plain; charset=utf-8")
+
+ // We have to read the whole POST body before
+ // writing any output. Buffer the output here.
+ var buf bytes.Buffer
+
+ // We don't know how many symbols we have, but we
+ // do have symbol information. Pprof only cares whether
+ // this number is 0 (no symbols available) or > 0.
+ fmt.Fprintf(&buf, "num_symbols: 1\n")
+
+ var b *bufio.Reader
+ if r.Method == "POST" {
+ b = bufio.NewReader(r.Body)
+ } else {
+ b = bufio.NewReader(strings.NewReader(r.URL.RawQuery))
+ }
+
+ for {
+ word, err := b.ReadSlice('+')
+ if err == nil {
+ word = word[0 : len(word)-1] // trim +
+ }
+ pc, _ := strconv.ParseUint(string(word), 0, 64)
+ if pc != 0 {
+ f := runtime.FuncForPC(uintptr(pc))
+ if f != nil {
+ fmt.Fprintf(&buf, "%#x %s\n", pc, f.Name())
+ }
+ }
+
+ // Wait until here to check for err; the last
+ // symbol will have an err because it doesn't end in +.
+ if err != nil {
+ if err != io.EOF {
+ fmt.Fprintf(&buf, "reading request: %v\n", err)
+ }
+ break
+ }
+ }
+
+ w.Write(buf.Bytes())
+}
+
+// Handler returns an HTTP handler that serves the named profile.
+// Available profiles can be found in [runtime/pprof.Profile].
+func Handler(name string) http.Handler {
+ return handler(name)
+}
+
+type handler string
+
+func (name handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("X-Content-Type-Options", "nosniff")
+ p := pprof.Lookup(string(name))
+ if p == nil {
+ serveError(w, http.StatusNotFound, "Unknown profile")
+ return
+ }
+ if sec := r.FormValue("seconds"); sec != "" {
+ name.serveDeltaProfile(w, r, p, sec)
+ return
+ }
+ gc, _ := strconv.Atoi(r.FormValue("gc"))
+ if name == "heap" && gc > 0 {
+ runtime.GC()
+ }
+ debug, _ := strconv.Atoi(r.FormValue("debug"))
+ if debug != 0 {
+ w.Header().Set("Content-Type", "text/plain; charset=utf-8")
+ } else {
+ w.Header().Set("Content-Type", "application/octet-stream")
+ w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, name))
+ }
+ p.WriteTo(w, debug)
+}
+
+func (name handler) serveDeltaProfile(w http.ResponseWriter, r *http.Request, p *pprof.Profile, secStr string) {
+ sec, err := strconv.ParseInt(secStr, 10, 64)
+ if err != nil || sec <= 0 {
+ serveError(w, http.StatusBadRequest, `invalid value for "seconds" - must be a positive integer`)
+ return
+ }
+ if !profileSupportsDelta[name] {
+ serveError(w, http.StatusBadRequest, `"seconds" parameter is not supported for this profile type`)
+ return
+ }
+ // 'name' should be a key in profileSupportsDelta.
+ if durationExceedsWriteTimeout(r, float64(sec)) {
+ serveError(w, http.StatusBadRequest, "profile duration exceeds server's WriteTimeout")
+ return
+ }
+ debug, _ := strconv.Atoi(r.FormValue("debug"))
+ if debug != 0 {
+ serveError(w, http.StatusBadRequest, "seconds and debug params are incompatible")
+ return
+ }
+ p0, err := collectProfile(p)
+ if err != nil {
+ serveError(w, http.StatusInternalServerError, "failed to collect profile")
+ return
+ }
+
+ t := time.NewTimer(time.Duration(sec) * time.Second)
+ defer t.Stop()
+
+ select {
+ case <-r.Context().Done():
+ err := r.Context().Err()
+ if err == context.DeadlineExceeded {
+ serveError(w, http.StatusRequestTimeout, err.Error())
+ } else { // TODO: what's a good status code for canceled requests? 400?
+ serveError(w, http.StatusInternalServerError, err.Error())
+ }
+ return
+ case <-t.C:
+ }
+
+ p1, err := collectProfile(p)
+ if err != nil {
+ serveError(w, http.StatusInternalServerError, "failed to collect profile")
+ return
+ }
+ ts := p1.TimeNanos
+ dur := p1.TimeNanos - p0.TimeNanos
+
+ p0.Scale(-1)
+
+ p1, err = profile.Merge([]*profile.Profile{p0, p1})
+ if err != nil {
+ serveError(w, http.StatusInternalServerError, "failed to compute delta")
+ return
+ }
+
+ p1.TimeNanos = ts // set since we don't know what profile.Merge set for TimeNanos.
+ p1.DurationNanos = dur
+
+ w.Header().Set("Content-Type", "application/octet-stream")
+ w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s-delta"`, name))
+ p1.Write(w)
+}
+
+func collectProfile(p *pprof.Profile) (*profile.Profile, error) {
+ var buf bytes.Buffer
+ if err := p.WriteTo(&buf, 0); err != nil {
+ return nil, err
+ }
+ ts := time.Now().UnixNano()
+ p0, err := profile.Parse(&buf)
+ if err != nil {
+ return nil, err
+ }
+ p0.TimeNanos = ts
+ return p0, nil
+}
+
+var profileSupportsDelta = map[handler]bool{
+ "allocs": true,
+ "block": true,
+ "goroutine": true,
+ "heap": true,
+ "mutex": true,
+ "threadcreate": true,
+}
+
+var profileDescriptions = map[string]string{
+ "allocs": "A sampling of all past memory allocations",
+ "block": "Stack traces that led to blocking on synchronization primitives",
+ "cmdline": "The command line invocation of the current program",
+ "goroutine": "Stack traces of all current goroutines. Use debug=2 as a query parameter to export in the same format as an unrecovered panic.",
+ "heap": "A sampling of memory allocations of live objects. You can specify the gc GET parameter to run GC before taking the heap sample.",
+ "mutex": "Stack traces of holders of contended mutexes",
+ "profile": "CPU profile. You can specify the duration in the seconds GET parameter. After you get the profile file, use the go tool pprof command to investigate the profile.",
+ "threadcreate": "Stack traces that led to the creation of new OS threads",
+ "trace": "A trace of execution of the current program. You can specify the duration in the seconds GET parameter. After you get the trace file, use the go tool trace command to investigate the trace.",
+}
+
+type profileEntry struct {
+ Name string
+ Href string
+ Desc string
+ Count int
+}
+
+// Index responds with the pprof-formatted profile named by the request.
+// For example, "/debug/pprof/heap" serves the "heap" profile.
+// Index responds to a request for "/debug/pprof/" with an HTML page
+// listing the available profiles.
+func Index(w http.ResponseWriter, r *http.Request) {
+ if name, found := strings.CutPrefix(r.URL.Path, "/debug/pprof/"); found {
+ if name != "" {
+ handler(name).ServeHTTP(w, r)
+ return
+ }
+ }
+
+ w.Header().Set("X-Content-Type-Options", "nosniff")
+ w.Header().Set("Content-Type", "text/html; charset=utf-8")
+
+ var profiles []profileEntry
+ for _, p := range pprof.Profiles() {
+ profiles = append(profiles, profileEntry{
+ Name: p.Name(),
+ Href: p.Name(),
+ Desc: profileDescriptions[p.Name()],
+ Count: p.Count(),
+ })
+ }
+
+ // Adding other profiles exposed from within this package
+ for _, p := range []string{"cmdline", "profile", "trace"} {
+ profiles = append(profiles, profileEntry{
+ Name: p,
+ Href: p,
+ Desc: profileDescriptions[p],
+ })
+ }
+
+ sort.Slice(profiles, func(i, j int) bool {
+ return profiles[i].Name < profiles[j].Name
+ })
+
+ if err := indexTmplExecute(w, profiles); err != nil {
+ log.Print(err)
+ }
+}
+
+func indexTmplExecute(w io.Writer, profiles []profileEntry) error {
+ var b bytes.Buffer
+ b.WriteString(`<html>
+<head>
+<title>/debug/pprof/</title>
+<style>
+.profile-name{
+ display:inline-block;
+ width:6rem;
+}
+</style>
+</head>
+<body>
+/debug/pprof/
+<br>
+<p>Set debug=1 as a query parameter to export in legacy text format</p>
+<br>
+Types of profiles available:
+<table>
+<thead><td>Count</td><td>Profile</td></thead>
+`)
+
+ for _, profile := range profiles {
+ link := &url.URL{Path: profile.Href, RawQuery: "debug=1"}
+ fmt.Fprintf(&b, "<tr><td>%d</td><td><a href='%s'>%s</a></td></tr>\n", profile.Count, link, html.EscapeString(profile.Name))
+ }
+
+ b.WriteString(`</table>
+<a href="goroutine?debug=2">full goroutine stack dump</a>
+<br>
+<p>
+Profile Descriptions:
+<ul>
+`)
+ for _, profile := range profiles {
+ fmt.Fprintf(&b, "<li><div class=profile-name>%s: </div> %s</li>\n", html.EscapeString(profile.Name), html.EscapeString(profile.Desc))
+ }
+ b.WriteString(`</ul>
+</p>
+</body>
+</html>`)
+
+ _, err := w.Write(b.Bytes())
+ return err
+}
diff --git a/src/net/http/pprof/pprof_test.go b/src/net/http/pprof/pprof_test.go
new file mode 100644
index 0000000..f82ad45
--- /dev/null
+++ b/src/net/http/pprof/pprof_test.go
@@ -0,0 +1,263 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package pprof
+
+import (
+ "bytes"
+ "fmt"
+ "internal/profile"
+ "internal/testenv"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "runtime"
+ "runtime/pprof"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+)
+
+// TestDescriptions checks that the profile names under runtime/pprof package
+// have a key in the description map.
+func TestDescriptions(t *testing.T) {
+ for _, p := range pprof.Profiles() {
+ _, ok := profileDescriptions[p.Name()]
+ if ok != true {
+ t.Errorf("%s does not exist in profileDescriptions map\n", p.Name())
+ }
+ }
+}
+
+func TestHandlers(t *testing.T) {
+ testCases := []struct {
+ path string
+ handler http.HandlerFunc
+ statusCode int
+ contentType string
+ contentDisposition string
+ resp []byte
+ }{
+ {"/debug/pprof/<script>scripty<script>", Index, http.StatusNotFound, "text/plain; charset=utf-8", "", []byte("Unknown profile\n")},
+ {"/debug/pprof/heap", Index, http.StatusOK, "application/octet-stream", `attachment; filename="heap"`, nil},
+ {"/debug/pprof/heap?debug=1", Index, http.StatusOK, "text/plain; charset=utf-8", "", nil},
+ {"/debug/pprof/cmdline", Cmdline, http.StatusOK, "text/plain; charset=utf-8", "", nil},
+ {"/debug/pprof/profile?seconds=1", Profile, http.StatusOK, "application/octet-stream", `attachment; filename="profile"`, nil},
+ {"/debug/pprof/symbol", Symbol, http.StatusOK, "text/plain; charset=utf-8", "", nil},
+ {"/debug/pprof/trace", Trace, http.StatusOK, "application/octet-stream", `attachment; filename="trace"`, nil},
+ {"/debug/pprof/mutex", Index, http.StatusOK, "application/octet-stream", `attachment; filename="mutex"`, nil},
+ {"/debug/pprof/block?seconds=1", Index, http.StatusOK, "application/octet-stream", `attachment; filename="block-delta"`, nil},
+ {"/debug/pprof/goroutine?seconds=1", Index, http.StatusOK, "application/octet-stream", `attachment; filename="goroutine-delta"`, nil},
+ {"/debug/pprof/", Index, http.StatusOK, "text/html; charset=utf-8", "", []byte("Types of profiles available:")},
+ }
+ for _, tc := range testCases {
+ t.Run(tc.path, func(t *testing.T) {
+ req := httptest.NewRequest("GET", "http://example.com"+tc.path, nil)
+ w := httptest.NewRecorder()
+ tc.handler(w, req)
+
+ resp := w.Result()
+ if got, want := resp.StatusCode, tc.statusCode; got != want {
+ t.Errorf("status code: got %d; want %d", got, want)
+ }
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ t.Errorf("when reading response body, expected non-nil err; got %v", err)
+ }
+ if got, want := resp.Header.Get("X-Content-Type-Options"), "nosniff"; got != want {
+ t.Errorf("X-Content-Type-Options: got %q; want %q", got, want)
+ }
+ if got, want := resp.Header.Get("Content-Type"), tc.contentType; got != want {
+ t.Errorf("Content-Type: got %q; want %q", got, want)
+ }
+ if got, want := resp.Header.Get("Content-Disposition"), tc.contentDisposition; got != want {
+ t.Errorf("Content-Disposition: got %q; want %q", got, want)
+ }
+
+ if resp.StatusCode == http.StatusOK {
+ return
+ }
+ if got, want := resp.Header.Get("X-Go-Pprof"), "1"; got != want {
+ t.Errorf("X-Go-Pprof: got %q; want %q", got, want)
+ }
+ if !bytes.Equal(body, tc.resp) {
+ t.Errorf("response: got %q; want %q", body, tc.resp)
+ }
+ })
+ }
+}
+
+var Sink uint32
+
+func mutexHog1(mu1, mu2 *sync.Mutex, start time.Time, dt time.Duration) {
+ atomic.AddUint32(&Sink, 1)
+ for time.Since(start) < dt {
+ // When using gccgo the loop of mutex operations is
+ // not preemptible. This can cause the loop to block a GC,
+ // causing the time limits in TestDeltaContentionz to fail.
+ // Since this loop is not very realistic, when using
+ // gccgo add preemption points 100 times a second.
+ t1 := time.Now()
+ for time.Since(start) < dt && time.Since(t1) < 10*time.Millisecond {
+ mu1.Lock()
+ mu2.Lock()
+ mu1.Unlock()
+ mu2.Unlock()
+ }
+ if runtime.Compiler == "gccgo" {
+ runtime.Gosched()
+ }
+ }
+}
+
+// mutexHog2 is almost identical to mutexHog but we keep them separate
+// in order to distinguish them with function names in the stack trace.
+// We make them slightly different, using Sink, because otherwise
+// gccgo -c opt will merge them.
+func mutexHog2(mu1, mu2 *sync.Mutex, start time.Time, dt time.Duration) {
+ atomic.AddUint32(&Sink, 2)
+ for time.Since(start) < dt {
+ // See comment in mutexHog.
+ t1 := time.Now()
+ for time.Since(start) < dt && time.Since(t1) < 10*time.Millisecond {
+ mu1.Lock()
+ mu2.Lock()
+ mu1.Unlock()
+ mu2.Unlock()
+ }
+ if runtime.Compiler == "gccgo" {
+ runtime.Gosched()
+ }
+ }
+}
+
+// mutexHog starts multiple goroutines that runs the given hogger function for the specified duration.
+// The hogger function will be given two mutexes to lock & unlock.
+func mutexHog(duration time.Duration, hogger func(mu1, mu2 *sync.Mutex, start time.Time, dt time.Duration)) {
+ start := time.Now()
+ mu1 := new(sync.Mutex)
+ mu2 := new(sync.Mutex)
+ var wg sync.WaitGroup
+ wg.Add(10)
+ for i := 0; i < 10; i++ {
+ go func() {
+ defer wg.Done()
+ hogger(mu1, mu2, start, duration)
+ }()
+ }
+ wg.Wait()
+}
+
+func TestDeltaProfile(t *testing.T) {
+ if strings.HasPrefix(runtime.GOARCH, "arm") {
+ testenv.SkipFlaky(t, 50218)
+ }
+
+ rate := runtime.SetMutexProfileFraction(1)
+ defer func() {
+ runtime.SetMutexProfileFraction(rate)
+ }()
+
+ // mutexHog1 will appear in non-delta mutex profile
+ // if the mutex profile works.
+ mutexHog(20*time.Millisecond, mutexHog1)
+
+ // If mutexHog1 does not appear in the mutex profile,
+ // skip this test. Mutex profile is likely not working,
+ // so is the delta profile.
+
+ p, err := query("/debug/pprof/mutex")
+ if err != nil {
+ t.Skipf("mutex profile is unsupported: %v", err)
+ }
+
+ if !seen(p, "mutexHog1") {
+ t.Skipf("mutex profile is not working: %v", p)
+ }
+
+ // causes mutexHog2 call stacks to appear in the mutex profile.
+ done := make(chan bool)
+ go func() {
+ for {
+ mutexHog(20*time.Millisecond, mutexHog2)
+ select {
+ case <-done:
+ done <- true
+ return
+ default:
+ time.Sleep(10 * time.Millisecond)
+ }
+ }
+ }()
+ defer func() { // cleanup the above goroutine.
+ done <- true
+ <-done // wait for the goroutine to exit.
+ }()
+
+ for _, d := range []int{1, 4, 16, 32} {
+ endpoint := fmt.Sprintf("/debug/pprof/mutex?seconds=%d", d)
+ p, err := query(endpoint)
+ if err != nil {
+ t.Fatalf("failed to query %q: %v", endpoint, err)
+ }
+ if !seen(p, "mutexHog1") && seen(p, "mutexHog2") && p.DurationNanos > 0 {
+ break // pass
+ }
+ if d == 32 {
+ t.Errorf("want mutexHog2 but no mutexHog1 in the profile, and non-zero p.DurationNanos, got %v", p)
+ }
+ }
+ p, err = query("/debug/pprof/mutex")
+ if err != nil {
+ t.Fatalf("failed to query mutex profile: %v", err)
+ }
+ if !seen(p, "mutexHog1") || !seen(p, "mutexHog2") {
+ t.Errorf("want both mutexHog1 and mutexHog2 in the profile, got %v", p)
+ }
+}
+
+var srv = httptest.NewServer(nil)
+
+func query(endpoint string) (*profile.Profile, error) {
+ url := srv.URL + endpoint
+ r, err := http.Get(url)
+ if err != nil {
+ return nil, fmt.Errorf("failed to fetch %q: %v", url, err)
+ }
+ if r.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("failed to fetch %q: %v", url, r.Status)
+ }
+
+ b, err := io.ReadAll(r.Body)
+ r.Body.Close()
+ if err != nil {
+ return nil, fmt.Errorf("failed to read and parse the result from %q: %v", url, err)
+ }
+ return profile.Parse(bytes.NewBuffer(b))
+}
+
+// seen returns true if the profile includes samples whose stacks include
+// the specified function name (fname).
+func seen(p *profile.Profile, fname string) bool {
+ locIDs := map[*profile.Location]bool{}
+ for _, loc := range p.Location {
+ for _, l := range loc.Line {
+ if strings.Contains(l.Function.Name, fname) {
+ locIDs[loc] = true
+ break
+ }
+ }
+ }
+ for _, sample := range p.Sample {
+ for _, loc := range sample.Location {
+ if locIDs[loc] {
+ return true
+ }
+ }
+ }
+ return false
+}
diff --git a/src/net/http/proxy_test.go b/src/net/http/proxy_test.go
new file mode 100644
index 0000000..0dd57b4
--- /dev/null
+++ b/src/net/http/proxy_test.go
@@ -0,0 +1,50 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "net/url"
+ "os"
+ "testing"
+)
+
+// TODO(mattn):
+// test ProxyAuth
+
+var cacheKeysTests = []struct {
+ proxy string
+ scheme string
+ addr string
+ key string
+}{
+ {"", "http", "foo.com", "|http|foo.com"},
+ {"", "https", "foo.com", "|https|foo.com"},
+ {"http://foo.com", "http", "foo.com", "http://foo.com|http|"},
+ {"http://foo.com", "https", "foo.com", "http://foo.com|https|foo.com"},
+}
+
+func TestCacheKeys(t *testing.T) {
+ for _, tt := range cacheKeysTests {
+ var proxy *url.URL
+ if tt.proxy != "" {
+ u, err := url.Parse(tt.proxy)
+ if err != nil {
+ t.Fatal(err)
+ }
+ proxy = u
+ }
+ cm := connectMethod{proxyURL: proxy, targetScheme: tt.scheme, targetAddr: tt.addr}
+ if got := cm.key().String(); got != tt.key {
+ t.Fatalf("{%q, %q, %q} cache key = %q; want %q", tt.proxy, tt.scheme, tt.addr, got, tt.key)
+ }
+ }
+}
+
+func ResetProxyEnv() {
+ for _, v := range []string{"HTTP_PROXY", "http_proxy", "NO_PROXY", "no_proxy", "REQUEST_METHOD"} {
+ os.Unsetenv(v)
+ }
+ ResetCachedEnvironment()
+}
diff --git a/src/net/http/range_test.go b/src/net/http/range_test.go
new file mode 100644
index 0000000..114987e
--- /dev/null
+++ b/src/net/http/range_test.go
@@ -0,0 +1,79 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "testing"
+)
+
+var ParseRangeTests = []struct {
+ s string
+ length int64
+ r []httpRange
+}{
+ {"", 0, nil},
+ {"", 1000, nil},
+ {"foo", 0, nil},
+ {"bytes=", 0, nil},
+ {"bytes=7", 10, nil},
+ {"bytes= 7 ", 10, nil},
+ {"bytes=1-", 0, nil},
+ {"bytes=5-4", 10, nil},
+ {"bytes=0-2,5-4", 10, nil},
+ {"bytes=2-5,4-3", 10, nil},
+ {"bytes=--5,4--3", 10, nil},
+ {"bytes=A-", 10, nil},
+ {"bytes=A- ", 10, nil},
+ {"bytes=A-Z", 10, nil},
+ {"bytes= -Z", 10, nil},
+ {"bytes=5-Z", 10, nil},
+ {"bytes=Ran-dom, garbage", 10, nil},
+ {"bytes=0x01-0x02", 10, nil},
+ {"bytes= ", 10, nil},
+ {"bytes= , , , ", 10, nil},
+
+ {"bytes=0-9", 10, []httpRange{{0, 10}}},
+ {"bytes=0-", 10, []httpRange{{0, 10}}},
+ {"bytes=5-", 10, []httpRange{{5, 5}}},
+ {"bytes=0-20", 10, []httpRange{{0, 10}}},
+ {"bytes=15-,0-5", 10, []httpRange{{0, 6}}},
+ {"bytes=1-2,5-", 10, []httpRange{{1, 2}, {5, 5}}},
+ {"bytes=-2 , 7-", 11, []httpRange{{9, 2}, {7, 4}}},
+ {"bytes=0-0 ,2-2, 7-", 11, []httpRange{{0, 1}, {2, 1}, {7, 4}}},
+ {"bytes=-5", 10, []httpRange{{5, 5}}},
+ {"bytes=-15", 10, []httpRange{{0, 10}}},
+ {"bytes=0-499", 10000, []httpRange{{0, 500}}},
+ {"bytes=500-999", 10000, []httpRange{{500, 500}}},
+ {"bytes=-500", 10000, []httpRange{{9500, 500}}},
+ {"bytes=9500-", 10000, []httpRange{{9500, 500}}},
+ {"bytes=0-0,-1", 10000, []httpRange{{0, 1}, {9999, 1}}},
+ {"bytes=500-600,601-999", 10000, []httpRange{{500, 101}, {601, 399}}},
+ {"bytes=500-700,601-999", 10000, []httpRange{{500, 201}, {601, 399}}},
+
+ // Match Apache laxity:
+ {"bytes= 1 -2 , 4- 5, 7 - 8 , ,,", 11, []httpRange{{1, 2}, {4, 2}, {7, 2}}},
+}
+
+func TestParseRange(t *testing.T) {
+ for _, test := range ParseRangeTests {
+ r := test.r
+ ranges, err := parseRange(test.s, test.length)
+ if err != nil && r != nil {
+ t.Errorf("parseRange(%q) returned error %q", test.s, err)
+ }
+ if len(ranges) != len(r) {
+ t.Errorf("len(parseRange(%q)) = %d, want %d", test.s, len(ranges), len(r))
+ continue
+ }
+ for i := range r {
+ if ranges[i].start != r[i].start {
+ t.Errorf("parseRange(%q)[%d].start = %d, want %d", test.s, i, ranges[i].start, r[i].start)
+ }
+ if ranges[i].length != r[i].length {
+ t.Errorf("parseRange(%q)[%d].length = %d, want %d", test.s, i, ranges[i].length, r[i].length)
+ }
+ }
+ }
+}
diff --git a/src/net/http/readrequest_test.go b/src/net/http/readrequest_test.go
new file mode 100644
index 0000000..5aaf3b9
--- /dev/null
+++ b/src/net/http/readrequest_test.go
@@ -0,0 +1,475 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "bufio"
+ "bytes"
+ "fmt"
+ "io"
+ "net/url"
+ "reflect"
+ "strings"
+ "testing"
+)
+
+type reqTest struct {
+ Raw string
+ Req *Request
+ Body string
+ Trailer Header
+ Error string
+}
+
+var noError = ""
+var noBodyStr = ""
+var noTrailer Header = nil
+
+var reqTests = []reqTest{
+ // Baseline test; All Request fields included for template use
+ {
+ "GET http://www.techcrunch.com/ HTTP/1.1\r\n" +
+ "Host: www.techcrunch.com\r\n" +
+ "User-Agent: Fake\r\n" +
+ "Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8\r\n" +
+ "Accept-Language: en-us,en;q=0.5\r\n" +
+ "Accept-Encoding: gzip,deflate\r\n" +
+ "Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.7\r\n" +
+ "Keep-Alive: 300\r\n" +
+ "Content-Length: 7\r\n" +
+ "Proxy-Connection: keep-alive\r\n\r\n" +
+ "abcdef\n???",
+
+ &Request{
+ Method: "GET",
+ URL: &url.URL{
+ Scheme: "http",
+ Host: "www.techcrunch.com",
+ Path: "/",
+ },
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{
+ "Accept": {"text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"},
+ "Accept-Language": {"en-us,en;q=0.5"},
+ "Accept-Encoding": {"gzip,deflate"},
+ "Accept-Charset": {"ISO-8859-1,utf-8;q=0.7,*;q=0.7"},
+ "Keep-Alive": {"300"},
+ "Proxy-Connection": {"keep-alive"},
+ "Content-Length": {"7"},
+ "User-Agent": {"Fake"},
+ },
+ Close: false,
+ ContentLength: 7,
+ Host: "www.techcrunch.com",
+ RequestURI: "http://www.techcrunch.com/",
+ },
+
+ "abcdef\n",
+
+ noTrailer,
+ noError,
+ },
+
+ // GET request with no body (the normal case)
+ {
+ "GET / HTTP/1.1\r\n" +
+ "Host: foo.com\r\n\r\n",
+
+ &Request{
+ Method: "GET",
+ URL: &url.URL{
+ Path: "/",
+ },
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{},
+ Close: false,
+ ContentLength: 0,
+ Host: "foo.com",
+ RequestURI: "/",
+ },
+
+ noBodyStr,
+ noTrailer,
+ noError,
+ },
+
+ // Tests that we don't parse a path that looks like a
+ // scheme-relative URI as a scheme-relative URI.
+ {
+ "GET //user@host/is/actually/a/path/ HTTP/1.1\r\n" +
+ "Host: test\r\n\r\n",
+
+ &Request{
+ Method: "GET",
+ URL: &url.URL{
+ Path: "//user@host/is/actually/a/path/",
+ },
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{},
+ Close: false,
+ ContentLength: 0,
+ Host: "test",
+ RequestURI: "//user@host/is/actually/a/path/",
+ },
+
+ noBodyStr,
+ noTrailer,
+ noError,
+ },
+
+ // Tests a bogus absolute-path on the Request-Line (RFC 7230 section 5.3.1)
+ {
+ "GET ../../../../etc/passwd HTTP/1.1\r\n" +
+ "Host: test\r\n\r\n",
+ nil,
+ noBodyStr,
+ noTrailer,
+ `parse "../../../../etc/passwd": invalid URI for request`,
+ },
+
+ // Tests missing URL:
+ {
+ "GET HTTP/1.1\r\n" +
+ "Host: test\r\n\r\n",
+ nil,
+ noBodyStr,
+ noTrailer,
+ `parse "": empty url`,
+ },
+
+ // Tests chunked body with trailer:
+ {
+ "POST / HTTP/1.1\r\n" +
+ "Host: foo.com\r\n" +
+ "Transfer-Encoding: chunked\r\n\r\n" +
+ "3\r\nfoo\r\n" +
+ "3\r\nbar\r\n" +
+ "0\r\n" +
+ "Trailer-Key: Trailer-Value\r\n" +
+ "\r\n",
+ &Request{
+ Method: "POST",
+ URL: &url.URL{
+ Path: "/",
+ },
+ TransferEncoding: []string{"chunked"},
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{},
+ ContentLength: -1,
+ Host: "foo.com",
+ RequestURI: "/",
+ },
+
+ "foobar",
+ Header{
+ "Trailer-Key": {"Trailer-Value"},
+ },
+ noError,
+ },
+
+ // Tests chunked body and a bogus Content-Length which should be deleted.
+ {
+ "POST / HTTP/1.1\r\n" +
+ "Host: foo.com\r\n" +
+ "Transfer-Encoding: chunked\r\n" +
+ "Content-Length: 9999\r\n\r\n" + // to be removed.
+ "3\r\nfoo\r\n" +
+ "3\r\nbar\r\n" +
+ "0\r\n" +
+ "\r\n",
+ &Request{
+ Method: "POST",
+ URL: &url.URL{
+ Path: "/",
+ },
+ TransferEncoding: []string{"chunked"},
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{},
+ ContentLength: -1,
+ Host: "foo.com",
+ RequestURI: "/",
+ },
+
+ "foobar",
+ noTrailer,
+ noError,
+ },
+
+ // CONNECT request with domain name:
+ {
+ "CONNECT www.google.com:443 HTTP/1.1\r\n\r\n",
+
+ &Request{
+ Method: "CONNECT",
+ URL: &url.URL{
+ Host: "www.google.com:443",
+ },
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{},
+ Close: false,
+ ContentLength: 0,
+ Host: "www.google.com:443",
+ RequestURI: "www.google.com:443",
+ },
+
+ noBodyStr,
+ noTrailer,
+ noError,
+ },
+
+ // CONNECT request with IP address:
+ {
+ "CONNECT 127.0.0.1:6060 HTTP/1.1\r\n\r\n",
+
+ &Request{
+ Method: "CONNECT",
+ URL: &url.URL{
+ Host: "127.0.0.1:6060",
+ },
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{},
+ Close: false,
+ ContentLength: 0,
+ Host: "127.0.0.1:6060",
+ RequestURI: "127.0.0.1:6060",
+ },
+
+ noBodyStr,
+ noTrailer,
+ noError,
+ },
+
+ // CONNECT request for RPC:
+ {
+ "CONNECT /_goRPC_ HTTP/1.1\r\n\r\n",
+
+ &Request{
+ Method: "CONNECT",
+ URL: &url.URL{
+ Path: "/_goRPC_",
+ },
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{},
+ Close: false,
+ ContentLength: 0,
+ Host: "",
+ RequestURI: "/_goRPC_",
+ },
+
+ noBodyStr,
+ noTrailer,
+ noError,
+ },
+
+ // SSDP Notify request. golang.org/issue/3692
+ {
+ "NOTIFY * HTTP/1.1\r\nServer: foo\r\n\r\n",
+ &Request{
+ Method: "NOTIFY",
+ URL: &url.URL{
+ Path: "*",
+ },
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{
+ "Server": []string{"foo"},
+ },
+ Close: false,
+ ContentLength: 0,
+ RequestURI: "*",
+ },
+
+ noBodyStr,
+ noTrailer,
+ noError,
+ },
+
+ // OPTIONS request. Similar to golang.org/issue/3692
+ {
+ "OPTIONS * HTTP/1.1\r\nServer: foo\r\n\r\n",
+ &Request{
+ Method: "OPTIONS",
+ URL: &url.URL{
+ Path: "*",
+ },
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{
+ "Server": []string{"foo"},
+ },
+ Close: false,
+ ContentLength: 0,
+ RequestURI: "*",
+ },
+
+ noBodyStr,
+ noTrailer,
+ noError,
+ },
+
+ // Connection: close. golang.org/issue/8261
+ {
+ "GET / HTTP/1.1\r\nHost: issue8261.com\r\nConnection: close\r\n\r\n",
+ &Request{
+ Method: "GET",
+ URL: &url.URL{
+ Path: "/",
+ },
+ Header: Header{
+ // This wasn't removed from Go 1.0 to
+ // Go 1.3, so locking it in that we
+ // keep this:
+ "Connection": []string{"close"},
+ },
+ Host: "issue8261.com",
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Close: true,
+ RequestURI: "/",
+ },
+
+ noBodyStr,
+ noTrailer,
+ noError,
+ },
+
+ // HEAD with Content-Length 0. Make sure this is permitted,
+ // since I think we used to send it.
+ {
+ "HEAD / HTTP/1.1\r\nHost: issue8261.com\r\nConnection: close\r\nContent-Length: 0\r\n\r\n",
+ &Request{
+ Method: "HEAD",
+ URL: &url.URL{
+ Path: "/",
+ },
+ Header: Header{
+ "Connection": []string{"close"},
+ "Content-Length": []string{"0"},
+ },
+ Host: "issue8261.com",
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Close: true,
+ RequestURI: "/",
+ },
+
+ noBodyStr,
+ noTrailer,
+ noError,
+ },
+
+ // http2 client preface:
+ {
+ "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n",
+ &Request{
+ Method: "PRI",
+ URL: &url.URL{
+ Path: "*",
+ },
+ Header: Header{},
+ Proto: "HTTP/2.0",
+ ProtoMajor: 2,
+ ProtoMinor: 0,
+ RequestURI: "*",
+ ContentLength: -1,
+ Close: true,
+ },
+ noBodyStr,
+ noTrailer,
+ noError,
+ },
+}
+
+func TestReadRequest(t *testing.T) {
+ for i := range reqTests {
+ tt := &reqTests[i]
+ req, err := ReadRequest(bufio.NewReader(strings.NewReader(tt.Raw)))
+ if err != nil {
+ if err.Error() != tt.Error {
+ t.Errorf("#%d: error %q, want error %q", i, err.Error(), tt.Error)
+ }
+ continue
+ }
+ rbody := req.Body
+ req.Body = nil
+ testName := fmt.Sprintf("Test %d (%q)", i, tt.Raw)
+ diff(t, testName, req, tt.Req)
+ var bout strings.Builder
+ if rbody != nil {
+ _, err := io.Copy(&bout, rbody)
+ if err != nil {
+ t.Fatalf("%s: copying body: %v", testName, err)
+ }
+ rbody.Close()
+ }
+ body := bout.String()
+ if body != tt.Body {
+ t.Errorf("%s: Body = %q want %q", testName, body, tt.Body)
+ }
+ if !reflect.DeepEqual(tt.Trailer, req.Trailer) {
+ t.Errorf("%s: Trailers differ.\n got: %v\nwant: %v", testName, req.Trailer, tt.Trailer)
+ }
+ }
+}
+
+// reqBytes treats req as a request (with \n delimiters) and returns it with \r\n delimiters,
+// ending in \r\n\r\n
+func reqBytes(req string) []byte {
+ return []byte(strings.ReplaceAll(strings.TrimSpace(req), "\n", "\r\n") + "\r\n\r\n")
+}
+
+var badRequestTests = []struct {
+ name string
+ req []byte
+}{
+ {"bad_connect_host", reqBytes("CONNECT []%20%48%54%54%50%2f%31%2e%31%0a%4d%79%48%65%61%64%65%72%3a%20%31%32%33%0a%0a HTTP/1.0")},
+ {"smuggle_two_contentlen", reqBytes(`POST / HTTP/1.1
+Content-Length: 3
+Content-Length: 4
+
+abc`)},
+ {"smuggle_two_content_len_head", reqBytes(`HEAD / HTTP/1.1
+Host: foo
+Content-Length: 4
+Content-Length: 5
+
+1234`)},
+
+ // golang.org/issue/22464
+ {"leading_space_in_header", reqBytes(`GET / HTTP/1.1
+ Host: foo`)},
+ {"leading_tab_in_header", reqBytes(`GET / HTTP/1.1
+` + "\t" + `Host: foo`)},
+}
+
+func TestReadRequest_Bad(t *testing.T) {
+ for _, tt := range badRequestTests {
+ got, err := ReadRequest(bufio.NewReader(bytes.NewReader(tt.req)))
+ if err == nil {
+ all, err := io.ReadAll(got.Body)
+ t.Errorf("%s: got unexpected request = %#v\n Body = %q, %v", tt.name, got, all, err)
+ }
+ }
+}
diff --git a/src/net/http/request.go b/src/net/http/request.go
new file mode 100644
index 0000000..81f7956
--- /dev/null
+++ b/src/net/http/request.go
@@ -0,0 +1,1488 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// HTTP Request reading and parsing.
+
+package http
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "crypto/tls"
+ "encoding/base64"
+ "errors"
+ "fmt"
+ "io"
+ "mime"
+ "mime/multipart"
+ "net/http/httptrace"
+ "net/http/internal/ascii"
+ "net/textproto"
+ "net/url"
+ urlpkg "net/url"
+ "strconv"
+ "strings"
+ "sync"
+
+ "golang.org/x/net/http/httpguts"
+ "golang.org/x/net/idna"
+)
+
+const (
+ defaultMaxMemory = 32 << 20 // 32 MB
+)
+
+// ErrMissingFile is returned by FormFile when the provided file field name
+// is either not present in the request or not a file field.
+var ErrMissingFile = errors.New("http: no such file")
+
+// ProtocolError represents an HTTP protocol error.
+//
+// Deprecated: Not all errors in the http package related to protocol errors
+// are of type ProtocolError.
+type ProtocolError struct {
+ ErrorString string
+}
+
+func (pe *ProtocolError) Error() string { return pe.ErrorString }
+
+// Is lets http.ErrNotSupported match errors.ErrUnsupported.
+func (pe *ProtocolError) Is(err error) bool {
+ return pe == ErrNotSupported && err == errors.ErrUnsupported
+}
+
+var (
+ // ErrNotSupported indicates that a feature is not supported.
+ //
+ // It is returned by ResponseController methods to indicate that
+ // the handler does not support the method, and by the Push method
+ // of Pusher implementations to indicate that HTTP/2 Push support
+ // is not available.
+ ErrNotSupported = &ProtocolError{"feature not supported"}
+
+ // Deprecated: ErrUnexpectedTrailer is no longer returned by
+ // anything in the net/http package. Callers should not
+ // compare errors against this variable.
+ ErrUnexpectedTrailer = &ProtocolError{"trailer header without chunked transfer encoding"}
+
+ // ErrMissingBoundary is returned by Request.MultipartReader when the
+ // request's Content-Type does not include a "boundary" parameter.
+ ErrMissingBoundary = &ProtocolError{"no multipart boundary param in Content-Type"}
+
+ // ErrNotMultipart is returned by Request.MultipartReader when the
+ // request's Content-Type is not multipart/form-data.
+ ErrNotMultipart = &ProtocolError{"request Content-Type isn't multipart/form-data"}
+
+ // Deprecated: ErrHeaderTooLong is no longer returned by
+ // anything in the net/http package. Callers should not
+ // compare errors against this variable.
+ ErrHeaderTooLong = &ProtocolError{"header too long"}
+
+ // Deprecated: ErrShortBody is no longer returned by
+ // anything in the net/http package. Callers should not
+ // compare errors against this variable.
+ ErrShortBody = &ProtocolError{"entity body too short"}
+
+ // Deprecated: ErrMissingContentLength is no longer returned by
+ // anything in the net/http package. Callers should not
+ // compare errors against this variable.
+ ErrMissingContentLength = &ProtocolError{"missing ContentLength in HEAD response"}
+)
+
+func badStringError(what, val string) error { return fmt.Errorf("%s %q", what, val) }
+
+// Headers that Request.Write handles itself and should be skipped.
+var reqWriteExcludeHeader = map[string]bool{
+ "Host": true, // not in Header map anyway
+ "User-Agent": true,
+ "Content-Length": true,
+ "Transfer-Encoding": true,
+ "Trailer": true,
+}
+
+// A Request represents an HTTP request received by a server
+// or to be sent by a client.
+//
+// The field semantics differ slightly between client and server
+// usage. In addition to the notes on the fields below, see the
+// documentation for Request.Write and RoundTripper.
+type Request struct {
+ // Method specifies the HTTP method (GET, POST, PUT, etc.).
+ // For client requests, an empty string means GET.
+ //
+ // Go's HTTP client does not support sending a request with
+ // the CONNECT method. See the documentation on Transport for
+ // details.
+ Method string
+
+ // URL specifies either the URI being requested (for server
+ // requests) or the URL to access (for client requests).
+ //
+ // For server requests, the URL is parsed from the URI
+ // supplied on the Request-Line as stored in RequestURI. For
+ // most requests, fields other than Path and RawQuery will be
+ // empty. (See RFC 7230, Section 5.3)
+ //
+ // For client requests, the URL's Host specifies the server to
+ // connect to, while the Request's Host field optionally
+ // specifies the Host header value to send in the HTTP
+ // request.
+ URL *url.URL
+
+ // The protocol version for incoming server requests.
+ //
+ // For client requests, these fields are ignored. The HTTP
+ // client code always uses either HTTP/1.1 or HTTP/2.
+ // See the docs on Transport for details.
+ Proto string // "HTTP/1.0"
+ ProtoMajor int // 1
+ ProtoMinor int // 0
+
+ // Header contains the request header fields either received
+ // by the server or to be sent by the client.
+ //
+ // If a server received a request with header lines,
+ //
+ // Host: example.com
+ // accept-encoding: gzip, deflate
+ // Accept-Language: en-us
+ // fOO: Bar
+ // foo: two
+ //
+ // then
+ //
+ // Header = map[string][]string{
+ // "Accept-Encoding": {"gzip, deflate"},
+ // "Accept-Language": {"en-us"},
+ // "Foo": {"Bar", "two"},
+ // }
+ //
+ // For incoming requests, the Host header is promoted to the
+ // Request.Host field and removed from the Header map.
+ //
+ // HTTP defines that header names are case-insensitive. The
+ // request parser implements this by using CanonicalHeaderKey,
+ // making the first character and any characters following a
+ // hyphen uppercase and the rest lowercase.
+ //
+ // For client requests, certain headers such as Content-Length
+ // and Connection are automatically written when needed and
+ // values in Header may be ignored. See the documentation
+ // for the Request.Write method.
+ Header Header
+
+ // Body is the request's body.
+ //
+ // For client requests, a nil body means the request has no
+ // body, such as a GET request. The HTTP Client's Transport
+ // is responsible for calling the Close method.
+ //
+ // For server requests, the Request Body is always non-nil
+ // but will return EOF immediately when no body is present.
+ // The Server will close the request body. The ServeHTTP
+ // Handler does not need to.
+ //
+ // Body must allow Read to be called concurrently with Close.
+ // In particular, calling Close should unblock a Read waiting
+ // for input.
+ Body io.ReadCloser
+
+ // GetBody defines an optional func to return a new copy of
+ // Body. It is used for client requests when a redirect requires
+ // reading the body more than once. Use of GetBody still
+ // requires setting Body.
+ //
+ // For server requests, it is unused.
+ GetBody func() (io.ReadCloser, error)
+
+ // ContentLength records the length of the associated content.
+ // The value -1 indicates that the length is unknown.
+ // Values >= 0 indicate that the given number of bytes may
+ // be read from Body.
+ //
+ // For client requests, a value of 0 with a non-nil Body is
+ // also treated as unknown.
+ ContentLength int64
+
+ // TransferEncoding lists the transfer encodings from outermost to
+ // innermost. An empty list denotes the "identity" encoding.
+ // TransferEncoding can usually be ignored; chunked encoding is
+ // automatically added and removed as necessary when sending and
+ // receiving requests.
+ TransferEncoding []string
+
+ // Close indicates whether to close the connection after
+ // replying to this request (for servers) or after sending this
+ // request and reading its response (for clients).
+ //
+ // For server requests, the HTTP server handles this automatically
+ // and this field is not needed by Handlers.
+ //
+ // For client requests, setting this field prevents re-use of
+ // TCP connections between requests to the same hosts, as if
+ // Transport.DisableKeepAlives were set.
+ Close bool
+
+ // For server requests, Host specifies the host on which the
+ // URL is sought. For HTTP/1 (per RFC 7230, section 5.4), this
+ // is either the value of the "Host" header or the host name
+ // given in the URL itself. For HTTP/2, it is the value of the
+ // ":authority" pseudo-header field.
+ // It may be of the form "host:port". For international domain
+ // names, Host may be in Punycode or Unicode form. Use
+ // golang.org/x/net/idna to convert it to either format if
+ // needed.
+ // To prevent DNS rebinding attacks, server Handlers should
+ // validate that the Host header has a value for which the
+ // Handler considers itself authoritative. The included
+ // ServeMux supports patterns registered to particular host
+ // names and thus protects its registered Handlers.
+ //
+ // For client requests, Host optionally overrides the Host
+ // header to send. If empty, the Request.Write method uses
+ // the value of URL.Host. Host may contain an international
+ // domain name.
+ Host string
+
+ // Form contains the parsed form data, including both the URL
+ // field's query parameters and the PATCH, POST, or PUT form data.
+ // This field is only available after ParseForm is called.
+ // The HTTP client ignores Form and uses Body instead.
+ Form url.Values
+
+ // PostForm contains the parsed form data from PATCH, POST
+ // or PUT body parameters.
+ //
+ // This field is only available after ParseForm is called.
+ // The HTTP client ignores PostForm and uses Body instead.
+ PostForm url.Values
+
+ // MultipartForm is the parsed multipart form, including file uploads.
+ // This field is only available after ParseMultipartForm is called.
+ // The HTTP client ignores MultipartForm and uses Body instead.
+ MultipartForm *multipart.Form
+
+ // Trailer specifies additional headers that are sent after the request
+ // body.
+ //
+ // For server requests, the Trailer map initially contains only the
+ // trailer keys, with nil values. (The client declares which trailers it
+ // will later send.) While the handler is reading from Body, it must
+ // not reference Trailer. After reading from Body returns EOF, Trailer
+ // can be read again and will contain non-nil values, if they were sent
+ // by the client.
+ //
+ // For client requests, Trailer must be initialized to a map containing
+ // the trailer keys to later send. The values may be nil or their final
+ // values. The ContentLength must be 0 or -1, to send a chunked request.
+ // After the HTTP request is sent the map values can be updated while
+ // the request body is read. Once the body returns EOF, the caller must
+ // not mutate Trailer.
+ //
+ // Few HTTP clients, servers, or proxies support HTTP trailers.
+ Trailer Header
+
+ // RemoteAddr allows HTTP servers and other software to record
+ // the network address that sent the request, usually for
+ // logging. This field is not filled in by ReadRequest and
+ // has no defined format. The HTTP server in this package
+ // sets RemoteAddr to an "IP:port" address before invoking a
+ // handler.
+ // This field is ignored by the HTTP client.
+ RemoteAddr string
+
+ // RequestURI is the unmodified request-target of the
+ // Request-Line (RFC 7230, Section 3.1.1) as sent by the client
+ // to a server. Usually the URL field should be used instead.
+ // It is an error to set this field in an HTTP client request.
+ RequestURI string
+
+ // TLS allows HTTP servers and other software to record
+ // information about the TLS connection on which the request
+ // was received. This field is not filled in by ReadRequest.
+ // The HTTP server in this package sets the field for
+ // TLS-enabled connections before invoking a handler;
+ // otherwise it leaves the field nil.
+ // This field is ignored by the HTTP client.
+ TLS *tls.ConnectionState
+
+ // Cancel is an optional channel whose closure indicates that the client
+ // request should be regarded as canceled. Not all implementations of
+ // RoundTripper may support Cancel.
+ //
+ // For server requests, this field is not applicable.
+ //
+ // Deprecated: Set the Request's context with NewRequestWithContext
+ // instead. If a Request's Cancel field and context are both
+ // set, it is undefined whether Cancel is respected.
+ Cancel <-chan struct{}
+
+ // Response is the redirect response which caused this request
+ // to be created. This field is only populated during client
+ // redirects.
+ Response *Response
+
+ // ctx is either the client or server context. It should only
+ // be modified via copying the whole Request using Clone or WithContext.
+ // It is unexported to prevent people from using Context wrong
+ // and mutating the contexts held by callers of the same request.
+ ctx context.Context
+}
+
+// Context returns the request's context. To change the context, use
+// Clone or WithContext.
+//
+// The returned context is always non-nil; it defaults to the
+// background context.
+//
+// For outgoing client requests, the context controls cancellation.
+//
+// For incoming server requests, the context is canceled when the
+// client's connection closes, the request is canceled (with HTTP/2),
+// or when the ServeHTTP method returns.
+func (r *Request) Context() context.Context {
+ if r.ctx != nil {
+ return r.ctx
+ }
+ return context.Background()
+}
+
+// WithContext returns a shallow copy of r with its context changed
+// to ctx. The provided ctx must be non-nil.
+//
+// For outgoing client request, the context controls the entire
+// lifetime of a request and its response: obtaining a connection,
+// sending the request, and reading the response headers and body.
+//
+// To create a new request with a context, use NewRequestWithContext.
+// To make a deep copy of a request with a new context, use Request.Clone.
+func (r *Request) WithContext(ctx context.Context) *Request {
+ if ctx == nil {
+ panic("nil context")
+ }
+ r2 := new(Request)
+ *r2 = *r
+ r2.ctx = ctx
+ return r2
+}
+
+// Clone returns a deep copy of r with its context changed to ctx.
+// The provided ctx must be non-nil.
+//
+// For an outgoing client request, the context controls the entire
+// lifetime of a request and its response: obtaining a connection,
+// sending the request, and reading the response headers and body.
+func (r *Request) Clone(ctx context.Context) *Request {
+ if ctx == nil {
+ panic("nil context")
+ }
+ r2 := new(Request)
+ *r2 = *r
+ r2.ctx = ctx
+ r2.URL = cloneURL(r.URL)
+ if r.Header != nil {
+ r2.Header = r.Header.Clone()
+ }
+ if r.Trailer != nil {
+ r2.Trailer = r.Trailer.Clone()
+ }
+ if s := r.TransferEncoding; s != nil {
+ s2 := make([]string, len(s))
+ copy(s2, s)
+ r2.TransferEncoding = s2
+ }
+ r2.Form = cloneURLValues(r.Form)
+ r2.PostForm = cloneURLValues(r.PostForm)
+ r2.MultipartForm = cloneMultipartForm(r.MultipartForm)
+ return r2
+}
+
+// ProtoAtLeast reports whether the HTTP protocol used
+// in the request is at least major.minor.
+func (r *Request) ProtoAtLeast(major, minor int) bool {
+ return r.ProtoMajor > major ||
+ r.ProtoMajor == major && r.ProtoMinor >= minor
+}
+
+// UserAgent returns the client's User-Agent, if sent in the request.
+func (r *Request) UserAgent() string {
+ return r.Header.Get("User-Agent")
+}
+
+// Cookies parses and returns the HTTP cookies sent with the request.
+func (r *Request) Cookies() []*Cookie {
+ return readCookies(r.Header, "")
+}
+
+// ErrNoCookie is returned by Request's Cookie method when a cookie is not found.
+var ErrNoCookie = errors.New("http: named cookie not present")
+
+// Cookie returns the named cookie provided in the request or
+// ErrNoCookie if not found.
+// If multiple cookies match the given name, only one cookie will
+// be returned.
+func (r *Request) Cookie(name string) (*Cookie, error) {
+ if name == "" {
+ return nil, ErrNoCookie
+ }
+ for _, c := range readCookies(r.Header, name) {
+ return c, nil
+ }
+ return nil, ErrNoCookie
+}
+
+// AddCookie adds a cookie to the request. Per RFC 6265 section 5.4,
+// AddCookie does not attach more than one Cookie header field. That
+// means all cookies, if any, are written into the same line,
+// separated by semicolon.
+// AddCookie only sanitizes c's name and value, and does not sanitize
+// a Cookie header already present in the request.
+func (r *Request) AddCookie(c *Cookie) {
+ s := fmt.Sprintf("%s=%s", sanitizeCookieName(c.Name), sanitizeCookieValue(c.Value))
+ if c := r.Header.Get("Cookie"); c != "" {
+ r.Header.Set("Cookie", c+"; "+s)
+ } else {
+ r.Header.Set("Cookie", s)
+ }
+}
+
+// Referer returns the referring URL, if sent in the request.
+//
+// Referer is misspelled as in the request itself, a mistake from the
+// earliest days of HTTP. This value can also be fetched from the
+// Header map as Header["Referer"]; the benefit of making it available
+// as a method is that the compiler can diagnose programs that use the
+// alternate (correct English) spelling req.Referrer() but cannot
+// diagnose programs that use Header["Referrer"].
+func (r *Request) Referer() string {
+ return r.Header.Get("Referer")
+}
+
+// multipartByReader is a sentinel value.
+// Its presence in Request.MultipartForm indicates that parsing of the request
+// body has been handed off to a MultipartReader instead of ParseMultipartForm.
+var multipartByReader = &multipart.Form{
+ Value: make(map[string][]string),
+ File: make(map[string][]*multipart.FileHeader),
+}
+
+// MultipartReader returns a MIME multipart reader if this is a
+// multipart/form-data or a multipart/mixed POST request, else returns nil and an error.
+// Use this function instead of ParseMultipartForm to
+// process the request body as a stream.
+func (r *Request) MultipartReader() (*multipart.Reader, error) {
+ if r.MultipartForm == multipartByReader {
+ return nil, errors.New("http: MultipartReader called twice")
+ }
+ if r.MultipartForm != nil {
+ return nil, errors.New("http: multipart handled by ParseMultipartForm")
+ }
+ r.MultipartForm = multipartByReader
+ return r.multipartReader(true)
+}
+
+func (r *Request) multipartReader(allowMixed bool) (*multipart.Reader, error) {
+ v := r.Header.Get("Content-Type")
+ if v == "" {
+ return nil, ErrNotMultipart
+ }
+ if r.Body == nil {
+ return nil, errors.New("missing form body")
+ }
+ d, params, err := mime.ParseMediaType(v)
+ if err != nil || !(d == "multipart/form-data" || allowMixed && d == "multipart/mixed") {
+ return nil, ErrNotMultipart
+ }
+ boundary, ok := params["boundary"]
+ if !ok {
+ return nil, ErrMissingBoundary
+ }
+ return multipart.NewReader(r.Body, boundary), nil
+}
+
+// isH2Upgrade reports whether r represents the http2 "client preface"
+// magic string.
+func (r *Request) isH2Upgrade() bool {
+ return r.Method == "PRI" && len(r.Header) == 0 && r.URL.Path == "*" && r.Proto == "HTTP/2.0"
+}
+
+// Return value if nonempty, def otherwise.
+func valueOrDefault(value, def string) string {
+ if value != "" {
+ return value
+ }
+ return def
+}
+
+// NOTE: This is not intended to reflect the actual Go version being used.
+// It was changed at the time of Go 1.1 release because the former User-Agent
+// had ended up blocked by some intrusion detection systems.
+// See https://codereview.appspot.com/7532043.
+const defaultUserAgent = "Go-http-client/1.1"
+
+// Write writes an HTTP/1.1 request, which is the header and body, in wire format.
+// This method consults the following fields of the request:
+//
+// Host
+// URL
+// Method (defaults to "GET")
+// Header
+// ContentLength
+// TransferEncoding
+// Body
+//
+// If Body is present, Content-Length is <= 0 and TransferEncoding
+// hasn't been set to "identity", Write adds "Transfer-Encoding:
+// chunked" to the header. Body is closed after it is sent.
+func (r *Request) Write(w io.Writer) error {
+ return r.write(w, false, nil, nil)
+}
+
+// WriteProxy is like Write but writes the request in the form
+// expected by an HTTP proxy. In particular, WriteProxy writes the
+// initial Request-URI line of the request with an absolute URI, per
+// section 5.3 of RFC 7230, including the scheme and host.
+// In either case, WriteProxy also writes a Host header, using
+// either r.Host or r.URL.Host.
+func (r *Request) WriteProxy(w io.Writer) error {
+ return r.write(w, true, nil, nil)
+}
+
+// errMissingHost is returned by Write when there is no Host or URL present in
+// the Request.
+var errMissingHost = errors.New("http: Request.Write on Request with no Host or URL set")
+
+// extraHeaders may be nil
+// waitForContinue may be nil
+// always closes body
+func (r *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, waitForContinue func() bool) (err error) {
+ trace := httptrace.ContextClientTrace(r.Context())
+ if trace != nil && trace.WroteRequest != nil {
+ defer func() {
+ trace.WroteRequest(httptrace.WroteRequestInfo{
+ Err: err,
+ })
+ }()
+ }
+ closed := false
+ defer func() {
+ if closed {
+ return
+ }
+ if closeErr := r.closeBody(); closeErr != nil && err == nil {
+ err = closeErr
+ }
+ }()
+
+ // Find the target host. Prefer the Host: header, but if that
+ // is not given, use the host from the request URL.
+ //
+ // Clean the host, in case it arrives with unexpected stuff in it.
+ host := r.Host
+ if host == "" {
+ if r.URL == nil {
+ return errMissingHost
+ }
+ host = r.URL.Host
+ }
+ host, err = httpguts.PunycodeHostPort(host)
+ if err != nil {
+ return err
+ }
+ // Validate that the Host header is a valid header in general,
+ // but don't validate the host itself. This is sufficient to avoid
+ // header or request smuggling via the Host field.
+ // The server can (and will, if it's a net/http server) reject
+ // the request if it doesn't consider the host valid.
+ if !httpguts.ValidHostHeader(host) {
+ // Historically, we would truncate the Host header after '/' or ' '.
+ // Some users have relied on this truncation to convert a network
+ // address such as Unix domain socket path into a valid, ignored
+ // Host header (see https://go.dev/issue/61431).
+ //
+ // We don't preserve the truncation, because sending an altered
+ // header field opens a smuggling vector. Instead, zero out the
+ // Host header entirely if it isn't valid. (An empty Host is valid;
+ // see RFC 9112 Section 3.2.)
+ //
+ // Return an error if we're sending to a proxy, since the proxy
+ // probably can't do anything useful with an empty Host header.
+ if !usingProxy {
+ host = ""
+ } else {
+ return errors.New("http: invalid Host header")
+ }
+ }
+
+ // According to RFC 6874, an HTTP client, proxy, or other
+ // intermediary must remove any IPv6 zone identifier attached
+ // to an outgoing URI.
+ host = removeZone(host)
+
+ ruri := r.URL.RequestURI()
+ if usingProxy && r.URL.Scheme != "" && r.URL.Opaque == "" {
+ ruri = r.URL.Scheme + "://" + host + ruri
+ } else if r.Method == "CONNECT" && r.URL.Path == "" {
+ // CONNECT requests normally give just the host and port, not a full URL.
+ ruri = host
+ if r.URL.Opaque != "" {
+ ruri = r.URL.Opaque
+ }
+ }
+ if stringContainsCTLByte(ruri) {
+ return errors.New("net/http: can't write control character in Request.URL")
+ }
+ // TODO: validate r.Method too? At least it's less likely to
+ // come from an attacker (more likely to be a constant in
+ // code).
+
+ // Wrap the writer in a bufio Writer if it's not already buffered.
+ // Don't always call NewWriter, as that forces a bytes.Buffer
+ // and other small bufio Writers to have a minimum 4k buffer
+ // size.
+ var bw *bufio.Writer
+ if _, ok := w.(io.ByteWriter); !ok {
+ bw = bufio.NewWriter(w)
+ w = bw
+ }
+
+ _, err = fmt.Fprintf(w, "%s %s HTTP/1.1\r\n", valueOrDefault(r.Method, "GET"), ruri)
+ if err != nil {
+ return err
+ }
+
+ // Header lines
+ _, err = fmt.Fprintf(w, "Host: %s\r\n", host)
+ if err != nil {
+ return err
+ }
+ if trace != nil && trace.WroteHeaderField != nil {
+ trace.WroteHeaderField("Host", []string{host})
+ }
+
+ // Use the defaultUserAgent unless the Header contains one, which
+ // may be blank to not send the header.
+ userAgent := defaultUserAgent
+ if r.Header.has("User-Agent") {
+ userAgent = r.Header.Get("User-Agent")
+ }
+ if userAgent != "" {
+ _, err = fmt.Fprintf(w, "User-Agent: %s\r\n", userAgent)
+ if err != nil {
+ return err
+ }
+ if trace != nil && trace.WroteHeaderField != nil {
+ trace.WroteHeaderField("User-Agent", []string{userAgent})
+ }
+ }
+
+ // Process Body,ContentLength,Close,Trailer
+ tw, err := newTransferWriter(r)
+ if err != nil {
+ return err
+ }
+ err = tw.writeHeader(w, trace)
+ if err != nil {
+ return err
+ }
+
+ err = r.Header.writeSubset(w, reqWriteExcludeHeader, trace)
+ if err != nil {
+ return err
+ }
+
+ if extraHeaders != nil {
+ err = extraHeaders.write(w, trace)
+ if err != nil {
+ return err
+ }
+ }
+
+ _, err = io.WriteString(w, "\r\n")
+ if err != nil {
+ return err
+ }
+
+ if trace != nil && trace.WroteHeaders != nil {
+ trace.WroteHeaders()
+ }
+
+ // Flush and wait for 100-continue if expected.
+ if waitForContinue != nil {
+ if bw, ok := w.(*bufio.Writer); ok {
+ err = bw.Flush()
+ if err != nil {
+ return err
+ }
+ }
+ if trace != nil && trace.Wait100Continue != nil {
+ trace.Wait100Continue()
+ }
+ if !waitForContinue() {
+ closed = true
+ r.closeBody()
+ return nil
+ }
+ }
+
+ if bw, ok := w.(*bufio.Writer); ok && tw.FlushHeaders {
+ if err := bw.Flush(); err != nil {
+ return err
+ }
+ }
+
+ // Write body and trailer
+ closed = true
+ err = tw.writeBody(w)
+ if err != nil {
+ if tw.bodyReadError == err {
+ err = requestBodyReadError{err}
+ }
+ return err
+ }
+
+ if bw != nil {
+ return bw.Flush()
+ }
+ return nil
+}
+
+// requestBodyReadError wraps an error from (*Request).write to indicate
+// that the error came from a Read call on the Request.Body.
+// This error type should not escape the net/http package to users.
+type requestBodyReadError struct{ error }
+
+func idnaASCII(v string) (string, error) {
+ // TODO: Consider removing this check after verifying performance is okay.
+ // Right now punycode verification, length checks, context checks, and the
+ // permissible character tests are all omitted. It also prevents the ToASCII
+ // call from salvaging an invalid IDN, when possible. As a result it may be
+ // possible to have two IDNs that appear identical to the user where the
+ // ASCII-only version causes an error downstream whereas the non-ASCII
+ // version does not.
+ // Note that for correct ASCII IDNs ToASCII will only do considerably more
+ // work, but it will not cause an allocation.
+ if ascii.Is(v) {
+ return v, nil
+ }
+ return idna.Lookup.ToASCII(v)
+}
+
+// removeZone removes IPv6 zone identifier from host.
+// E.g., "[fe80::1%en0]:8080" to "[fe80::1]:8080"
+func removeZone(host string) string {
+ if !strings.HasPrefix(host, "[") {
+ return host
+ }
+ i := strings.LastIndex(host, "]")
+ if i < 0 {
+ return host
+ }
+ j := strings.LastIndex(host[:i], "%")
+ if j < 0 {
+ return host
+ }
+ return host[:j] + host[i:]
+}
+
+// ParseHTTPVersion parses an HTTP version string according to RFC 7230, section 2.6.
+// "HTTP/1.0" returns (1, 0, true). Note that strings without
+// a minor version, such as "HTTP/2", are not valid.
+func ParseHTTPVersion(vers string) (major, minor int, ok bool) {
+ switch vers {
+ case "HTTP/1.1":
+ return 1, 1, true
+ case "HTTP/1.0":
+ return 1, 0, true
+ }
+ if !strings.HasPrefix(vers, "HTTP/") {
+ return 0, 0, false
+ }
+ if len(vers) != len("HTTP/X.Y") {
+ return 0, 0, false
+ }
+ if vers[6] != '.' {
+ return 0, 0, false
+ }
+ maj, err := strconv.ParseUint(vers[5:6], 10, 0)
+ if err != nil {
+ return 0, 0, false
+ }
+ min, err := strconv.ParseUint(vers[7:8], 10, 0)
+ if err != nil {
+ return 0, 0, false
+ }
+ return int(maj), int(min), true
+}
+
+func validMethod(method string) bool {
+ /*
+ Method = "OPTIONS" ; Section 9.2
+ | "GET" ; Section 9.3
+ | "HEAD" ; Section 9.4
+ | "POST" ; Section 9.5
+ | "PUT" ; Section 9.6
+ | "DELETE" ; Section 9.7
+ | "TRACE" ; Section 9.8
+ | "CONNECT" ; Section 9.9
+ | extension-method
+ extension-method = token
+ token = 1*<any CHAR except CTLs or separators>
+ */
+ return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1
+}
+
+// NewRequest wraps NewRequestWithContext using context.Background.
+func NewRequest(method, url string, body io.Reader) (*Request, error) {
+ return NewRequestWithContext(context.Background(), method, url, body)
+}
+
+// NewRequestWithContext returns a new Request given a method, URL, and
+// optional body.
+//
+// If the provided body is also an io.Closer, the returned
+// Request.Body is set to body and will be closed by the Client
+// methods Do, Post, and PostForm, and Transport.RoundTrip.
+//
+// NewRequestWithContext returns a Request suitable for use with
+// Client.Do or Transport.RoundTrip. To create a request for use with
+// testing a Server Handler, either use the NewRequest function in the
+// net/http/httptest package, use ReadRequest, or manually update the
+// Request fields. For an outgoing client request, the context
+// controls the entire lifetime of a request and its response:
+// obtaining a connection, sending the request, and reading the
+// response headers and body. See the Request type's documentation for
+// the difference between inbound and outbound request fields.
+//
+// If body is of type *bytes.Buffer, *bytes.Reader, or
+// *strings.Reader, the returned request's ContentLength is set to its
+// exact value (instead of -1), GetBody is populated (so 307 and 308
+// redirects can replay the body), and Body is set to NoBody if the
+// ContentLength is 0.
+func NewRequestWithContext(ctx context.Context, method, url string, body io.Reader) (*Request, error) {
+ if method == "" {
+ // We document that "" means "GET" for Request.Method, and people have
+ // relied on that from NewRequest, so keep that working.
+ // We still enforce validMethod for non-empty methods.
+ method = "GET"
+ }
+ if !validMethod(method) {
+ return nil, fmt.Errorf("net/http: invalid method %q", method)
+ }
+ if ctx == nil {
+ return nil, errors.New("net/http: nil Context")
+ }
+ u, err := urlpkg.Parse(url)
+ if err != nil {
+ return nil, err
+ }
+ rc, ok := body.(io.ReadCloser)
+ if !ok && body != nil {
+ rc = io.NopCloser(body)
+ }
+ // The host's colon:port should be normalized. See Issue 14836.
+ u.Host = removeEmptyPort(u.Host)
+ req := &Request{
+ ctx: ctx,
+ Method: method,
+ URL: u,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: make(Header),
+ Body: rc,
+ Host: u.Host,
+ }
+ if body != nil {
+ switch v := body.(type) {
+ case *bytes.Buffer:
+ req.ContentLength = int64(v.Len())
+ buf := v.Bytes()
+ req.GetBody = func() (io.ReadCloser, error) {
+ r := bytes.NewReader(buf)
+ return io.NopCloser(r), nil
+ }
+ case *bytes.Reader:
+ req.ContentLength = int64(v.Len())
+ snapshot := *v
+ req.GetBody = func() (io.ReadCloser, error) {
+ r := snapshot
+ return io.NopCloser(&r), nil
+ }
+ case *strings.Reader:
+ req.ContentLength = int64(v.Len())
+ snapshot := *v
+ req.GetBody = func() (io.ReadCloser, error) {
+ r := snapshot
+ return io.NopCloser(&r), nil
+ }
+ default:
+ // This is where we'd set it to -1 (at least
+ // if body != NoBody) to mean unknown, but
+ // that broke people during the Go 1.8 testing
+ // period. People depend on it being 0 I
+ // guess. Maybe retry later. See Issue 18117.
+ }
+ // For client requests, Request.ContentLength of 0
+ // means either actually 0, or unknown. The only way
+ // to explicitly say that the ContentLength is zero is
+ // to set the Body to nil. But turns out too much code
+ // depends on NewRequest returning a non-nil Body,
+ // so we use a well-known ReadCloser variable instead
+ // and have the http package also treat that sentinel
+ // variable to mean explicitly zero.
+ if req.GetBody != nil && req.ContentLength == 0 {
+ req.Body = NoBody
+ req.GetBody = func() (io.ReadCloser, error) { return NoBody, nil }
+ }
+ }
+
+ return req, nil
+}
+
+// BasicAuth returns the username and password provided in the request's
+// Authorization header, if the request uses HTTP Basic Authentication.
+// See RFC 2617, Section 2.
+func (r *Request) BasicAuth() (username, password string, ok bool) {
+ auth := r.Header.Get("Authorization")
+ if auth == "" {
+ return "", "", false
+ }
+ return parseBasicAuth(auth)
+}
+
+// parseBasicAuth parses an HTTP Basic Authentication string.
+// "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==" returns ("Aladdin", "open sesame", true).
+func parseBasicAuth(auth string) (username, password string, ok bool) {
+ const prefix = "Basic "
+ // Case insensitive prefix match. See Issue 22736.
+ if len(auth) < len(prefix) || !ascii.EqualFold(auth[:len(prefix)], prefix) {
+ return "", "", false
+ }
+ c, err := base64.StdEncoding.DecodeString(auth[len(prefix):])
+ if err != nil {
+ return "", "", false
+ }
+ cs := string(c)
+ username, password, ok = strings.Cut(cs, ":")
+ if !ok {
+ return "", "", false
+ }
+ return username, password, true
+}
+
+// SetBasicAuth sets the request's Authorization header to use HTTP
+// Basic Authentication with the provided username and password.
+//
+// With HTTP Basic Authentication the provided username and password
+// are not encrypted. It should generally only be used in an HTTPS
+// request.
+//
+// The username may not contain a colon. Some protocols may impose
+// additional requirements on pre-escaping the username and
+// password. For instance, when used with OAuth2, both arguments must
+// be URL encoded first with url.QueryEscape.
+func (r *Request) SetBasicAuth(username, password string) {
+ r.Header.Set("Authorization", "Basic "+basicAuth(username, password))
+}
+
+// parseRequestLine parses "GET /foo HTTP/1.1" into its three parts.
+func parseRequestLine(line string) (method, requestURI, proto string, ok bool) {
+ method, rest, ok1 := strings.Cut(line, " ")
+ requestURI, proto, ok2 := strings.Cut(rest, " ")
+ if !ok1 || !ok2 {
+ return "", "", "", false
+ }
+ return method, requestURI, proto, true
+}
+
+var textprotoReaderPool sync.Pool
+
+func newTextprotoReader(br *bufio.Reader) *textproto.Reader {
+ if v := textprotoReaderPool.Get(); v != nil {
+ tr := v.(*textproto.Reader)
+ tr.R = br
+ return tr
+ }
+ return textproto.NewReader(br)
+}
+
+func putTextprotoReader(r *textproto.Reader) {
+ r.R = nil
+ textprotoReaderPool.Put(r)
+}
+
+// ReadRequest reads and parses an incoming request from b.
+//
+// ReadRequest is a low-level function and should only be used for
+// specialized applications; most code should use the Server to read
+// requests and handle them via the Handler interface. ReadRequest
+// only supports HTTP/1.x requests. For HTTP/2, use golang.org/x/net/http2.
+func ReadRequest(b *bufio.Reader) (*Request, error) {
+ req, err := readRequest(b)
+ if err != nil {
+ return nil, err
+ }
+
+ delete(req.Header, "Host")
+ return req, err
+}
+
+func readRequest(b *bufio.Reader) (req *Request, err error) {
+ tp := newTextprotoReader(b)
+ defer putTextprotoReader(tp)
+
+ req = new(Request)
+
+ // First line: GET /index.html HTTP/1.0
+ var s string
+ if s, err = tp.ReadLine(); err != nil {
+ return nil, err
+ }
+ defer func() {
+ if err == io.EOF {
+ err = io.ErrUnexpectedEOF
+ }
+ }()
+
+ var ok bool
+ req.Method, req.RequestURI, req.Proto, ok = parseRequestLine(s)
+ if !ok {
+ return nil, badStringError("malformed HTTP request", s)
+ }
+ if !validMethod(req.Method) {
+ return nil, badStringError("invalid method", req.Method)
+ }
+ rawurl := req.RequestURI
+ if req.ProtoMajor, req.ProtoMinor, ok = ParseHTTPVersion(req.Proto); !ok {
+ return nil, badStringError("malformed HTTP version", req.Proto)
+ }
+
+ // CONNECT requests are used two different ways, and neither uses a full URL:
+ // The standard use is to tunnel HTTPS through an HTTP proxy.
+ // It looks like "CONNECT www.google.com:443 HTTP/1.1", and the parameter is
+ // just the authority section of a URL. This information should go in req.URL.Host.
+ //
+ // The net/rpc package also uses CONNECT, but there the parameter is a path
+ // that starts with a slash. It can be parsed with the regular URL parser,
+ // and the path will end up in req.URL.Path, where it needs to be in order for
+ // RPC to work.
+ justAuthority := req.Method == "CONNECT" && !strings.HasPrefix(rawurl, "/")
+ if justAuthority {
+ rawurl = "http://" + rawurl
+ }
+
+ if req.URL, err = url.ParseRequestURI(rawurl); err != nil {
+ return nil, err
+ }
+
+ if justAuthority {
+ // Strip the bogus "http://" back off.
+ req.URL.Scheme = ""
+ }
+
+ // Subsequent lines: Key: value.
+ mimeHeader, err := tp.ReadMIMEHeader()
+ if err != nil {
+ return nil, err
+ }
+ req.Header = Header(mimeHeader)
+ if len(req.Header["Host"]) > 1 {
+ return nil, fmt.Errorf("too many Host headers")
+ }
+
+ // RFC 7230, section 5.3: Must treat
+ // GET /index.html HTTP/1.1
+ // Host: www.google.com
+ // and
+ // GET http://www.google.com/index.html HTTP/1.1
+ // Host: doesntmatter
+ // the same. In the second case, any Host line is ignored.
+ req.Host = req.URL.Host
+ if req.Host == "" {
+ req.Host = req.Header.get("Host")
+ }
+
+ fixPragmaCacheControl(req.Header)
+
+ req.Close = shouldClose(req.ProtoMajor, req.ProtoMinor, req.Header, false)
+
+ err = readTransfer(req, b)
+ if err != nil {
+ return nil, err
+ }
+
+ if req.isH2Upgrade() {
+ // Because it's neither chunked, nor declared:
+ req.ContentLength = -1
+
+ // We want to give handlers a chance to hijack the
+ // connection, but we need to prevent the Server from
+ // dealing with the connection further if it's not
+ // hijacked. Set Close to ensure that:
+ req.Close = true
+ }
+ return req, nil
+}
+
+// MaxBytesReader is similar to io.LimitReader but is intended for
+// limiting the size of incoming request bodies. In contrast to
+// io.LimitReader, MaxBytesReader's result is a ReadCloser, returns a
+// non-nil error of type *MaxBytesError for a Read beyond the limit,
+// and closes the underlying reader when its Close method is called.
+//
+// MaxBytesReader prevents clients from accidentally or maliciously
+// sending a large request and wasting server resources. If possible,
+// it tells the ResponseWriter to close the connection after the limit
+// has been reached.
+func MaxBytesReader(w ResponseWriter, r io.ReadCloser, n int64) io.ReadCloser {
+ if n < 0 { // Treat negative limits as equivalent to 0.
+ n = 0
+ }
+ return &maxBytesReader{w: w, r: r, i: n, n: n}
+}
+
+// MaxBytesError is returned by MaxBytesReader when its read limit is exceeded.
+type MaxBytesError struct {
+ Limit int64
+}
+
+func (e *MaxBytesError) Error() string {
+ // Due to Hyrum's law, this text cannot be changed.
+ return "http: request body too large"
+}
+
+type maxBytesReader struct {
+ w ResponseWriter
+ r io.ReadCloser // underlying reader
+ i int64 // max bytes initially, for MaxBytesError
+ n int64 // max bytes remaining
+ err error // sticky error
+}
+
+func (l *maxBytesReader) Read(p []byte) (n int, err error) {
+ if l.err != nil {
+ return 0, l.err
+ }
+ if len(p) == 0 {
+ return 0, nil
+ }
+ // If they asked for a 32KB byte read but only 5 bytes are
+ // remaining, no need to read 32KB. 6 bytes will answer the
+ // question of the whether we hit the limit or go past it.
+ // 0 < len(p) < 2^63
+ if int64(len(p))-1 > l.n {
+ p = p[:l.n+1]
+ }
+ n, err = l.r.Read(p)
+
+ if int64(n) <= l.n {
+ l.n -= int64(n)
+ l.err = err
+ return n, err
+ }
+
+ n = int(l.n)
+ l.n = 0
+
+ // The server code and client code both use
+ // maxBytesReader. This "requestTooLarge" check is
+ // only used by the server code. To prevent binaries
+ // which only using the HTTP Client code (such as
+ // cmd/go) from also linking in the HTTP server, don't
+ // use a static type assertion to the server
+ // "*response" type. Check this interface instead:
+ type requestTooLarger interface {
+ requestTooLarge()
+ }
+ if res, ok := l.w.(requestTooLarger); ok {
+ res.requestTooLarge()
+ }
+ l.err = &MaxBytesError{l.i}
+ return n, l.err
+}
+
+func (l *maxBytesReader) Close() error {
+ return l.r.Close()
+}
+
+func copyValues(dst, src url.Values) {
+ for k, vs := range src {
+ dst[k] = append(dst[k], vs...)
+ }
+}
+
+func parsePostForm(r *Request) (vs url.Values, err error) {
+ if r.Body == nil {
+ err = errors.New("missing form body")
+ return
+ }
+ ct := r.Header.Get("Content-Type")
+ // RFC 7231, section 3.1.1.5 - empty type
+ // MAY be treated as application/octet-stream
+ if ct == "" {
+ ct = "application/octet-stream"
+ }
+ ct, _, err = mime.ParseMediaType(ct)
+ switch {
+ case ct == "application/x-www-form-urlencoded":
+ var reader io.Reader = r.Body
+ maxFormSize := int64(1<<63 - 1)
+ if _, ok := r.Body.(*maxBytesReader); !ok {
+ maxFormSize = int64(10 << 20) // 10 MB is a lot of text.
+ reader = io.LimitReader(r.Body, maxFormSize+1)
+ }
+ b, e := io.ReadAll(reader)
+ if e != nil {
+ if err == nil {
+ err = e
+ }
+ break
+ }
+ if int64(len(b)) > maxFormSize {
+ err = errors.New("http: POST too large")
+ return
+ }
+ vs, e = url.ParseQuery(string(b))
+ if err == nil {
+ err = e
+ }
+ case ct == "multipart/form-data":
+ // handled by ParseMultipartForm (which is calling us, or should be)
+ // TODO(bradfitz): there are too many possible
+ // orders to call too many functions here.
+ // Clean this up and write more tests.
+ // request_test.go contains the start of this,
+ // in TestParseMultipartFormOrder and others.
+ }
+ return
+}
+
+// ParseForm populates r.Form and r.PostForm.
+//
+// For all requests, ParseForm parses the raw query from the URL and updates
+// r.Form.
+//
+// For POST, PUT, and PATCH requests, it also reads the request body, parses it
+// as a form and puts the results into both r.PostForm and r.Form. Request body
+// parameters take precedence over URL query string values in r.Form.
+//
+// If the request Body's size has not already been limited by MaxBytesReader,
+// the size is capped at 10MB.
+//
+// For other HTTP methods, or when the Content-Type is not
+// application/x-www-form-urlencoded, the request Body is not read, and
+// r.PostForm is initialized to a non-nil, empty value.
+//
+// ParseMultipartForm calls ParseForm automatically.
+// ParseForm is idempotent.
+func (r *Request) ParseForm() error {
+ var err error
+ if r.PostForm == nil {
+ if r.Method == "POST" || r.Method == "PUT" || r.Method == "PATCH" {
+ r.PostForm, err = parsePostForm(r)
+ }
+ if r.PostForm == nil {
+ r.PostForm = make(url.Values)
+ }
+ }
+ if r.Form == nil {
+ if len(r.PostForm) > 0 {
+ r.Form = make(url.Values)
+ copyValues(r.Form, r.PostForm)
+ }
+ var newValues url.Values
+ if r.URL != nil {
+ var e error
+ newValues, e = url.ParseQuery(r.URL.RawQuery)
+ if err == nil {
+ err = e
+ }
+ }
+ if newValues == nil {
+ newValues = make(url.Values)
+ }
+ if r.Form == nil {
+ r.Form = newValues
+ } else {
+ copyValues(r.Form, newValues)
+ }
+ }
+ return err
+}
+
+// ParseMultipartForm parses a request body as multipart/form-data.
+// The whole request body is parsed and up to a total of maxMemory bytes of
+// its file parts are stored in memory, with the remainder stored on
+// disk in temporary files.
+// ParseMultipartForm calls ParseForm if necessary.
+// If ParseForm returns an error, ParseMultipartForm returns it but also
+// continues parsing the request body.
+// After one call to ParseMultipartForm, subsequent calls have no effect.
+func (r *Request) ParseMultipartForm(maxMemory int64) error {
+ if r.MultipartForm == multipartByReader {
+ return errors.New("http: multipart handled by MultipartReader")
+ }
+ var parseFormErr error
+ if r.Form == nil {
+ // Let errors in ParseForm fall through, and just
+ // return it at the end.
+ parseFormErr = r.ParseForm()
+ }
+ if r.MultipartForm != nil {
+ return nil
+ }
+
+ mr, err := r.multipartReader(false)
+ if err != nil {
+ return err
+ }
+
+ f, err := mr.ReadForm(maxMemory)
+ if err != nil {
+ return err
+ }
+
+ if r.PostForm == nil {
+ r.PostForm = make(url.Values)
+ }
+ for k, v := range f.Value {
+ r.Form[k] = append(r.Form[k], v...)
+ // r.PostForm should also be populated. See Issue 9305.
+ r.PostForm[k] = append(r.PostForm[k], v...)
+ }
+
+ r.MultipartForm = f
+
+ return parseFormErr
+}
+
+// FormValue returns the first value for the named component of the query.
+// POST and PUT body parameters take precedence over URL query string values.
+// FormValue calls ParseMultipartForm and ParseForm if necessary and ignores
+// any errors returned by these functions.
+// If key is not present, FormValue returns the empty string.
+// To access multiple values of the same key, call ParseForm and
+// then inspect Request.Form directly.
+func (r *Request) FormValue(key string) string {
+ if r.Form == nil {
+ r.ParseMultipartForm(defaultMaxMemory)
+ }
+ if vs := r.Form[key]; len(vs) > 0 {
+ return vs[0]
+ }
+ return ""
+}
+
+// PostFormValue returns the first value for the named component of the POST,
+// PATCH, or PUT request body. URL query parameters are ignored.
+// PostFormValue calls ParseMultipartForm and ParseForm if necessary and ignores
+// any errors returned by these functions.
+// If key is not present, PostFormValue returns the empty string.
+func (r *Request) PostFormValue(key string) string {
+ if r.PostForm == nil {
+ r.ParseMultipartForm(defaultMaxMemory)
+ }
+ if vs := r.PostForm[key]; len(vs) > 0 {
+ return vs[0]
+ }
+ return ""
+}
+
+// FormFile returns the first file for the provided form key.
+// FormFile calls ParseMultipartForm and ParseForm if necessary.
+func (r *Request) FormFile(key string) (multipart.File, *multipart.FileHeader, error) {
+ if r.MultipartForm == multipartByReader {
+ return nil, nil, errors.New("http: multipart handled by MultipartReader")
+ }
+ if r.MultipartForm == nil {
+ err := r.ParseMultipartForm(defaultMaxMemory)
+ if err != nil {
+ return nil, nil, err
+ }
+ }
+ if r.MultipartForm != nil && r.MultipartForm.File != nil {
+ if fhs := r.MultipartForm.File[key]; len(fhs) > 0 {
+ f, err := fhs[0].Open()
+ return f, fhs[0], err
+ }
+ }
+ return nil, nil, ErrMissingFile
+}
+
+func (r *Request) expectsContinue() bool {
+ return hasToken(r.Header.get("Expect"), "100-continue")
+}
+
+func (r *Request) wantsHttp10KeepAlive() bool {
+ if r.ProtoMajor != 1 || r.ProtoMinor != 0 {
+ return false
+ }
+ return hasToken(r.Header.get("Connection"), "keep-alive")
+}
+
+func (r *Request) wantsClose() bool {
+ if r.Close {
+ return true
+ }
+ return hasToken(r.Header.get("Connection"), "close")
+}
+
+func (r *Request) closeBody() error {
+ if r.Body == nil {
+ return nil
+ }
+ return r.Body.Close()
+}
+
+func (r *Request) isReplayable() bool {
+ if r.Body == nil || r.Body == NoBody || r.GetBody != nil {
+ switch valueOrDefault(r.Method, "GET") {
+ case "GET", "HEAD", "OPTIONS", "TRACE":
+ return true
+ }
+ // The Idempotency-Key, while non-standard, is widely used to
+ // mean a POST or other request is idempotent. See
+ // https://golang.org/issue/19943#issuecomment-421092421
+ if r.Header.has("Idempotency-Key") || r.Header.has("X-Idempotency-Key") {
+ return true
+ }
+ }
+ return false
+}
+
+// outgoingLength reports the Content-Length of this outgoing (Client) request.
+// It maps 0 into -1 (unknown) when the Body is non-nil.
+func (r *Request) outgoingLength() int64 {
+ if r.Body == nil || r.Body == NoBody {
+ return 0
+ }
+ if r.ContentLength != 0 {
+ return r.ContentLength
+ }
+ return -1
+}
+
+// requestMethodUsuallyLacksBody reports whether the given request
+// method is one that typically does not involve a request body.
+// This is used by the Transport (via
+// transferWriter.shouldSendChunkedRequestBody) to determine whether
+// we try to test-read a byte from a non-nil Request.Body when
+// Request.outgoingLength() returns -1. See the comments in
+// shouldSendChunkedRequestBody.
+func requestMethodUsuallyLacksBody(method string) bool {
+ switch method {
+ case "GET", "HEAD", "DELETE", "OPTIONS", "PROPFIND", "SEARCH":
+ return true
+ }
+ return false
+}
+
+// requiresHTTP1 reports whether this request requires being sent on
+// an HTTP/1 connection.
+func (r *Request) requiresHTTP1() bool {
+ return hasToken(r.Header.Get("Connection"), "upgrade") &&
+ ascii.EqualFold(r.Header.Get("Upgrade"), "websocket")
+}
diff --git a/src/net/http/request_test.go b/src/net/http/request_test.go
new file mode 100644
index 0000000..a32b583
--- /dev/null
+++ b/src/net/http/request_test.go
@@ -0,0 +1,1397 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http_test
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "crypto/rand"
+ "encoding/base64"
+ "errors"
+ "fmt"
+ "io"
+ "math"
+ "mime/multipart"
+ . "net/http"
+ "net/url"
+ "os"
+ "reflect"
+ "regexp"
+ "strings"
+ "testing"
+)
+
+func TestQuery(t *testing.T) {
+ req := &Request{Method: "GET"}
+ req.URL, _ = url.Parse("http://www.google.com/search?q=foo&q=bar")
+ if q := req.FormValue("q"); q != "foo" {
+ t.Errorf(`req.FormValue("q") = %q, want "foo"`, q)
+ }
+}
+
+// Issue #25192: Test that ParseForm fails but still parses the form when a URL
+// containing a semicolon is provided.
+func TestParseFormSemicolonSeparator(t *testing.T) {
+ for _, method := range []string{"POST", "PATCH", "PUT", "GET"} {
+ req, _ := NewRequest(method, "http://www.google.com/search?q=foo;q=bar&a=1",
+ strings.NewReader("q"))
+ err := req.ParseForm()
+ if err == nil {
+ t.Fatalf(`for method %s, ParseForm expected an error, got success`, method)
+ }
+ wantForm := url.Values{"a": []string{"1"}}
+ if !reflect.DeepEqual(req.Form, wantForm) {
+ t.Fatalf("for method %s, ParseForm expected req.Form = %v, want %v", method, req.Form, wantForm)
+ }
+ }
+}
+
+func TestParseFormQuery(t *testing.T) {
+ req, _ := NewRequest("POST", "http://www.google.com/search?q=foo&q=bar&both=x&prio=1&orphan=nope&empty=not",
+ strings.NewReader("z=post&both=y&prio=2&=nokey&orphan&empty=&"))
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
+
+ if q := req.FormValue("q"); q != "foo" {
+ t.Errorf(`req.FormValue("q") = %q, want "foo"`, q)
+ }
+ if z := req.FormValue("z"); z != "post" {
+ t.Errorf(`req.FormValue("z") = %q, want "post"`, z)
+ }
+ if bq, found := req.PostForm["q"]; found {
+ t.Errorf(`req.PostForm["q"] = %q, want no entry in map`, bq)
+ }
+ if bz := req.PostFormValue("z"); bz != "post" {
+ t.Errorf(`req.PostFormValue("z") = %q, want "post"`, bz)
+ }
+ if qs := req.Form["q"]; !reflect.DeepEqual(qs, []string{"foo", "bar"}) {
+ t.Errorf(`req.Form["q"] = %q, want ["foo", "bar"]`, qs)
+ }
+ if both := req.Form["both"]; !reflect.DeepEqual(both, []string{"y", "x"}) {
+ t.Errorf(`req.Form["both"] = %q, want ["y", "x"]`, both)
+ }
+ if prio := req.FormValue("prio"); prio != "2" {
+ t.Errorf(`req.FormValue("prio") = %q, want "2" (from body)`, prio)
+ }
+ if orphan := req.Form["orphan"]; !reflect.DeepEqual(orphan, []string{"", "nope"}) {
+ t.Errorf(`req.FormValue("orphan") = %q, want "" (from body)`, orphan)
+ }
+ if empty := req.Form["empty"]; !reflect.DeepEqual(empty, []string{"", "not"}) {
+ t.Errorf(`req.FormValue("empty") = %q, want "" (from body)`, empty)
+ }
+ if nokey := req.Form[""]; !reflect.DeepEqual(nokey, []string{"nokey"}) {
+ t.Errorf(`req.FormValue("nokey") = %q, want "nokey" (from body)`, nokey)
+ }
+}
+
+// Tests that we only parse the form automatically for certain methods.
+func TestParseFormQueryMethods(t *testing.T) {
+ for _, method := range []string{"POST", "PATCH", "PUT", "FOO"} {
+ req, _ := NewRequest(method, "http://www.google.com/search",
+ strings.NewReader("foo=bar"))
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
+ want := "bar"
+ if method == "FOO" {
+ want = ""
+ }
+ if got := req.FormValue("foo"); got != want {
+ t.Errorf(`for method %s, FormValue("foo") = %q; want %q`, method, got, want)
+ }
+ }
+}
+
+func TestParseFormUnknownContentType(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ wantErr string
+ contentType Header
+ }{
+ {"text", "", Header{"Content-Type": {"text/plain"}}},
+ // Empty content type is legal - may be treated as
+ // application/octet-stream (RFC 7231, section 3.1.1.5)
+ {"empty", "", Header{}},
+ {"boundary", "mime: invalid media parameter", Header{"Content-Type": {"text/plain; boundary="}}},
+ {"unknown", "", Header{"Content-Type": {"application/unknown"}}},
+ } {
+ t.Run(test.name,
+ func(t *testing.T) {
+ req := &Request{
+ Method: "POST",
+ Header: test.contentType,
+ Body: io.NopCloser(strings.NewReader("body")),
+ }
+ err := req.ParseForm()
+ switch {
+ case err == nil && test.wantErr != "":
+ t.Errorf("unexpected success; want error %q", test.wantErr)
+ case err != nil && test.wantErr == "":
+ t.Errorf("want success, got error: %v", err)
+ case test.wantErr != "" && test.wantErr != fmt.Sprint(err):
+ t.Errorf("got error %q; want %q", err, test.wantErr)
+ }
+ },
+ )
+ }
+}
+
+func TestParseFormInitializeOnError(t *testing.T) {
+ nilBody, _ := NewRequest("POST", "http://www.google.com/search?q=foo", nil)
+ tests := []*Request{
+ nilBody,
+ {Method: "GET", URL: nil},
+ }
+ for i, req := range tests {
+ err := req.ParseForm()
+ if req.Form == nil {
+ t.Errorf("%d. Form not initialized, error %v", i, err)
+ }
+ if req.PostForm == nil {
+ t.Errorf("%d. PostForm not initialized, error %v", i, err)
+ }
+ }
+}
+
+func TestMultipartReader(t *testing.T) {
+ tests := []struct {
+ shouldError bool
+ contentType string
+ }{
+ {false, `multipart/form-data; boundary="foo123"`},
+ {false, `multipart/mixed; boundary="foo123"`},
+ {true, `text/plain`},
+ }
+
+ for i, test := range tests {
+ req := &Request{
+ Method: "POST",
+ Header: Header{"Content-Type": {test.contentType}},
+ Body: io.NopCloser(new(bytes.Buffer)),
+ }
+ multipart, err := req.MultipartReader()
+ if test.shouldError {
+ if err == nil || multipart != nil {
+ t.Errorf("test %d: unexpectedly got nil-error (%v) or non-nil-multipart (%v)", i, err, multipart)
+ }
+ continue
+ }
+ if err != nil || multipart == nil {
+ t.Errorf("test %d: unexpectedly got error (%v) or nil-multipart (%v)", i, err, multipart)
+ }
+ }
+}
+
+// Issue 9305: ParseMultipartForm should populate PostForm too
+func TestParseMultipartFormPopulatesPostForm(t *testing.T) {
+ postData :=
+ `--xxx
+Content-Disposition: form-data; name="field1"
+
+value1
+--xxx
+Content-Disposition: form-data; name="field2"
+
+value2
+--xxx
+Content-Disposition: form-data; name="file"; filename="file"
+Content-Type: application/octet-stream
+Content-Transfer-Encoding: binary
+
+binary data
+--xxx--
+`
+ req := &Request{
+ Method: "POST",
+ Header: Header{"Content-Type": {`multipart/form-data; boundary=xxx`}},
+ Body: io.NopCloser(strings.NewReader(postData)),
+ }
+
+ initialFormItems := map[string]string{
+ "language": "Go",
+ "name": "gopher",
+ "skill": "go-ing",
+ "field2": "initial-value2",
+ }
+
+ req.Form = make(url.Values)
+ for k, v := range initialFormItems {
+ req.Form.Add(k, v)
+ }
+
+ err := req.ParseMultipartForm(10000)
+ if err != nil {
+ t.Fatalf("unexpected multipart error %v", err)
+ }
+
+ wantForm := url.Values{
+ "language": []string{"Go"},
+ "name": []string{"gopher"},
+ "skill": []string{"go-ing"},
+ "field1": []string{"value1"},
+ "field2": []string{"initial-value2", "value2"},
+ }
+ if !reflect.DeepEqual(req.Form, wantForm) {
+ t.Fatalf("req.Form = %v, want %v", req.Form, wantForm)
+ }
+
+ wantPostForm := url.Values{
+ "field1": []string{"value1"},
+ "field2": []string{"value2"},
+ }
+ if !reflect.DeepEqual(req.PostForm, wantPostForm) {
+ t.Fatalf("req.PostForm = %v, want %v", req.PostForm, wantPostForm)
+ }
+}
+
+func TestParseMultipartForm(t *testing.T) {
+ req := &Request{
+ Method: "POST",
+ Header: Header{"Content-Type": {`multipart/form-data; boundary="foo123"`}},
+ Body: io.NopCloser(new(bytes.Buffer)),
+ }
+ err := req.ParseMultipartForm(25)
+ if err == nil {
+ t.Error("expected multipart EOF, got nil")
+ }
+
+ req.Header = Header{"Content-Type": {"text/plain"}}
+ err = req.ParseMultipartForm(25)
+ if err != ErrNotMultipart {
+ t.Error("expected ErrNotMultipart for text/plain")
+ }
+}
+
+// Issue 45789: multipart form should not include directory path in filename
+func TestParseMultipartFormFilename(t *testing.T) {
+ postData :=
+ `--xxx
+Content-Disposition: form-data; name="file"; filename="../usr/foobar.txt/"
+Content-Type: text/plain
+
+--xxx--
+`
+ req := &Request{
+ Method: "POST",
+ Header: Header{"Content-Type": {`multipart/form-data; boundary=xxx`}},
+ Body: io.NopCloser(strings.NewReader(postData)),
+ }
+ _, hdr, err := req.FormFile("file")
+ if err != nil {
+ t.Fatal(err)
+ }
+ if hdr.Filename != "foobar.txt" {
+ t.Errorf("expected only the last element of the path, got %q", hdr.Filename)
+ }
+}
+
+// Issue #40430: Test that if maxMemory for ParseMultipartForm when combined with
+// the payload size and the internal leeway buffer size of 10MiB overflows, that we
+// correctly return an error.
+func TestMaxInt64ForMultipartFormMaxMemoryOverflow(t *testing.T) {
+ run(t, testMaxInt64ForMultipartFormMaxMemoryOverflow)
+}
+func testMaxInt64ForMultipartFormMaxMemoryOverflow(t *testing.T, mode testMode) {
+ payloadSize := 1 << 10
+ cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
+ // The combination of:
+ // MaxInt64 + payloadSize + (internal spare of 10MiB)
+ // triggers the overflow. See issue https://golang.org/issue/40430/
+ if err := req.ParseMultipartForm(math.MaxInt64); err != nil {
+ Error(rw, err.Error(), StatusBadRequest)
+ return
+ }
+ })).ts
+ fBuf := new(bytes.Buffer)
+ mw := multipart.NewWriter(fBuf)
+ mf, err := mw.CreateFormFile("file", "myfile.txt")
+ if err != nil {
+ t.Fatal(err)
+ }
+ if _, err := mf.Write(bytes.Repeat([]byte("abc"), payloadSize)); err != nil {
+ t.Fatal(err)
+ }
+ if err := mw.Close(); err != nil {
+ t.Fatal(err)
+ }
+ req, err := NewRequest("POST", cst.URL, fBuf)
+ if err != nil {
+ t.Fatal(err)
+ }
+ req.Header.Set("Content-Type", mw.FormDataContentType())
+ res, err := cst.Client().Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ if g, w := res.StatusCode, StatusOK; g != w {
+ t.Fatalf("Status code mismatch: got %d, want %d", g, w)
+ }
+}
+
+func TestRequestRedirect(t *testing.T) { run(t, testRequestRedirect) }
+func testRequestRedirect(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ switch r.URL.Path {
+ case "/":
+ w.Header().Set("Location", "/foo/")
+ w.WriteHeader(StatusSeeOther)
+ case "/foo/":
+ fmt.Fprintf(w, "foo")
+ default:
+ w.WriteHeader(StatusBadRequest)
+ }
+ }))
+
+ var end = regexp.MustCompile("/foo/$")
+ r, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ r.Body.Close()
+ url := r.Request.URL.String()
+ if r.StatusCode != 200 || !end.MatchString(url) {
+ t.Fatalf("Get got status %d at %q, want 200 matching /foo/$", r.StatusCode, url)
+ }
+}
+
+func TestSetBasicAuth(t *testing.T) {
+ r, _ := NewRequest("GET", "http://example.com/", nil)
+ r.SetBasicAuth("Aladdin", "open sesame")
+ if g, e := r.Header.Get("Authorization"), "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=="; g != e {
+ t.Errorf("got header %q, want %q", g, e)
+ }
+}
+
+func TestMultipartRequest(t *testing.T) {
+ // Test that we can read the values and files of a
+ // multipart request with FormValue and FormFile,
+ // and that ParseMultipartForm can be called multiple times.
+ req := newTestMultipartRequest(t)
+ if err := req.ParseMultipartForm(25); err != nil {
+ t.Fatal("ParseMultipartForm first call:", err)
+ }
+ defer req.MultipartForm.RemoveAll()
+ validateTestMultipartContents(t, req, false)
+ if err := req.ParseMultipartForm(25); err != nil {
+ t.Fatal("ParseMultipartForm second call:", err)
+ }
+ validateTestMultipartContents(t, req, false)
+}
+
+// Issue #25192: Test that ParseMultipartForm fails but still parses the
+// multi-part form when a URL containing a semicolon is provided.
+func TestParseMultipartFormSemicolonSeparator(t *testing.T) {
+ req := newTestMultipartRequest(t)
+ req.URL = &url.URL{RawQuery: "q=foo;q=bar"}
+ if err := req.ParseMultipartForm(25); err == nil {
+ t.Fatal("ParseMultipartForm expected error due to invalid semicolon, got nil")
+ }
+ defer req.MultipartForm.RemoveAll()
+ validateTestMultipartContents(t, req, false)
+}
+
+func TestMultipartRequestAuto(t *testing.T) {
+ // Test that FormValue and FormFile automatically invoke
+ // ParseMultipartForm and return the right values.
+ req := newTestMultipartRequest(t)
+ defer func() {
+ if req.MultipartForm != nil {
+ req.MultipartForm.RemoveAll()
+ }
+ }()
+ validateTestMultipartContents(t, req, true)
+}
+
+func TestMissingFileMultipartRequest(t *testing.T) {
+ // Test that FormFile returns an error if
+ // the named file is missing.
+ req := newTestMultipartRequest(t)
+ testMissingFile(t, req)
+}
+
+// Test that FormValue invokes ParseMultipartForm.
+func TestFormValueCallsParseMultipartForm(t *testing.T) {
+ req, _ := NewRequest("POST", "http://www.google.com/", strings.NewReader("z=post"))
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
+ if req.Form != nil {
+ t.Fatal("Unexpected request Form, want nil")
+ }
+ req.FormValue("z")
+ if req.Form == nil {
+ t.Fatal("ParseMultipartForm not called by FormValue")
+ }
+}
+
+// Test that FormFile invokes ParseMultipartForm.
+func TestFormFileCallsParseMultipartForm(t *testing.T) {
+ req := newTestMultipartRequest(t)
+ if req.Form != nil {
+ t.Fatal("Unexpected request Form, want nil")
+ }
+ req.FormFile("")
+ if req.Form == nil {
+ t.Fatal("ParseMultipartForm not called by FormFile")
+ }
+}
+
+// Test that ParseMultipartForm errors if called
+// after MultipartReader on the same request.
+func TestParseMultipartFormOrder(t *testing.T) {
+ req := newTestMultipartRequest(t)
+ if _, err := req.MultipartReader(); err != nil {
+ t.Fatalf("MultipartReader: %v", err)
+ }
+ if err := req.ParseMultipartForm(1024); err == nil {
+ t.Fatal("expected an error from ParseMultipartForm after call to MultipartReader")
+ }
+}
+
+// Test that MultipartReader errors if called
+// after ParseMultipartForm on the same request.
+func TestMultipartReaderOrder(t *testing.T) {
+ req := newTestMultipartRequest(t)
+ if err := req.ParseMultipartForm(25); err != nil {
+ t.Fatalf("ParseMultipartForm: %v", err)
+ }
+ defer req.MultipartForm.RemoveAll()
+ if _, err := req.MultipartReader(); err == nil {
+ t.Fatal("expected an error from MultipartReader after call to ParseMultipartForm")
+ }
+}
+
+// Test that FormFile errors if called after
+// MultipartReader on the same request.
+func TestFormFileOrder(t *testing.T) {
+ req := newTestMultipartRequest(t)
+ if _, err := req.MultipartReader(); err != nil {
+ t.Fatalf("MultipartReader: %v", err)
+ }
+ if _, _, err := req.FormFile(""); err == nil {
+ t.Fatal("expected an error from FormFile after call to MultipartReader")
+ }
+}
+
+var readRequestErrorTests = []struct {
+ in string
+ err string
+
+ header Header
+}{
+ 0: {"GET / HTTP/1.1\r\nheader:foo\r\n\r\n", "", Header{"Header": {"foo"}}},
+ 1: {"GET / HTTP/1.1\r\nheader:foo\r\n", io.ErrUnexpectedEOF.Error(), nil},
+ 2: {"", io.EOF.Error(), nil},
+ 3: {
+ in: "HEAD / HTTP/1.1\r\n\r\n",
+ header: Header{},
+ },
+
+ // Multiple Content-Length values should either be
+ // deduplicated if same or reject otherwise
+ // See Issue 16490.
+ 4: {
+ in: "POST / HTTP/1.1\r\nContent-Length: 10\r\nContent-Length: 0\r\n\r\nGopher hey\r\n",
+ err: "cannot contain multiple Content-Length headers",
+ },
+ 5: {
+ in: "POST / HTTP/1.1\r\nContent-Length: 10\r\nContent-Length: 6\r\n\r\nGopher\r\n",
+ err: "cannot contain multiple Content-Length headers",
+ },
+ 6: {
+ in: "PUT / HTTP/1.1\r\nContent-Length: 6 \r\nContent-Length: 6\r\nContent-Length:6\r\n\r\nGopher\r\n",
+ err: "",
+ header: Header{"Content-Length": {"6"}},
+ },
+ 7: {
+ in: "PUT / HTTP/1.1\r\nContent-Length: 1\r\nContent-Length: 6 \r\n\r\n",
+ err: "cannot contain multiple Content-Length headers",
+ },
+ 8: {
+ in: "POST / HTTP/1.1\r\nContent-Length:\r\nContent-Length: 3\r\n\r\n",
+ err: "cannot contain multiple Content-Length headers",
+ },
+ 9: {
+ in: "HEAD / HTTP/1.1\r\nContent-Length:0\r\nContent-Length: 0\r\n\r\n",
+ header: Header{"Content-Length": {"0"}},
+ },
+ 10: {
+ in: "HEAD / HTTP/1.1\r\nHost: foo\r\nHost: bar\r\n\r\n\r\n\r\n",
+ err: "too many Host headers",
+ },
+}
+
+func TestReadRequestErrors(t *testing.T) {
+ for i, tt := range readRequestErrorTests {
+ req, err := ReadRequest(bufio.NewReader(strings.NewReader(tt.in)))
+ if err == nil {
+ if tt.err != "" {
+ t.Errorf("#%d: got nil err; want %q", i, tt.err)
+ }
+
+ if !reflect.DeepEqual(tt.header, req.Header) {
+ t.Errorf("#%d: gotHeader: %q wantHeader: %q", i, req.Header, tt.header)
+ }
+ continue
+ }
+
+ if tt.err == "" || !strings.Contains(err.Error(), tt.err) {
+ t.Errorf("%d: got error = %v; want %v", i, err, tt.err)
+ }
+ }
+}
+
+var newRequestHostTests = []struct {
+ in, out string
+}{
+ {"http://www.example.com/", "www.example.com"},
+ {"http://www.example.com:8080/", "www.example.com:8080"},
+
+ {"http://192.168.0.1/", "192.168.0.1"},
+ {"http://192.168.0.1:8080/", "192.168.0.1:8080"},
+ {"http://192.168.0.1:/", "192.168.0.1"},
+
+ {"http://[fe80::1]/", "[fe80::1]"},
+ {"http://[fe80::1]:8080/", "[fe80::1]:8080"},
+ {"http://[fe80::1%25en0]/", "[fe80::1%en0]"},
+ {"http://[fe80::1%25en0]:8080/", "[fe80::1%en0]:8080"},
+ {"http://[fe80::1%25en0]:/", "[fe80::1%en0]"},
+}
+
+func TestNewRequestHost(t *testing.T) {
+ for i, tt := range newRequestHostTests {
+ req, err := NewRequest("GET", tt.in, nil)
+ if err != nil {
+ t.Errorf("#%v: %v", i, err)
+ continue
+ }
+ if req.Host != tt.out {
+ t.Errorf("got %q; want %q", req.Host, tt.out)
+ }
+ }
+}
+
+func TestRequestInvalidMethod(t *testing.T) {
+ _, err := NewRequest("bad method", "http://foo.com/", nil)
+ if err == nil {
+ t.Error("expected error from NewRequest with invalid method")
+ }
+ req, err := NewRequest("GET", "http://foo.example/", nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ req.Method = "bad method"
+ _, err = DefaultClient.Do(req)
+ if err == nil || !strings.Contains(err.Error(), "invalid method") {
+ t.Errorf("Transport error = %v; want invalid method", err)
+ }
+
+ req, err = NewRequest("", "http://foo.com/", nil)
+ if err != nil {
+ t.Errorf("NewRequest(empty method) = %v; want nil", err)
+ } else if req.Method != "GET" {
+ t.Errorf("NewRequest(empty method) has method %q; want GET", req.Method)
+ }
+}
+
+func TestNewRequestContentLength(t *testing.T) {
+ readByte := func(r io.Reader) io.Reader {
+ var b [1]byte
+ r.Read(b[:])
+ return r
+ }
+ tests := []struct {
+ r io.Reader
+ want int64
+ }{
+ {bytes.NewReader([]byte("123")), 3},
+ {bytes.NewBuffer([]byte("1234")), 4},
+ {strings.NewReader("12345"), 5},
+ {strings.NewReader(""), 0},
+ {NoBody, 0},
+
+ // Not detected. During Go 1.8 we tried to make these set to -1, but
+ // due to Issue 18117, we keep these returning 0, even though they're
+ // unknown.
+ {struct{ io.Reader }{strings.NewReader("xyz")}, 0},
+ {io.NewSectionReader(strings.NewReader("x"), 0, 6), 0},
+ {readByte(io.NewSectionReader(strings.NewReader("xy"), 0, 6)), 0},
+ }
+ for i, tt := range tests {
+ req, err := NewRequest("POST", "http://localhost/", tt.r)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if req.ContentLength != tt.want {
+ t.Errorf("test[%d]: ContentLength(%T) = %d; want %d", i, tt.r, req.ContentLength, tt.want)
+ }
+ }
+}
+
+var parseHTTPVersionTests = []struct {
+ vers string
+ major, minor int
+ ok bool
+}{
+ {"HTTP/0.0", 0, 0, true},
+ {"HTTP/0.9", 0, 9, true},
+ {"HTTP/1.0", 1, 0, true},
+ {"HTTP/1.1", 1, 1, true},
+
+ {"HTTP", 0, 0, false},
+ {"HTTP/one.one", 0, 0, false},
+ {"HTTP/1.1/", 0, 0, false},
+ {"HTTP/-1,0", 0, 0, false},
+ {"HTTP/0,-1", 0, 0, false},
+ {"HTTP/", 0, 0, false},
+ {"HTTP/1,1", 0, 0, false},
+ {"HTTP/+1.1", 0, 0, false},
+ {"HTTP/1.+1", 0, 0, false},
+ {"HTTP/0000000001.1", 0, 0, false},
+ {"HTTP/1.0000000001", 0, 0, false},
+ {"HTTP/3.14", 0, 0, false},
+ {"HTTP/12.3", 0, 0, false},
+}
+
+func TestParseHTTPVersion(t *testing.T) {
+ for _, tt := range parseHTTPVersionTests {
+ major, minor, ok := ParseHTTPVersion(tt.vers)
+ if ok != tt.ok || major != tt.major || minor != tt.minor {
+ type version struct {
+ major, minor int
+ ok bool
+ }
+ t.Errorf("failed to parse %q, expected: %#v, got %#v", tt.vers, version{tt.major, tt.minor, tt.ok}, version{major, minor, ok})
+ }
+ }
+}
+
+type getBasicAuthTest struct {
+ username, password string
+ ok bool
+}
+
+type basicAuthCredentialsTest struct {
+ username, password string
+}
+
+var getBasicAuthTests = []struct {
+ username, password string
+ ok bool
+}{
+ {"Aladdin", "open sesame", true},
+ {"Aladdin", "open:sesame", true},
+ {"", "", true},
+}
+
+func TestGetBasicAuth(t *testing.T) {
+ for _, tt := range getBasicAuthTests {
+ r, _ := NewRequest("GET", "http://example.com/", nil)
+ r.SetBasicAuth(tt.username, tt.password)
+ username, password, ok := r.BasicAuth()
+ if ok != tt.ok || username != tt.username || password != tt.password {
+ t.Errorf("BasicAuth() = %#v, want %#v", getBasicAuthTest{username, password, ok},
+ getBasicAuthTest{tt.username, tt.password, tt.ok})
+ }
+ }
+ // Unauthenticated request.
+ r, _ := NewRequest("GET", "http://example.com/", nil)
+ username, password, ok := r.BasicAuth()
+ if ok {
+ t.Errorf("expected false from BasicAuth when the request is unauthenticated")
+ }
+ want := basicAuthCredentialsTest{"", ""}
+ if username != want.username || password != want.password {
+ t.Errorf("expected credentials: %#v when the request is unauthenticated, got %#v",
+ want, basicAuthCredentialsTest{username, password})
+ }
+}
+
+var parseBasicAuthTests = []struct {
+ header, username, password string
+ ok bool
+}{
+ {"Basic " + base64.StdEncoding.EncodeToString([]byte("Aladdin:open sesame")), "Aladdin", "open sesame", true},
+
+ // Case doesn't matter:
+ {"BASIC " + base64.StdEncoding.EncodeToString([]byte("Aladdin:open sesame")), "Aladdin", "open sesame", true},
+ {"basic " + base64.StdEncoding.EncodeToString([]byte("Aladdin:open sesame")), "Aladdin", "open sesame", true},
+
+ {"Basic " + base64.StdEncoding.EncodeToString([]byte("Aladdin:open:sesame")), "Aladdin", "open:sesame", true},
+ {"Basic " + base64.StdEncoding.EncodeToString([]byte(":")), "", "", true},
+ {"Basic" + base64.StdEncoding.EncodeToString([]byte("Aladdin:open sesame")), "", "", false},
+ {base64.StdEncoding.EncodeToString([]byte("Aladdin:open sesame")), "", "", false},
+ {"Basic ", "", "", false},
+ {"Basic Aladdin:open sesame", "", "", false},
+ {`Digest username="Aladdin"`, "", "", false},
+}
+
+func TestParseBasicAuth(t *testing.T) {
+ for _, tt := range parseBasicAuthTests {
+ r, _ := NewRequest("GET", "http://example.com/", nil)
+ r.Header.Set("Authorization", tt.header)
+ username, password, ok := r.BasicAuth()
+ if ok != tt.ok || username != tt.username || password != tt.password {
+ t.Errorf("BasicAuth() = %#v, want %#v", getBasicAuthTest{username, password, ok},
+ getBasicAuthTest{tt.username, tt.password, tt.ok})
+ }
+ }
+}
+
+type logWrites struct {
+ t *testing.T
+ dst *[]string
+}
+
+func (l logWrites) WriteByte(c byte) error {
+ l.t.Fatalf("unexpected WriteByte call")
+ return nil
+}
+
+func (l logWrites) Write(p []byte) (n int, err error) {
+ *l.dst = append(*l.dst, string(p))
+ return len(p), nil
+}
+
+func TestRequestWriteBufferedWriter(t *testing.T) {
+ got := []string{}
+ req, _ := NewRequest("GET", "http://foo.com/", nil)
+ req.Write(logWrites{t, &got})
+ want := []string{
+ "GET / HTTP/1.1\r\n",
+ "Host: foo.com\r\n",
+ "User-Agent: " + DefaultUserAgent + "\r\n",
+ "\r\n",
+ }
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("Writes = %q\n Want = %q", got, want)
+ }
+}
+
+func TestRequestBadHostHeader(t *testing.T) {
+ got := []string{}
+ req, err := NewRequest("GET", "http://foo/after", nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ req.Host = "foo.com\nnewline"
+ req.URL.Host = "foo.com\nnewline"
+ req.Write(logWrites{t, &got})
+ want := []string{
+ "GET /after HTTP/1.1\r\n",
+ "Host: \r\n",
+ "User-Agent: " + DefaultUserAgent + "\r\n",
+ "\r\n",
+ }
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("Writes = %q\n Want = %q", got, want)
+ }
+}
+
+func TestStarRequest(t *testing.T) {
+ req, err := ReadRequest(bufio.NewReader(strings.NewReader("M-SEARCH * HTTP/1.1\r\n\r\n")))
+ if err != nil {
+ return
+ }
+ if req.ContentLength != 0 {
+ t.Errorf("ContentLength = %d; want 0", req.ContentLength)
+ }
+ if req.Body == nil {
+ t.Errorf("Body = nil; want non-nil")
+ }
+
+ // Request.Write has Client semantics for Body/ContentLength,
+ // where ContentLength 0 means unknown if Body is non-nil, and
+ // thus chunking will happen unless we change semantics and
+ // signal that we want to serialize it as exactly zero. The
+ // only way to do that for outbound requests is with a nil
+ // Body:
+ clientReq := *req
+ clientReq.Body = nil
+
+ var out strings.Builder
+ if err := clientReq.Write(&out); err != nil {
+ t.Fatal(err)
+ }
+
+ if strings.Contains(out.String(), "chunked") {
+ t.Error("wrote chunked request; want no body")
+ }
+ back, err := ReadRequest(bufio.NewReader(strings.NewReader(out.String())))
+ if err != nil {
+ t.Fatal(err)
+ }
+ // Ignore the Headers (the User-Agent breaks the deep equal,
+ // but we don't care about it)
+ req.Header = nil
+ back.Header = nil
+ if !reflect.DeepEqual(req, back) {
+ t.Errorf("Original request doesn't match Request read back.")
+ t.Logf("Original: %#v", req)
+ t.Logf("Original.URL: %#v", req.URL)
+ t.Logf("Wrote: %s", out.String())
+ t.Logf("Read back (doesn't match Original): %#v", back)
+ }
+}
+
+type responseWriterJustWriter struct {
+ io.Writer
+}
+
+func (responseWriterJustWriter) Header() Header { panic("should not be called") }
+func (responseWriterJustWriter) WriteHeader(int) { panic("should not be called") }
+
+// delayedEOFReader never returns (n > 0, io.EOF), instead putting
+// off the io.EOF until a subsequent Read call.
+type delayedEOFReader struct {
+ r io.Reader
+}
+
+func (dr delayedEOFReader) Read(p []byte) (n int, err error) {
+ n, err = dr.r.Read(p)
+ if n > 0 && err == io.EOF {
+ err = nil
+ }
+ return
+}
+
+func TestIssue10884_MaxBytesEOF(t *testing.T) {
+ dst := io.Discard
+ _, err := io.Copy(dst, MaxBytesReader(
+ responseWriterJustWriter{dst},
+ io.NopCloser(delayedEOFReader{strings.NewReader("12345")}),
+ 5))
+ if err != nil {
+ t.Fatal(err)
+ }
+}
+
+// Issue 14981: MaxBytesReader's return error wasn't sticky. It
+// doesn't technically need to be, but people expected it to be.
+func TestMaxBytesReaderStickyError(t *testing.T) {
+ isSticky := func(r io.Reader) error {
+ var log bytes.Buffer
+ buf := make([]byte, 1000)
+ var firstErr error
+ for {
+ n, err := r.Read(buf)
+ fmt.Fprintf(&log, "Read(%d) = %d, %v\n", len(buf), n, err)
+ if err == nil {
+ continue
+ }
+ if firstErr == nil {
+ firstErr = err
+ continue
+ }
+ if !reflect.DeepEqual(err, firstErr) {
+ return fmt.Errorf("non-sticky error. got log:\n%s", log.Bytes())
+ }
+ t.Logf("Got log: %s", log.Bytes())
+ return nil
+ }
+ }
+ tests := [...]struct {
+ readable int
+ limit int64
+ }{
+ 0: {99, 100},
+ 1: {100, 100},
+ 2: {101, 100},
+ }
+ for i, tt := range tests {
+ rc := MaxBytesReader(nil, io.NopCloser(bytes.NewReader(make([]byte, tt.readable))), tt.limit)
+ if err := isSticky(rc); err != nil {
+ t.Errorf("%d. error: %v", i, err)
+ }
+ }
+}
+
+// Issue 45101: maxBytesReader's Read panicked when n < -1. This test
+// also ensures that Read treats negative limits as equivalent to 0.
+func TestMaxBytesReaderDifferentLimits(t *testing.T) {
+ const testStr = "1234"
+ tests := [...]struct {
+ limit int64
+ lenP int
+ wantN int
+ wantErr bool
+ }{
+ 0: {
+ limit: -123,
+ lenP: 0,
+ wantN: 0,
+ wantErr: false, // Ensure we won't return an error when the limit is negative, but we don't need to read.
+ },
+ 1: {
+ limit: -100,
+ lenP: 32 * 1024,
+ wantN: 0,
+ wantErr: true,
+ },
+ 2: {
+ limit: -2,
+ lenP: 1,
+ wantN: 0,
+ wantErr: true,
+ },
+ 3: {
+ limit: -1,
+ lenP: 2,
+ wantN: 0,
+ wantErr: true,
+ },
+ 4: {
+ limit: 0,
+ lenP: 3,
+ wantN: 0,
+ wantErr: true,
+ },
+ 5: {
+ limit: 1,
+ lenP: 4,
+ wantN: 1,
+ wantErr: true,
+ },
+ 6: {
+ limit: 2,
+ lenP: 5,
+ wantN: 2,
+ wantErr: true,
+ },
+ 7: {
+ limit: 3,
+ lenP: 2,
+ wantN: 2,
+ wantErr: false,
+ },
+ 8: {
+ limit: int64(len(testStr)),
+ lenP: len(testStr),
+ wantN: len(testStr),
+ wantErr: false,
+ },
+ 9: {
+ limit: 100,
+ lenP: 6,
+ wantN: len(testStr),
+ wantErr: false,
+ },
+ 10: { /* Issue 54408 */
+ limit: int64(1<<63 - 1),
+ lenP: len(testStr),
+ wantN: len(testStr),
+ wantErr: false,
+ },
+ }
+ for i, tt := range tests {
+ rc := MaxBytesReader(nil, io.NopCloser(strings.NewReader(testStr)), tt.limit)
+
+ n, err := rc.Read(make([]byte, tt.lenP))
+
+ if n != tt.wantN {
+ t.Errorf("%d. n: %d, want n: %d", i, n, tt.wantN)
+ }
+
+ if (err != nil) != tt.wantErr {
+ t.Errorf("%d. error: %v", i, err)
+ }
+ }
+}
+
+func TestWithContextNilURL(t *testing.T) {
+ req, err := NewRequest("POST", "https://golang.org/", nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Issue 20601
+ req.URL = nil
+ reqCopy := req.WithContext(context.Background())
+ if reqCopy.URL != nil {
+ t.Error("expected nil URL in cloned request")
+ }
+}
+
+// Ensure that Request.Clone creates a deep copy of TransferEncoding.
+// See issue 41907.
+func TestRequestCloneTransferEncoding(t *testing.T) {
+ body := strings.NewReader("body")
+ req, _ := NewRequest("POST", "https://example.org/", body)
+ req.TransferEncoding = []string{
+ "encoding1",
+ }
+
+ clonedReq := req.Clone(context.Background())
+ // modify original after deep copy
+ req.TransferEncoding[0] = "encoding2"
+
+ if req.TransferEncoding[0] != "encoding2" {
+ t.Error("expected req.TransferEncoding to be changed")
+ }
+ if clonedReq.TransferEncoding[0] != "encoding1" {
+ t.Error("expected clonedReq.TransferEncoding to be unchanged")
+ }
+}
+
+// Issue 34878: verify we don't panic when including basic auth (Go 1.13 regression)
+func TestNoPanicOnRoundTripWithBasicAuth(t *testing.T) { run(t, testNoPanicWithBasicAuth) }
+func testNoPanicWithBasicAuth(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}))
+
+ u, err := url.Parse(cst.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ u.User = url.UserPassword("foo", "bar")
+ req := &Request{
+ URL: u,
+ Method: "GET",
+ }
+ if _, err := cst.c.Do(req); err != nil {
+ t.Fatalf("Unexpected error: %v", err)
+ }
+}
+
+// verify that NewRequest sets Request.GetBody and that it works
+func TestNewRequestGetBody(t *testing.T) {
+ tests := []struct {
+ r io.Reader
+ }{
+ {r: strings.NewReader("hello")},
+ {r: bytes.NewReader([]byte("hello"))},
+ {r: bytes.NewBuffer([]byte("hello"))},
+ }
+ for i, tt := range tests {
+ req, err := NewRequest("POST", "http://foo.tld/", tt.r)
+ if err != nil {
+ t.Errorf("test[%d]: %v", i, err)
+ continue
+ }
+ if req.Body == nil {
+ t.Errorf("test[%d]: Body = nil", i)
+ continue
+ }
+ if req.GetBody == nil {
+ t.Errorf("test[%d]: GetBody = nil", i)
+ continue
+ }
+ slurp1, err := io.ReadAll(req.Body)
+ if err != nil {
+ t.Errorf("test[%d]: ReadAll(Body) = %v", i, err)
+ }
+ newBody, err := req.GetBody()
+ if err != nil {
+ t.Errorf("test[%d]: GetBody = %v", i, err)
+ }
+ slurp2, err := io.ReadAll(newBody)
+ if err != nil {
+ t.Errorf("test[%d]: ReadAll(GetBody()) = %v", i, err)
+ }
+ if string(slurp1) != string(slurp2) {
+ t.Errorf("test[%d]: Body %q != GetBody %q", i, slurp1, slurp2)
+ }
+ }
+}
+
+func testMissingFile(t *testing.T, req *Request) {
+ f, fh, err := req.FormFile("missing")
+ if f != nil {
+ t.Errorf("FormFile file = %v, want nil", f)
+ }
+ if fh != nil {
+ t.Errorf("FormFile file header = %v, want nil", fh)
+ }
+ if err != ErrMissingFile {
+ t.Errorf("FormFile err = %q, want ErrMissingFile", err)
+ }
+}
+
+func newTestMultipartRequest(t *testing.T) *Request {
+ b := strings.NewReader(strings.ReplaceAll(message, "\n", "\r\n"))
+ req, err := NewRequest("POST", "/", b)
+ if err != nil {
+ t.Fatal("NewRequest:", err)
+ }
+ ctype := fmt.Sprintf(`multipart/form-data; boundary="%s"`, boundary)
+ req.Header.Set("Content-type", ctype)
+ return req
+}
+
+func validateTestMultipartContents(t *testing.T, req *Request, allMem bool) {
+ if g, e := req.FormValue("texta"), textaValue; g != e {
+ t.Errorf("texta value = %q, want %q", g, e)
+ }
+ if g, e := req.FormValue("textb"), textbValue; g != e {
+ t.Errorf("textb value = %q, want %q", g, e)
+ }
+ if g := req.FormValue("missing"); g != "" {
+ t.Errorf("missing value = %q, want empty string", g)
+ }
+
+ assertMem := func(n string, fd multipart.File) {
+ if _, ok := fd.(*os.File); ok {
+ t.Error(n, " is *os.File, should not be")
+ }
+ }
+ fda := testMultipartFile(t, req, "filea", "filea.txt", fileaContents)
+ defer fda.Close()
+ assertMem("filea", fda)
+ fdb := testMultipartFile(t, req, "fileb", "fileb.txt", filebContents)
+ defer fdb.Close()
+ if allMem {
+ assertMem("fileb", fdb)
+ } else {
+ if _, ok := fdb.(*os.File); !ok {
+ t.Errorf("fileb has unexpected underlying type %T", fdb)
+ }
+ }
+
+ testMissingFile(t, req)
+}
+
+func testMultipartFile(t *testing.T, req *Request, key, expectFilename, expectContent string) multipart.File {
+ f, fh, err := req.FormFile(key)
+ if err != nil {
+ t.Fatalf("FormFile(%q): %q", key, err)
+ }
+ if fh.Filename != expectFilename {
+ t.Errorf("filename = %q, want %q", fh.Filename, expectFilename)
+ }
+ var b strings.Builder
+ _, err = io.Copy(&b, f)
+ if err != nil {
+ t.Fatal("copying contents:", err)
+ }
+ if g := b.String(); g != expectContent {
+ t.Errorf("contents = %q, want %q", g, expectContent)
+ }
+ return f
+}
+
+// Issue 53181: verify Request.Cookie return the correct Cookie.
+// Return ErrNoCookie instead of the first cookie when name is "".
+func TestRequestCookie(t *testing.T) {
+ for _, tt := range []struct {
+ name string
+ value string
+ expectedErr error
+ }{
+ {
+ name: "foo",
+ value: "bar",
+ expectedErr: nil,
+ },
+ {
+ name: "",
+ expectedErr: ErrNoCookie,
+ },
+ } {
+ req, err := NewRequest("GET", "http://example.com/", nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ req.AddCookie(&Cookie{Name: tt.name, Value: tt.value})
+ c, err := req.Cookie(tt.name)
+ if err != tt.expectedErr {
+ t.Errorf("got %v, want %v", err, tt.expectedErr)
+ }
+
+ // skip if error occurred.
+ if err != nil {
+ continue
+ }
+ if c.Value != tt.value {
+ t.Errorf("got %v, want %v", c.Value, tt.value)
+ }
+ if c.Name != tt.name {
+ t.Errorf("got %s, want %v", tt.name, c.Name)
+ }
+ }
+}
+
+const (
+ fileaContents = "This is a test file."
+ filebContents = "Another test file."
+ textaValue = "foo"
+ textbValue = "bar"
+ boundary = `MyBoundary`
+)
+
+const message = `
+--MyBoundary
+Content-Disposition: form-data; name="filea"; filename="filea.txt"
+Content-Type: text/plain
+
+` + fileaContents + `
+--MyBoundary
+Content-Disposition: form-data; name="fileb"; filename="fileb.txt"
+Content-Type: text/plain
+
+` + filebContents + `
+--MyBoundary
+Content-Disposition: form-data; name="texta"
+
+` + textaValue + `
+--MyBoundary
+Content-Disposition: form-data; name="textb"
+
+` + textbValue + `
+--MyBoundary--
+`
+
+func benchmarkReadRequest(b *testing.B, request string) {
+ request = request + "\n" // final \n
+ request = strings.ReplaceAll(request, "\n", "\r\n") // expand \n to \r\n
+ b.SetBytes(int64(len(request)))
+ r := bufio.NewReader(&infiniteReader{buf: []byte(request)})
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ _, err := ReadRequest(r)
+ if err != nil {
+ b.Fatalf("failed to read request: %v", err)
+ }
+ }
+}
+
+// infiniteReader satisfies Read requests as if the contents of buf
+// loop indefinitely.
+type infiniteReader struct {
+ buf []byte
+ offset int
+}
+
+func (r *infiniteReader) Read(b []byte) (int, error) {
+ n := copy(b, r.buf[r.offset:])
+ r.offset = (r.offset + n) % len(r.buf)
+ return n, nil
+}
+
+func BenchmarkReadRequestChrome(b *testing.B) {
+ // https://github.com/felixge/node-http-perf/blob/master/fixtures/get.http
+ benchmarkReadRequest(b, `GET / HTTP/1.1
+Host: localhost:8080
+Connection: keep-alive
+Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
+User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17
+Accept-Encoding: gzip,deflate,sdch
+Accept-Language: en-US,en;q=0.8
+Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3
+Cookie: __utma=1.1978842379.1323102373.1323102373.1323102373.1; EPi:NumberOfVisits=1,2012-02-28T13:42:18; CrmSession=5b707226b9563e1bc69084d07a107c98; plushContainerWidth=100%25; plushNoTopMenu=0; hudson_auto_refresh=false
+`)
+}
+
+func BenchmarkReadRequestCurl(b *testing.B) {
+ // curl http://localhost:8080/
+ benchmarkReadRequest(b, `GET / HTTP/1.1
+User-Agent: curl/7.27.0
+Host: localhost:8080
+Accept: */*
+`)
+}
+
+func BenchmarkReadRequestApachebench(b *testing.B) {
+ // ab -n 1 -c 1 http://localhost:8080/
+ benchmarkReadRequest(b, `GET / HTTP/1.0
+Host: localhost:8080
+User-Agent: ApacheBench/2.3
+Accept: */*
+`)
+}
+
+func BenchmarkReadRequestSiege(b *testing.B) {
+ // siege -r 1 -c 1 http://localhost:8080/
+ benchmarkReadRequest(b, `GET / HTTP/1.1
+Host: localhost:8080
+Accept: */*
+Accept-Encoding: gzip
+User-Agent: JoeDog/1.00 [en] (X11; I; Siege 2.70)
+Connection: keep-alive
+`)
+}
+
+func BenchmarkReadRequestWrk(b *testing.B) {
+ // wrk -t 1 -r 1 -c 1 http://localhost:8080/
+ benchmarkReadRequest(b, `GET / HTTP/1.1
+Host: localhost:8080
+`)
+}
+
+func BenchmarkFileAndServer_1KB(b *testing.B) {
+ benchmarkFileAndServer(b, 1<<10)
+}
+
+func BenchmarkFileAndServer_16MB(b *testing.B) {
+ benchmarkFileAndServer(b, 1<<24)
+}
+
+func BenchmarkFileAndServer_64MB(b *testing.B) {
+ benchmarkFileAndServer(b, 1<<26)
+}
+
+func benchmarkFileAndServer(b *testing.B, n int64) {
+ f, err := os.CreateTemp(os.TempDir(), "go-bench-http-file-and-server")
+ if err != nil {
+ b.Fatalf("Failed to create temp file: %v", err)
+ }
+
+ defer func() {
+ f.Close()
+ os.RemoveAll(f.Name())
+ }()
+
+ if _, err := io.CopyN(f, rand.Reader, n); err != nil {
+ b.Fatalf("Failed to copy %d bytes: %v", n, err)
+ }
+
+ run(b, func(b *testing.B, mode testMode) {
+ runFileAndServerBenchmarks(b, mode, f, n)
+ }, []testMode{http1Mode, https1Mode, http2Mode})
+}
+
+func runFileAndServerBenchmarks(b *testing.B, mode testMode, f *os.File, n int64) {
+ handler := HandlerFunc(func(rw ResponseWriter, req *Request) {
+ defer req.Body.Close()
+ nc, err := io.Copy(io.Discard, req.Body)
+ if err != nil {
+ panic(err)
+ }
+
+ if nc != n {
+ panic(fmt.Errorf("Copied %d Wanted %d bytes", nc, n))
+ }
+ })
+
+ cst := newClientServerTest(b, mode, handler).ts
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ // Perform some setup.
+ b.StopTimer()
+ if _, err := f.Seek(0, 0); err != nil {
+ b.Fatalf("Failed to seek back to file: %v", err)
+ }
+
+ b.StartTimer()
+ req, err := NewRequest("PUT", cst.URL, io.NopCloser(f))
+ if err != nil {
+ b.Fatal(err)
+ }
+
+ req.ContentLength = n
+ // Prevent mime sniffing by setting the Content-Type.
+ req.Header.Set("Content-Type", "application/octet-stream")
+ res, err := cst.Client().Do(req)
+ if err != nil {
+ b.Fatalf("Failed to make request to backend: %v", err)
+ }
+
+ res.Body.Close()
+ b.SetBytes(n)
+ }
+}
+
+func TestErrNotSupported(t *testing.T) {
+ if !errors.Is(ErrNotSupported, errors.ErrUnsupported) {
+ t.Error("errors.Is(ErrNotSupported, errors.ErrUnsupported) failed")
+ }
+}
diff --git a/src/net/http/requestwrite_test.go b/src/net/http/requestwrite_test.go
new file mode 100644
index 0000000..380ae9d
--- /dev/null
+++ b/src/net/http/requestwrite_test.go
@@ -0,0 +1,977 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "bufio"
+ "bytes"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "net/url"
+ "strings"
+ "testing"
+ "testing/iotest"
+ "time"
+)
+
+type reqWriteTest struct {
+ Req Request
+ Body any // optional []byte or func() io.ReadCloser to populate Req.Body
+
+ // Any of these three may be empty to skip that test.
+ WantWrite string // Request.Write
+ WantProxy string // Request.WriteProxy
+
+ WantError error // wanted error from Request.Write
+}
+
+var reqWriteTests = []reqWriteTest{
+ // HTTP/1.1 => chunked coding; no body; no trailer
+ 0: {
+ Req: Request{
+ Method: "GET",
+ URL: &url.URL{
+ Scheme: "http",
+ Host: "www.techcrunch.com",
+ Path: "/",
+ },
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{
+ "Accept": {"text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"},
+ "Accept-Charset": {"ISO-8859-1,utf-8;q=0.7,*;q=0.7"},
+ "Accept-Encoding": {"gzip,deflate"},
+ "Accept-Language": {"en-us,en;q=0.5"},
+ "Keep-Alive": {"300"},
+ "Proxy-Connection": {"keep-alive"},
+ "User-Agent": {"Fake"},
+ },
+ Body: nil,
+ Close: false,
+ Host: "www.techcrunch.com",
+ Form: map[string][]string{},
+ },
+
+ WantWrite: "GET / HTTP/1.1\r\n" +
+ "Host: www.techcrunch.com\r\n" +
+ "User-Agent: Fake\r\n" +
+ "Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8\r\n" +
+ "Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.7\r\n" +
+ "Accept-Encoding: gzip,deflate\r\n" +
+ "Accept-Language: en-us,en;q=0.5\r\n" +
+ "Keep-Alive: 300\r\n" +
+ "Proxy-Connection: keep-alive\r\n\r\n",
+
+ WantProxy: "GET http://www.techcrunch.com/ HTTP/1.1\r\n" +
+ "Host: www.techcrunch.com\r\n" +
+ "User-Agent: Fake\r\n" +
+ "Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8\r\n" +
+ "Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.7\r\n" +
+ "Accept-Encoding: gzip,deflate\r\n" +
+ "Accept-Language: en-us,en;q=0.5\r\n" +
+ "Keep-Alive: 300\r\n" +
+ "Proxy-Connection: keep-alive\r\n\r\n",
+ },
+ // HTTP/1.1 => chunked coding; body; empty trailer
+ 1: {
+ Req: Request{
+ Method: "GET",
+ URL: &url.URL{
+ Scheme: "http",
+ Host: "www.google.com",
+ Path: "/search",
+ },
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{},
+ TransferEncoding: []string{"chunked"},
+ },
+
+ Body: []byte("abcdef"),
+
+ WantWrite: "GET /search HTTP/1.1\r\n" +
+ "Host: www.google.com\r\n" +
+ "User-Agent: Go-http-client/1.1\r\n" +
+ "Transfer-Encoding: chunked\r\n\r\n" +
+ chunk("abcdef") + chunk(""),
+
+ WantProxy: "GET http://www.google.com/search HTTP/1.1\r\n" +
+ "Host: www.google.com\r\n" +
+ "User-Agent: Go-http-client/1.1\r\n" +
+ "Transfer-Encoding: chunked\r\n\r\n" +
+ chunk("abcdef") + chunk(""),
+ },
+ // HTTP/1.1 POST => chunked coding; body; empty trailer
+ 2: {
+ Req: Request{
+ Method: "POST",
+ URL: &url.URL{
+ Scheme: "http",
+ Host: "www.google.com",
+ Path: "/search",
+ },
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{},
+ Close: true,
+ TransferEncoding: []string{"chunked"},
+ },
+
+ Body: []byte("abcdef"),
+
+ WantWrite: "POST /search HTTP/1.1\r\n" +
+ "Host: www.google.com\r\n" +
+ "User-Agent: Go-http-client/1.1\r\n" +
+ "Connection: close\r\n" +
+ "Transfer-Encoding: chunked\r\n\r\n" +
+ chunk("abcdef") + chunk(""),
+
+ WantProxy: "POST http://www.google.com/search HTTP/1.1\r\n" +
+ "Host: www.google.com\r\n" +
+ "User-Agent: Go-http-client/1.1\r\n" +
+ "Connection: close\r\n" +
+ "Transfer-Encoding: chunked\r\n\r\n" +
+ chunk("abcdef") + chunk(""),
+ },
+
+ // HTTP/1.1 POST with Content-Length, no chunking
+ 3: {
+ Req: Request{
+ Method: "POST",
+ URL: &url.URL{
+ Scheme: "http",
+ Host: "www.google.com",
+ Path: "/search",
+ },
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{},
+ Close: true,
+ ContentLength: 6,
+ },
+
+ Body: []byte("abcdef"),
+
+ WantWrite: "POST /search HTTP/1.1\r\n" +
+ "Host: www.google.com\r\n" +
+ "User-Agent: Go-http-client/1.1\r\n" +
+ "Connection: close\r\n" +
+ "Content-Length: 6\r\n" +
+ "\r\n" +
+ "abcdef",
+
+ WantProxy: "POST http://www.google.com/search HTTP/1.1\r\n" +
+ "Host: www.google.com\r\n" +
+ "User-Agent: Go-http-client/1.1\r\n" +
+ "Connection: close\r\n" +
+ "Content-Length: 6\r\n" +
+ "\r\n" +
+ "abcdef",
+ },
+
+ // HTTP/1.1 POST with Content-Length in headers
+ 4: {
+ Req: Request{
+ Method: "POST",
+ URL: mustParseURL("http://example.com/"),
+ Host: "example.com",
+ Header: Header{
+ "Content-Length": []string{"10"}, // ignored
+ },
+ ContentLength: 6,
+ },
+
+ Body: []byte("abcdef"),
+
+ WantWrite: "POST / HTTP/1.1\r\n" +
+ "Host: example.com\r\n" +
+ "User-Agent: Go-http-client/1.1\r\n" +
+ "Content-Length: 6\r\n" +
+ "\r\n" +
+ "abcdef",
+
+ WantProxy: "POST http://example.com/ HTTP/1.1\r\n" +
+ "Host: example.com\r\n" +
+ "User-Agent: Go-http-client/1.1\r\n" +
+ "Content-Length: 6\r\n" +
+ "\r\n" +
+ "abcdef",
+ },
+
+ // default to HTTP/1.1
+ 5: {
+ Req: Request{
+ Method: "GET",
+ URL: mustParseURL("/search"),
+ Host: "www.google.com",
+ },
+
+ WantWrite: "GET /search HTTP/1.1\r\n" +
+ "Host: www.google.com\r\n" +
+ "User-Agent: Go-http-client/1.1\r\n" +
+ "\r\n",
+ },
+
+ // Request with a 0 ContentLength and a 0 byte body.
+ 6: {
+ Req: Request{
+ Method: "POST",
+ URL: mustParseURL("/"),
+ Host: "example.com",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ ContentLength: 0, // as if unset by user
+ },
+
+ Body: func() io.ReadCloser { return io.NopCloser(io.LimitReader(strings.NewReader("xx"), 0)) },
+
+ WantWrite: "POST / HTTP/1.1\r\n" +
+ "Host: example.com\r\n" +
+ "User-Agent: Go-http-client/1.1\r\n" +
+ "Transfer-Encoding: chunked\r\n" +
+ "\r\n0\r\n\r\n",
+
+ WantProxy: "POST / HTTP/1.1\r\n" +
+ "Host: example.com\r\n" +
+ "User-Agent: Go-http-client/1.1\r\n" +
+ "Transfer-Encoding: chunked\r\n" +
+ "\r\n0\r\n\r\n",
+ },
+
+ // Request with a 0 ContentLength and a nil body.
+ 7: {
+ Req: Request{
+ Method: "POST",
+ URL: mustParseURL("/"),
+ Host: "example.com",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ ContentLength: 0, // as if unset by user
+ },
+
+ Body: func() io.ReadCloser { return nil },
+
+ WantWrite: "POST / HTTP/1.1\r\n" +
+ "Host: example.com\r\n" +
+ "User-Agent: Go-http-client/1.1\r\n" +
+ "Content-Length: 0\r\n" +
+ "\r\n",
+
+ WantProxy: "POST / HTTP/1.1\r\n" +
+ "Host: example.com\r\n" +
+ "User-Agent: Go-http-client/1.1\r\n" +
+ "Content-Length: 0\r\n" +
+ "\r\n",
+ },
+
+ // Request with a 0 ContentLength and a 1 byte body.
+ 8: {
+ Req: Request{
+ Method: "POST",
+ URL: mustParseURL("/"),
+ Host: "example.com",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ ContentLength: 0, // as if unset by user
+ },
+
+ Body: func() io.ReadCloser { return io.NopCloser(io.LimitReader(strings.NewReader("xx"), 1)) },
+
+ WantWrite: "POST / HTTP/1.1\r\n" +
+ "Host: example.com\r\n" +
+ "User-Agent: Go-http-client/1.1\r\n" +
+ "Transfer-Encoding: chunked\r\n\r\n" +
+ chunk("x") + chunk(""),
+
+ WantProxy: "POST / HTTP/1.1\r\n" +
+ "Host: example.com\r\n" +
+ "User-Agent: Go-http-client/1.1\r\n" +
+ "Transfer-Encoding: chunked\r\n\r\n" +
+ chunk("x") + chunk(""),
+ },
+
+ // Request with a ContentLength of 10 but a 5 byte body.
+ 9: {
+ Req: Request{
+ Method: "POST",
+ URL: mustParseURL("/"),
+ Host: "example.com",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ ContentLength: 10, // but we're going to send only 5 bytes
+ },
+ Body: []byte("12345"),
+ WantError: errors.New("http: ContentLength=10 with Body length 5"),
+ },
+
+ // Request with a ContentLength of 4 but an 8 byte body.
+ 10: {
+ Req: Request{
+ Method: "POST",
+ URL: mustParseURL("/"),
+ Host: "example.com",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ ContentLength: 4, // but we're going to try to send 8 bytes
+ },
+ Body: []byte("12345678"),
+ WantError: errors.New("http: ContentLength=4 with Body length 8"),
+ },
+
+ // Request with a 5 ContentLength and nil body.
+ 11: {
+ Req: Request{
+ Method: "POST",
+ URL: mustParseURL("/"),
+ Host: "example.com",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ ContentLength: 5, // but we'll omit the body
+ },
+ WantError: errors.New("http: Request.ContentLength=5 with nil Body"),
+ },
+
+ // Request with a 0 ContentLength and a body with 1 byte content and an error.
+ 12: {
+ Req: Request{
+ Method: "POST",
+ URL: mustParseURL("/"),
+ Host: "example.com",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ ContentLength: 0, // as if unset by user
+ },
+
+ Body: func() io.ReadCloser {
+ err := errors.New("Custom reader error")
+ errReader := iotest.ErrReader(err)
+ return io.NopCloser(io.MultiReader(strings.NewReader("x"), errReader))
+ },
+
+ WantError: errors.New("Custom reader error"),
+ },
+
+ // Request with a 0 ContentLength and a body without content and an error.
+ 13: {
+ Req: Request{
+ Method: "POST",
+ URL: mustParseURL("/"),
+ Host: "example.com",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ ContentLength: 0, // as if unset by user
+ },
+
+ Body: func() io.ReadCloser {
+ err := errors.New("Custom reader error")
+ errReader := iotest.ErrReader(err)
+ return io.NopCloser(errReader)
+ },
+
+ WantError: errors.New("Custom reader error"),
+ },
+
+ // Verify that DumpRequest preserves the HTTP version number, doesn't add a Host,
+ // and doesn't add a User-Agent.
+ 14: {
+ Req: Request{
+ Method: "GET",
+ URL: mustParseURL("/foo"),
+ ProtoMajor: 1,
+ ProtoMinor: 0,
+ Header: Header{
+ "X-Foo": []string{"X-Bar"},
+ },
+ },
+
+ WantWrite: "GET /foo HTTP/1.1\r\n" +
+ "Host: \r\n" +
+ "User-Agent: Go-http-client/1.1\r\n" +
+ "X-Foo: X-Bar\r\n\r\n",
+ },
+
+ // If no Request.Host and no Request.URL.Host, we send
+ // an empty Host header, and don't use
+ // Request.Header["Host"]. This is just testing that
+ // we don't change Go 1.0 behavior.
+ 15: {
+ Req: Request{
+ Method: "GET",
+ Host: "",
+ URL: &url.URL{
+ Scheme: "http",
+ Host: "",
+ Path: "/search",
+ },
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{
+ "Host": []string{"bad.example.com"},
+ },
+ },
+
+ WantWrite: "GET /search HTTP/1.1\r\n" +
+ "Host: \r\n" +
+ "User-Agent: Go-http-client/1.1\r\n\r\n",
+ },
+
+ // Opaque test #1 from golang.org/issue/4860
+ 16: {
+ Req: Request{
+ Method: "GET",
+ URL: &url.URL{
+ Scheme: "http",
+ Host: "www.google.com",
+ Opaque: "/%2F/%2F/",
+ },
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{},
+ },
+
+ WantWrite: "GET /%2F/%2F/ HTTP/1.1\r\n" +
+ "Host: www.google.com\r\n" +
+ "User-Agent: Go-http-client/1.1\r\n\r\n",
+ },
+
+ // Opaque test #2 from golang.org/issue/4860
+ 17: {
+ Req: Request{
+ Method: "GET",
+ URL: &url.URL{
+ Scheme: "http",
+ Host: "x.google.com",
+ Opaque: "//y.google.com/%2F/%2F/",
+ },
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{},
+ },
+
+ WantWrite: "GET http://y.google.com/%2F/%2F/ HTTP/1.1\r\n" +
+ "Host: x.google.com\r\n" +
+ "User-Agent: Go-http-client/1.1\r\n\r\n",
+ },
+
+ // Testing custom case in header keys. Issue 5022.
+ 18: {
+ Req: Request{
+ Method: "GET",
+ URL: &url.URL{
+ Scheme: "http",
+ Host: "www.google.com",
+ Path: "/",
+ },
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{
+ "ALL-CAPS": {"x"},
+ },
+ },
+
+ WantWrite: "GET / HTTP/1.1\r\n" +
+ "Host: www.google.com\r\n" +
+ "User-Agent: Go-http-client/1.1\r\n" +
+ "ALL-CAPS: x\r\n" +
+ "\r\n",
+ },
+
+ // Request with host header field; IPv6 address with zone identifier
+ 19: {
+ Req: Request{
+ Method: "GET",
+ URL: &url.URL{
+ Host: "[fe80::1%en0]",
+ },
+ },
+
+ WantWrite: "GET / HTTP/1.1\r\n" +
+ "Host: [fe80::1]\r\n" +
+ "User-Agent: Go-http-client/1.1\r\n" +
+ "\r\n",
+ },
+
+ // Request with optional host header field; IPv6 address with zone identifier
+ 20: {
+ Req: Request{
+ Method: "GET",
+ URL: &url.URL{
+ Host: "www.example.com",
+ },
+ Host: "[fe80::1%en0]:8080",
+ },
+
+ WantWrite: "GET / HTTP/1.1\r\n" +
+ "Host: [fe80::1]:8080\r\n" +
+ "User-Agent: Go-http-client/1.1\r\n" +
+ "\r\n",
+ },
+
+ // CONNECT without Opaque
+ 21: {
+ Req: Request{
+ Method: "CONNECT",
+ URL: &url.URL{
+ Scheme: "https", // of proxy.com
+ Host: "proxy.com",
+ },
+ },
+ // What we used to do, locking that behavior in:
+ WantWrite: "CONNECT proxy.com HTTP/1.1\r\n" +
+ "Host: proxy.com\r\n" +
+ "User-Agent: Go-http-client/1.1\r\n" +
+ "\r\n",
+ },
+
+ // CONNECT with Opaque
+ 22: {
+ Req: Request{
+ Method: "CONNECT",
+ URL: &url.URL{
+ Scheme: "https", // of proxy.com
+ Host: "proxy.com",
+ Opaque: "backend:443",
+ },
+ },
+ WantWrite: "CONNECT backend:443 HTTP/1.1\r\n" +
+ "Host: proxy.com\r\n" +
+ "User-Agent: Go-http-client/1.1\r\n" +
+ "\r\n",
+ },
+
+ // Verify that a nil header value doesn't get written.
+ 23: {
+ Req: Request{
+ Method: "GET",
+ URL: mustParseURL("/foo"),
+ Header: Header{
+ "X-Foo": []string{"X-Bar"},
+ "X-Idempotency-Key": nil,
+ },
+ },
+
+ WantWrite: "GET /foo HTTP/1.1\r\n" +
+ "Host: \r\n" +
+ "User-Agent: Go-http-client/1.1\r\n" +
+ "X-Foo: X-Bar\r\n\r\n",
+ },
+ 24: {
+ Req: Request{
+ Method: "GET",
+ URL: mustParseURL("/foo"),
+ Header: Header{
+ "X-Foo": []string{"X-Bar"},
+ "X-Idempotency-Key": []string{},
+ },
+ },
+
+ WantWrite: "GET /foo HTTP/1.1\r\n" +
+ "Host: \r\n" +
+ "User-Agent: Go-http-client/1.1\r\n" +
+ "X-Foo: X-Bar\r\n\r\n",
+ },
+
+ 25: {
+ Req: Request{
+ Method: "GET",
+ URL: &url.URL{
+ Host: "www.example.com",
+ RawQuery: "new\nline", // or any CTL
+ },
+ },
+ WantError: errors.New("net/http: can't write control character in Request.URL"),
+ },
+
+ 26: { // Request with nil body and PATCH method. Issue #40978
+ Req: Request{
+ Method: "PATCH",
+ URL: mustParseURL("/"),
+ Host: "example.com",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ ContentLength: 0, // as if unset by user
+ },
+ Body: nil,
+ WantWrite: "PATCH / HTTP/1.1\r\n" +
+ "Host: example.com\r\n" +
+ "User-Agent: Go-http-client/1.1\r\n" +
+ "Content-Length: 0\r\n\r\n",
+ WantProxy: "PATCH / HTTP/1.1\r\n" +
+ "Host: example.com\r\n" +
+ "User-Agent: Go-http-client/1.1\r\n" +
+ "Content-Length: 0\r\n\r\n",
+ },
+}
+
+func TestRequestWrite(t *testing.T) {
+ for i := range reqWriteTests {
+ tt := &reqWriteTests[i]
+
+ setBody := func() {
+ if tt.Body == nil {
+ return
+ }
+ switch b := tt.Body.(type) {
+ case []byte:
+ tt.Req.Body = io.NopCloser(bytes.NewReader(b))
+ case func() io.ReadCloser:
+ tt.Req.Body = b()
+ }
+ }
+ setBody()
+ if tt.Req.Header == nil {
+ tt.Req.Header = make(Header)
+ }
+
+ var braw strings.Builder
+ err := tt.Req.Write(&braw)
+ if g, e := fmt.Sprintf("%v", err), fmt.Sprintf("%v", tt.WantError); g != e {
+ t.Errorf("writing #%d, err = %q, want %q", i, g, e)
+ continue
+ }
+ if err != nil {
+ continue
+ }
+
+ if tt.WantWrite != "" {
+ sraw := braw.String()
+ if sraw != tt.WantWrite {
+ t.Errorf("Test %d, expecting:\n%s\nGot:\n%s\n", i, tt.WantWrite, sraw)
+ continue
+ }
+ }
+
+ if tt.WantProxy != "" {
+ setBody()
+ var praw strings.Builder
+ err = tt.Req.WriteProxy(&praw)
+ if err != nil {
+ t.Errorf("WriteProxy #%d: %s", i, err)
+ continue
+ }
+ sraw := praw.String()
+ if sraw != tt.WantProxy {
+ t.Errorf("Test Proxy %d, expecting:\n%s\nGot:\n%s\n", i, tt.WantProxy, sraw)
+ continue
+ }
+ }
+ }
+}
+
+func TestRequestWriteTransport(t *testing.T) {
+ t.Parallel()
+
+ matchSubstr := func(substr string) func(string) error {
+ return func(written string) error {
+ if !strings.Contains(written, substr) {
+ return fmt.Errorf("expected substring %q in request: %s", substr, written)
+ }
+ return nil
+ }
+ }
+
+ noContentLengthOrTransferEncoding := func(req string) error {
+ if strings.Contains(req, "Content-Length: ") {
+ return fmt.Errorf("unexpected Content-Length in request: %s", req)
+ }
+ if strings.Contains(req, "Transfer-Encoding: ") {
+ return fmt.Errorf("unexpected Transfer-Encoding in request: %s", req)
+ }
+ return nil
+ }
+
+ all := func(checks ...func(string) error) func(string) error {
+ return func(req string) error {
+ for _, c := range checks {
+ if err := c(req); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+ }
+
+ type testCase struct {
+ method string
+ clen int64 // ContentLength
+ body io.ReadCloser
+ want func(string) error
+
+ // optional:
+ init func(*testCase)
+ afterReqRead func()
+ }
+
+ tests := []testCase{
+ {
+ method: "GET",
+ want: noContentLengthOrTransferEncoding,
+ },
+ {
+ method: "GET",
+ body: io.NopCloser(strings.NewReader("")),
+ want: noContentLengthOrTransferEncoding,
+ },
+ {
+ method: "GET",
+ clen: -1,
+ body: io.NopCloser(strings.NewReader("")),
+ want: noContentLengthOrTransferEncoding,
+ },
+ // A GET with a body, with explicit content length:
+ {
+ method: "GET",
+ clen: 7,
+ body: io.NopCloser(strings.NewReader("foobody")),
+ want: all(matchSubstr("Content-Length: 7"),
+ matchSubstr("foobody")),
+ },
+ // A GET with a body, sniffing the leading "f" from "foobody".
+ {
+ method: "GET",
+ clen: -1,
+ body: io.NopCloser(strings.NewReader("foobody")),
+ want: all(matchSubstr("Transfer-Encoding: chunked"),
+ matchSubstr("\r\n1\r\nf\r\n"),
+ matchSubstr("oobody")),
+ },
+ // But a POST request is expected to have a body, so
+ // no sniffing happens:
+ {
+ method: "POST",
+ clen: -1,
+ body: io.NopCloser(strings.NewReader("foobody")),
+ want: all(matchSubstr("Transfer-Encoding: chunked"),
+ matchSubstr("foobody")),
+ },
+ {
+ method: "POST",
+ clen: -1,
+ body: io.NopCloser(strings.NewReader("")),
+ want: all(matchSubstr("Transfer-Encoding: chunked")),
+ },
+ // Verify that a blocking Request.Body doesn't block forever.
+ {
+ method: "GET",
+ clen: -1,
+ init: func(tt *testCase) {
+ pr, pw := io.Pipe()
+ tt.afterReqRead = func() {
+ pw.Close()
+ }
+ tt.body = io.NopCloser(pr)
+ },
+ want: matchSubstr("Transfer-Encoding: chunked"),
+ },
+ }
+
+ for i, tt := range tests {
+ if tt.init != nil {
+ tt.init(&tt)
+ }
+ req := &Request{
+ Method: tt.method,
+ URL: &url.URL{
+ Scheme: "http",
+ Host: "example.com",
+ },
+ Header: make(Header),
+ ContentLength: tt.clen,
+ Body: tt.body,
+ }
+ got, err := dumpRequestOut(req, tt.afterReqRead)
+ if err != nil {
+ t.Errorf("test[%d]: %v", i, err)
+ continue
+ }
+ if err := tt.want(string(got)); err != nil {
+ t.Errorf("test[%d]: %v", i, err)
+ }
+ }
+}
+
+type closeChecker struct {
+ io.Reader
+ closed bool
+}
+
+func (rc *closeChecker) Close() error {
+ rc.closed = true
+ return nil
+}
+
+// TestRequestWriteClosesBody tests that Request.Write closes its request.Body.
+// It also indirectly tests NewRequest and that it doesn't wrap an existing Closer
+// inside a NopCloser, and that it serializes it correctly.
+func TestRequestWriteClosesBody(t *testing.T) {
+ rc := &closeChecker{Reader: strings.NewReader("my body")}
+ req, err := NewRequest("POST", "http://foo.com/", rc)
+ if err != nil {
+ t.Fatal(err)
+ }
+ buf := new(strings.Builder)
+ if err := req.Write(buf); err != nil {
+ t.Error(err)
+ }
+ if !rc.closed {
+ t.Error("body not closed after write")
+ }
+ expected := "POST / HTTP/1.1\r\n" +
+ "Host: foo.com\r\n" +
+ "User-Agent: Go-http-client/1.1\r\n" +
+ "Transfer-Encoding: chunked\r\n\r\n" +
+ chunk("my body") +
+ chunk("")
+ if buf.String() != expected {
+ t.Errorf("write:\n got: %s\nwant: %s", buf.String(), expected)
+ }
+}
+
+func chunk(s string) string {
+ return fmt.Sprintf("%x\r\n%s\r\n", len(s), s)
+}
+
+func mustParseURL(s string) *url.URL {
+ u, err := url.Parse(s)
+ if err != nil {
+ panic(fmt.Sprintf("Error parsing URL %q: %v", s, err))
+ }
+ return u
+}
+
+type writerFunc func([]byte) (int, error)
+
+func (f writerFunc) Write(p []byte) (int, error) { return f(p) }
+
+// TestRequestWriteError tests the Write err != nil checks in (*Request).write.
+func TestRequestWriteError(t *testing.T) {
+ failAfter, writeCount := 0, 0
+ errFail := errors.New("fake write failure")
+
+ // w is the buffered io.Writer to write the request to. It
+ // fails exactly once on its Nth Write call, as controlled by
+ // failAfter. It also tracks the number of calls in
+ // writeCount.
+ w := struct {
+ io.ByteWriter // to avoid being wrapped by a bufio.Writer
+ io.Writer
+ }{
+ nil,
+ writerFunc(func(p []byte) (n int, err error) {
+ writeCount++
+ if failAfter == 0 {
+ err = errFail
+ }
+ failAfter--
+ return len(p), err
+ }),
+ }
+
+ req, _ := NewRequest("GET", "http://example.com/", nil)
+ const writeCalls = 4 // number of Write calls in current implementation
+ sawGood := false
+ for n := 0; n <= writeCalls+2; n++ {
+ failAfter = n
+ writeCount = 0
+ err := req.Write(w)
+ var wantErr error
+ if n < writeCalls {
+ wantErr = errFail
+ }
+ if err != wantErr {
+ t.Errorf("for fail-after %d Writes, err = %v; want %v", n, err, wantErr)
+ continue
+ }
+ if err == nil {
+ sawGood = true
+ if writeCount != writeCalls {
+ t.Fatalf("writeCalls constant is outdated in test")
+ }
+ }
+ if writeCount > writeCalls || writeCount > n+1 {
+ t.Errorf("for fail-after %d, saw unexpectedly high (%d) write calls", n, writeCount)
+ }
+ }
+ if !sawGood {
+ t.Fatalf("writeCalls constant is outdated in test")
+ }
+}
+
+// dumpRequestOut is a modified copy of net/http/httputil.DumpRequestOut.
+// Unlike the original, this version doesn't mutate the req.Body and
+// try to restore it. It always dumps the whole body.
+// And it doesn't support https.
+func dumpRequestOut(req *Request, onReadHeaders func()) ([]byte, error) {
+
+ // Use the actual Transport code to record what we would send
+ // on the wire, but not using TCP. Use a Transport with a
+ // custom dialer that returns a fake net.Conn that waits
+ // for the full input (and recording it), and then responds
+ // with a dummy response.
+ var buf bytes.Buffer // records the output
+ pr, pw := io.Pipe()
+ defer pr.Close()
+ defer pw.Close()
+ dr := &delegateReader{c: make(chan io.Reader)}
+
+ t := &Transport{
+ Dial: func(net, addr string) (net.Conn, error) {
+ return &dumpConn{io.MultiWriter(&buf, pw), dr}, nil
+ },
+ }
+ defer t.CloseIdleConnections()
+
+ // Wait for the request before replying with a dummy response:
+ go func() {
+ req, err := ReadRequest(bufio.NewReader(pr))
+ if err == nil {
+ if onReadHeaders != nil {
+ onReadHeaders()
+ }
+ // Ensure all the body is read; otherwise
+ // we'll get a partial dump.
+ io.Copy(io.Discard, req.Body)
+ req.Body.Close()
+ }
+ dr.c <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n")
+ }()
+
+ _, err := t.RoundTrip(req)
+ if err != nil {
+ return nil, err
+ }
+ return buf.Bytes(), nil
+}
+
+// delegateReader is a reader that delegates to another reader,
+// once it arrives on a channel.
+type delegateReader struct {
+ c chan io.Reader
+ r io.Reader // nil until received from c
+}
+
+func (r *delegateReader) Read(p []byte) (int, error) {
+ if r.r == nil {
+ r.r = <-r.c
+ }
+ return r.r.Read(p)
+}
+
+// dumpConn is a net.Conn that writes to Writer and reads from Reader.
+type dumpConn struct {
+ io.Writer
+ io.Reader
+}
+
+func (c *dumpConn) Close() error { return nil }
+func (c *dumpConn) LocalAddr() net.Addr { return nil }
+func (c *dumpConn) RemoteAddr() net.Addr { return nil }
+func (c *dumpConn) SetDeadline(t time.Time) error { return nil }
+func (c *dumpConn) SetReadDeadline(t time.Time) error { return nil }
+func (c *dumpConn) SetWriteDeadline(t time.Time) error { return nil }
diff --git a/src/net/http/response.go b/src/net/http/response.go
new file mode 100644
index 0000000..755c696
--- /dev/null
+++ b/src/net/http/response.go
@@ -0,0 +1,371 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// HTTP Response reading and parsing.
+
+package http
+
+import (
+ "bufio"
+ "bytes"
+ "crypto/tls"
+ "errors"
+ "fmt"
+ "io"
+ "net/textproto"
+ "net/url"
+ "strconv"
+ "strings"
+
+ "golang.org/x/net/http/httpguts"
+)
+
+var respExcludeHeader = map[string]bool{
+ "Content-Length": true,
+ "Transfer-Encoding": true,
+ "Trailer": true,
+}
+
+// Response represents the response from an HTTP request.
+//
+// The Client and Transport return Responses from servers once
+// the response headers have been received. The response body
+// is streamed on demand as the Body field is read.
+type Response struct {
+ Status string // e.g. "200 OK"
+ StatusCode int // e.g. 200
+ Proto string // e.g. "HTTP/1.0"
+ ProtoMajor int // e.g. 1
+ ProtoMinor int // e.g. 0
+
+ // Header maps header keys to values. If the response had multiple
+ // headers with the same key, they may be concatenated, with comma
+ // delimiters. (RFC 7230, section 3.2.2 requires that multiple headers
+ // be semantically equivalent to a comma-delimited sequence.) When
+ // Header values are duplicated by other fields in this struct (e.g.,
+ // ContentLength, TransferEncoding, Trailer), the field values are
+ // authoritative.
+ //
+ // Keys in the map are canonicalized (see CanonicalHeaderKey).
+ Header Header
+
+ // Body represents the response body.
+ //
+ // The response body is streamed on demand as the Body field
+ // is read. If the network connection fails or the server
+ // terminates the response, Body.Read calls return an error.
+ //
+ // The http Client and Transport guarantee that Body is always
+ // non-nil, even on responses without a body or responses with
+ // a zero-length body. It is the caller's responsibility to
+ // close Body. The default HTTP client's Transport may not
+ // reuse HTTP/1.x "keep-alive" TCP connections if the Body is
+ // not read to completion and closed.
+ //
+ // The Body is automatically dechunked if the server replied
+ // with a "chunked" Transfer-Encoding.
+ //
+ // As of Go 1.12, the Body will also implement io.Writer
+ // on a successful "101 Switching Protocols" response,
+ // as used by WebSockets and HTTP/2's "h2c" mode.
+ Body io.ReadCloser
+
+ // ContentLength records the length of the associated content. The
+ // value -1 indicates that the length is unknown. Unless Request.Method
+ // is "HEAD", values >= 0 indicate that the given number of bytes may
+ // be read from Body.
+ ContentLength int64
+
+ // Contains transfer encodings from outer-most to inner-most. Value is
+ // nil, means that "identity" encoding is used.
+ TransferEncoding []string
+
+ // Close records whether the header directed that the connection be
+ // closed after reading Body. The value is advice for clients: neither
+ // ReadResponse nor Response.Write ever closes a connection.
+ Close bool
+
+ // Uncompressed reports whether the response was sent compressed but
+ // was decompressed by the http package. When true, reading from
+ // Body yields the uncompressed content instead of the compressed
+ // content actually set from the server, ContentLength is set to -1,
+ // and the "Content-Length" and "Content-Encoding" fields are deleted
+ // from the responseHeader. To get the original response from
+ // the server, set Transport.DisableCompression to true.
+ Uncompressed bool
+
+ // Trailer maps trailer keys to values in the same
+ // format as Header.
+ //
+ // The Trailer initially contains only nil values, one for
+ // each key specified in the server's "Trailer" header
+ // value. Those values are not added to Header.
+ //
+ // Trailer must not be accessed concurrently with Read calls
+ // on the Body.
+ //
+ // After Body.Read has returned io.EOF, Trailer will contain
+ // any trailer values sent by the server.
+ Trailer Header
+
+ // Request is the request that was sent to obtain this Response.
+ // Request's Body is nil (having already been consumed).
+ // This is only populated for Client requests.
+ Request *Request
+
+ // TLS contains information about the TLS connection on which the
+ // response was received. It is nil for unencrypted responses.
+ // The pointer is shared between responses and should not be
+ // modified.
+ TLS *tls.ConnectionState
+}
+
+// Cookies parses and returns the cookies set in the Set-Cookie headers.
+func (r *Response) Cookies() []*Cookie {
+ return readSetCookies(r.Header)
+}
+
+// ErrNoLocation is returned by Response's Location method
+// when no Location header is present.
+var ErrNoLocation = errors.New("http: no Location header in response")
+
+// Location returns the URL of the response's "Location" header,
+// if present. Relative redirects are resolved relative to
+// the Response's Request. ErrNoLocation is returned if no
+// Location header is present.
+func (r *Response) Location() (*url.URL, error) {
+ lv := r.Header.Get("Location")
+ if lv == "" {
+ return nil, ErrNoLocation
+ }
+ if r.Request != nil && r.Request.URL != nil {
+ return r.Request.URL.Parse(lv)
+ }
+ return url.Parse(lv)
+}
+
+// ReadResponse reads and returns an HTTP response from r.
+// The req parameter optionally specifies the Request that corresponds
+// to this Response. If nil, a GET request is assumed.
+// Clients must call resp.Body.Close when finished reading resp.Body.
+// After that call, clients can inspect resp.Trailer to find key/value
+// pairs included in the response trailer.
+func ReadResponse(r *bufio.Reader, req *Request) (*Response, error) {
+ tp := textproto.NewReader(r)
+ resp := &Response{
+ Request: req,
+ }
+
+ // Parse the first line of the response.
+ line, err := tp.ReadLine()
+ if err != nil {
+ if err == io.EOF {
+ err = io.ErrUnexpectedEOF
+ }
+ return nil, err
+ }
+ proto, status, ok := strings.Cut(line, " ")
+ if !ok {
+ return nil, badStringError("malformed HTTP response", line)
+ }
+ resp.Proto = proto
+ resp.Status = strings.TrimLeft(status, " ")
+
+ statusCode, _, _ := strings.Cut(resp.Status, " ")
+ if len(statusCode) != 3 {
+ return nil, badStringError("malformed HTTP status code", statusCode)
+ }
+ resp.StatusCode, err = strconv.Atoi(statusCode)
+ if err != nil || resp.StatusCode < 0 {
+ return nil, badStringError("malformed HTTP status code", statusCode)
+ }
+ if resp.ProtoMajor, resp.ProtoMinor, ok = ParseHTTPVersion(resp.Proto); !ok {
+ return nil, badStringError("malformed HTTP version", resp.Proto)
+ }
+
+ // Parse the response headers.
+ mimeHeader, err := tp.ReadMIMEHeader()
+ if err != nil {
+ if err == io.EOF {
+ err = io.ErrUnexpectedEOF
+ }
+ return nil, err
+ }
+ resp.Header = Header(mimeHeader)
+
+ fixPragmaCacheControl(resp.Header)
+
+ err = readTransfer(resp, r)
+ if err != nil {
+ return nil, err
+ }
+
+ return resp, nil
+}
+
+// RFC 7234, section 5.4: Should treat
+//
+// Pragma: no-cache
+//
+// like
+//
+// Cache-Control: no-cache
+func fixPragmaCacheControl(header Header) {
+ if hp, ok := header["Pragma"]; ok && len(hp) > 0 && hp[0] == "no-cache" {
+ if _, presentcc := header["Cache-Control"]; !presentcc {
+ header["Cache-Control"] = []string{"no-cache"}
+ }
+ }
+}
+
+// ProtoAtLeast reports whether the HTTP protocol used
+// in the response is at least major.minor.
+func (r *Response) ProtoAtLeast(major, minor int) bool {
+ return r.ProtoMajor > major ||
+ r.ProtoMajor == major && r.ProtoMinor >= minor
+}
+
+// Write writes r to w in the HTTP/1.x server response format,
+// including the status line, headers, body, and optional trailer.
+//
+// This method consults the following fields of the response r:
+//
+// StatusCode
+// ProtoMajor
+// ProtoMinor
+// Request.Method
+// TransferEncoding
+// Trailer
+// Body
+// ContentLength
+// Header, values for non-canonical keys will have unpredictable behavior
+//
+// The Response Body is closed after it is sent.
+func (r *Response) Write(w io.Writer) error {
+ // Status line
+ text := r.Status
+ if text == "" {
+ text = StatusText(r.StatusCode)
+ if text == "" {
+ text = "status code " + strconv.Itoa(r.StatusCode)
+ }
+ } else {
+ // Just to reduce stutter, if user set r.Status to "200 OK" and StatusCode to 200.
+ // Not important.
+ text = strings.TrimPrefix(text, strconv.Itoa(r.StatusCode)+" ")
+ }
+
+ if _, err := fmt.Fprintf(w, "HTTP/%d.%d %03d %s\r\n", r.ProtoMajor, r.ProtoMinor, r.StatusCode, text); err != nil {
+ return err
+ }
+
+ // Clone it, so we can modify r1 as needed.
+ r1 := new(Response)
+ *r1 = *r
+ if r1.ContentLength == 0 && r1.Body != nil {
+ // Is it actually 0 length? Or just unknown?
+ var buf [1]byte
+ n, err := r1.Body.Read(buf[:])
+ if err != nil && err != io.EOF {
+ return err
+ }
+ if n == 0 {
+ // Reset it to a known zero reader, in case underlying one
+ // is unhappy being read repeatedly.
+ r1.Body = NoBody
+ } else {
+ r1.ContentLength = -1
+ r1.Body = struct {
+ io.Reader
+ io.Closer
+ }{
+ io.MultiReader(bytes.NewReader(buf[:1]), r.Body),
+ r.Body,
+ }
+ }
+ }
+ // If we're sending a non-chunked HTTP/1.1 response without a
+ // content-length, the only way to do that is the old HTTP/1.0
+ // way, by noting the EOF with a connection close, so we need
+ // to set Close.
+ if r1.ContentLength == -1 && !r1.Close && r1.ProtoAtLeast(1, 1) && !chunked(r1.TransferEncoding) && !r1.Uncompressed {
+ r1.Close = true
+ }
+
+ // Process Body,ContentLength,Close,Trailer
+ tw, err := newTransferWriter(r1)
+ if err != nil {
+ return err
+ }
+ err = tw.writeHeader(w, nil)
+ if err != nil {
+ return err
+ }
+
+ // Rest of header
+ err = r.Header.WriteSubset(w, respExcludeHeader)
+ if err != nil {
+ return err
+ }
+
+ // contentLengthAlreadySent may have been already sent for
+ // POST/PUT requests, even if zero length. See Issue 8180.
+ contentLengthAlreadySent := tw.shouldSendContentLength()
+ if r1.ContentLength == 0 && !chunked(r1.TransferEncoding) && !contentLengthAlreadySent && bodyAllowedForStatus(r.StatusCode) {
+ if _, err := io.WriteString(w, "Content-Length: 0\r\n"); err != nil {
+ return err
+ }
+ }
+
+ // End-of-header
+ if _, err := io.WriteString(w, "\r\n"); err != nil {
+ return err
+ }
+
+ // Write body and trailer
+ err = tw.writeBody(w)
+ if err != nil {
+ return err
+ }
+
+ // Success
+ return nil
+}
+
+func (r *Response) closeBody() {
+ if r.Body != nil {
+ r.Body.Close()
+ }
+}
+
+// bodyIsWritable reports whether the Body supports writing. The
+// Transport returns Writable bodies for 101 Switching Protocols
+// responses.
+// The Transport uses this method to determine whether a persistent
+// connection is done being managed from its perspective. Once we
+// return a writable response body to a user, the net/http package is
+// done managing that connection.
+func (r *Response) bodyIsWritable() bool {
+ _, ok := r.Body.(io.Writer)
+ return ok
+}
+
+// isProtocolSwitch reports whether the response code and header
+// indicate a successful protocol upgrade response.
+func (r *Response) isProtocolSwitch() bool {
+ return isProtocolSwitchResponse(r.StatusCode, r.Header)
+}
+
+// isProtocolSwitchResponse reports whether the response code and
+// response header indicate a successful protocol upgrade response.
+func isProtocolSwitchResponse(code int, h Header) bool {
+ return code == StatusSwitchingProtocols && isProtocolSwitchHeader(h)
+}
+
+// isProtocolSwitchHeader reports whether the request or response header
+// is for a protocol switch.
+func isProtocolSwitchHeader(h Header) bool {
+ return h.Get("Upgrade") != "" &&
+ httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade")
+}
diff --git a/src/net/http/response_test.go b/src/net/http/response_test.go
new file mode 100644
index 0000000..19fb48f
--- /dev/null
+++ b/src/net/http/response_test.go
@@ -0,0 +1,999 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "bufio"
+ "bytes"
+ "compress/gzip"
+ "crypto/rand"
+ "fmt"
+ "go/token"
+ "io"
+ "net/http/internal"
+ "net/url"
+ "reflect"
+ "regexp"
+ "strings"
+ "testing"
+)
+
+type respTest struct {
+ Raw string
+ Resp Response
+ Body string
+}
+
+func dummyReq(method string) *Request {
+ return &Request{Method: method}
+}
+
+func dummyReq11(method string) *Request {
+ return &Request{Method: method, Proto: "HTTP/1.1", ProtoMajor: 1, ProtoMinor: 1}
+}
+
+var respTests = []respTest{
+ // Unchunked response without Content-Length.
+ {
+ "HTTP/1.0 200 OK\r\n" +
+ "Connection: close\r\n" +
+ "\r\n" +
+ "Body here\n",
+
+ Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.0",
+ ProtoMajor: 1,
+ ProtoMinor: 0,
+ Request: dummyReq("GET"),
+ Header: Header{
+ "Connection": {"close"}, // TODO(rsc): Delete?
+ },
+ Close: true,
+ ContentLength: -1,
+ },
+
+ "Body here\n",
+ },
+
+ // Unchunked HTTP/1.1 response without Content-Length or
+ // Connection headers.
+ {
+ "HTTP/1.1 200 OK\r\n" +
+ "\r\n" +
+ "Body here\n",
+
+ Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{},
+ Request: dummyReq("GET"),
+ Close: true,
+ ContentLength: -1,
+ },
+
+ "Body here\n",
+ },
+
+ // Unchunked HTTP/1.1 204 response without Content-Length.
+ {
+ "HTTP/1.1 204 No Content\r\n" +
+ "\r\n" +
+ "Body should not be read!\n",
+
+ Response{
+ Status: "204 No Content",
+ StatusCode: 204,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: Header{},
+ Request: dummyReq("GET"),
+ Close: false,
+ ContentLength: 0,
+ },
+
+ "",
+ },
+
+ // Unchunked response with Content-Length.
+ {
+ "HTTP/1.0 200 OK\r\n" +
+ "Content-Length: 10\r\n" +
+ "Connection: close\r\n" +
+ "\r\n" +
+ "Body here\n",
+
+ Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.0",
+ ProtoMajor: 1,
+ ProtoMinor: 0,
+ Request: dummyReq("GET"),
+ Header: Header{
+ "Connection": {"close"},
+ "Content-Length": {"10"},
+ },
+ Close: true,
+ ContentLength: 10,
+ },
+
+ "Body here\n",
+ },
+
+ // Chunked response without Content-Length.
+ {
+ "HTTP/1.1 200 OK\r\n" +
+ "Transfer-Encoding: chunked\r\n" +
+ "\r\n" +
+ "0a\r\n" +
+ "Body here\n\r\n" +
+ "09\r\n" +
+ "continued\r\n" +
+ "0\r\n" +
+ "\r\n",
+
+ Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: dummyReq("GET"),
+ Header: Header{},
+ Close: false,
+ ContentLength: -1,
+ TransferEncoding: []string{"chunked"},
+ },
+
+ "Body here\ncontinued",
+ },
+
+ // Trailer header but no TransferEncoding
+ {
+ "HTTP/1.0 200 OK\r\n" +
+ "Trailer: Content-MD5, Content-Sources\r\n" +
+ "Content-Length: 10\r\n" +
+ "Connection: close\r\n" +
+ "\r\n" +
+ "Body here\n",
+
+ Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.0",
+ ProtoMajor: 1,
+ ProtoMinor: 0,
+ Request: dummyReq("GET"),
+ Header: Header{
+ "Connection": {"close"},
+ "Content-Length": {"10"},
+ "Trailer": []string{"Content-MD5, Content-Sources"},
+ },
+ Close: true,
+ ContentLength: 10,
+ },
+
+ "Body here\n",
+ },
+
+ // Chunked response with Content-Length.
+ {
+ "HTTP/1.1 200 OK\r\n" +
+ "Transfer-Encoding: chunked\r\n" +
+ "Content-Length: 10\r\n" +
+ "\r\n" +
+ "0a\r\n" +
+ "Body here\n\r\n" +
+ "0\r\n" +
+ "\r\n",
+
+ Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: dummyReq("GET"),
+ Header: Header{},
+ Close: false,
+ ContentLength: -1,
+ TransferEncoding: []string{"chunked"},
+ },
+
+ "Body here\n",
+ },
+
+ // Chunked response in response to a HEAD request
+ {
+ "HTTP/1.1 200 OK\r\n" +
+ "Transfer-Encoding: chunked\r\n" +
+ "\r\n",
+
+ Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: dummyReq("HEAD"),
+ Header: Header{},
+ TransferEncoding: []string{"chunked"},
+ Close: false,
+ ContentLength: -1,
+ },
+
+ "",
+ },
+
+ // Content-Length in response to a HEAD request
+ {
+ "HTTP/1.0 200 OK\r\n" +
+ "Content-Length: 256\r\n" +
+ "\r\n",
+
+ Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.0",
+ ProtoMajor: 1,
+ ProtoMinor: 0,
+ Request: dummyReq("HEAD"),
+ Header: Header{"Content-Length": {"256"}},
+ TransferEncoding: nil,
+ Close: true,
+ ContentLength: 256,
+ },
+
+ "",
+ },
+
+ // Content-Length in response to a HEAD request with HTTP/1.1
+ {
+ "HTTP/1.1 200 OK\r\n" +
+ "Content-Length: 256\r\n" +
+ "\r\n",
+
+ Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: dummyReq("HEAD"),
+ Header: Header{"Content-Length": {"256"}},
+ TransferEncoding: nil,
+ Close: false,
+ ContentLength: 256,
+ },
+
+ "",
+ },
+
+ // No Content-Length or Chunked in response to a HEAD request
+ {
+ "HTTP/1.0 200 OK\r\n" +
+ "\r\n",
+
+ Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.0",
+ ProtoMajor: 1,
+ ProtoMinor: 0,
+ Request: dummyReq("HEAD"),
+ Header: Header{},
+ TransferEncoding: nil,
+ Close: true,
+ ContentLength: -1,
+ },
+
+ "",
+ },
+
+ // explicit Content-Length of 0.
+ {
+ "HTTP/1.1 200 OK\r\n" +
+ "Content-Length: 0\r\n" +
+ "\r\n",
+
+ Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: dummyReq("GET"),
+ Header: Header{
+ "Content-Length": {"0"},
+ },
+ Close: false,
+ ContentLength: 0,
+ },
+
+ "",
+ },
+
+ // Status line without a Reason-Phrase, but trailing space.
+ // (permitted by RFC 7230, section 3.1.2)
+ {
+ "HTTP/1.0 303 \r\n\r\n",
+ Response{
+ Status: "303 ",
+ StatusCode: 303,
+ Proto: "HTTP/1.0",
+ ProtoMajor: 1,
+ ProtoMinor: 0,
+ Request: dummyReq("GET"),
+ Header: Header{},
+ Close: true,
+ ContentLength: -1,
+ },
+
+ "",
+ },
+
+ // Status line without a Reason-Phrase, and no trailing space.
+ // (not permitted by RFC 7230, but we'll accept it anyway)
+ {
+ "HTTP/1.0 303\r\n\r\n",
+ Response{
+ Status: "303",
+ StatusCode: 303,
+ Proto: "HTTP/1.0",
+ ProtoMajor: 1,
+ ProtoMinor: 0,
+ Request: dummyReq("GET"),
+ Header: Header{},
+ Close: true,
+ ContentLength: -1,
+ },
+
+ "",
+ },
+
+ // golang.org/issue/4767: don't special-case multipart/byteranges responses
+ {
+ `HTTP/1.1 206 Partial Content
+Connection: close
+Content-Type: multipart/byteranges; boundary=18a75608c8f47cef
+
+some body`,
+ Response{
+ Status: "206 Partial Content",
+ StatusCode: 206,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: dummyReq("GET"),
+ Header: Header{
+ "Content-Type": []string{"multipart/byteranges; boundary=18a75608c8f47cef"},
+ },
+ Close: true,
+ ContentLength: -1,
+ },
+
+ "some body",
+ },
+
+ // Unchunked response without Content-Length, Request is nil
+ {
+ "HTTP/1.0 200 OK\r\n" +
+ "Connection: close\r\n" +
+ "\r\n" +
+ "Body here\n",
+
+ Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.0",
+ ProtoMajor: 1,
+ ProtoMinor: 0,
+ Header: Header{
+ "Connection": {"close"}, // TODO(rsc): Delete?
+ },
+ Close: true,
+ ContentLength: -1,
+ },
+
+ "Body here\n",
+ },
+
+ // 206 Partial Content. golang.org/issue/8923
+ {
+ "HTTP/1.1 206 Partial Content\r\n" +
+ "Content-Type: text/plain; charset=utf-8\r\n" +
+ "Accept-Ranges: bytes\r\n" +
+ "Content-Range: bytes 0-5/1862\r\n" +
+ "Content-Length: 6\r\n\r\n" +
+ "foobar",
+
+ Response{
+ Status: "206 Partial Content",
+ StatusCode: 206,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: dummyReq("GET"),
+ Header: Header{
+ "Accept-Ranges": []string{"bytes"},
+ "Content-Length": []string{"6"},
+ "Content-Type": []string{"text/plain; charset=utf-8"},
+ "Content-Range": []string{"bytes 0-5/1862"},
+ },
+ ContentLength: 6,
+ },
+
+ "foobar",
+ },
+
+ // Both keep-alive and close, on the same Connection line. (Issue 8840)
+ {
+ "HTTP/1.1 200 OK\r\n" +
+ "Content-Length: 256\r\n" +
+ "Connection: keep-alive, close\r\n" +
+ "\r\n",
+
+ Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: dummyReq("HEAD"),
+ Header: Header{
+ "Content-Length": {"256"},
+ },
+ TransferEncoding: nil,
+ Close: true,
+ ContentLength: 256,
+ },
+
+ "",
+ },
+
+ // Both keep-alive and close, on different Connection lines. (Issue 8840)
+ {
+ "HTTP/1.1 200 OK\r\n" +
+ "Content-Length: 256\r\n" +
+ "Connection: keep-alive\r\n" +
+ "Connection: close\r\n" +
+ "\r\n",
+
+ Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: dummyReq("HEAD"),
+ Header: Header{
+ "Content-Length": {"256"},
+ },
+ TransferEncoding: nil,
+ Close: true,
+ ContentLength: 256,
+ },
+
+ "",
+ },
+
+ // Issue 12785: HTTP/1.0 response with bogus (to be ignored) Transfer-Encoding.
+ // Without a Content-Length.
+ {
+ "HTTP/1.0 200 OK\r\n" +
+ "Transfer-Encoding: bogus\r\n" +
+ "\r\n" +
+ "Body here\n",
+
+ Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.0",
+ ProtoMajor: 1,
+ ProtoMinor: 0,
+ Request: dummyReq("GET"),
+ Header: Header{},
+ Close: true,
+ ContentLength: -1,
+ },
+
+ "Body here\n",
+ },
+
+ // Issue 12785: HTTP/1.0 response with bogus (to be ignored) Transfer-Encoding.
+ // With a Content-Length.
+ {
+ "HTTP/1.0 200 OK\r\n" +
+ "Transfer-Encoding: bogus\r\n" +
+ "Content-Length: 10\r\n" +
+ "\r\n" +
+ "Body here\n",
+
+ Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.0",
+ ProtoMajor: 1,
+ ProtoMinor: 0,
+ Request: dummyReq("GET"),
+ Header: Header{
+ "Content-Length": {"10"},
+ },
+ Close: true,
+ ContentLength: 10,
+ },
+
+ "Body here\n",
+ },
+
+ {
+ "HTTP/1.1 200 OK\r\n" +
+ "Content-Encoding: gzip\r\n" +
+ "Content-Length: 23\r\n" +
+ "Connection: keep-alive\r\n" +
+ "Keep-Alive: timeout=7200\r\n\r\n" +
+ "\x1f\x8b\b\x00\x00\x00\x00\x00\x00\x00s\xf3\xf7\a\x00\xab'\xd4\x1a\x03\x00\x00\x00",
+ Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: dummyReq("GET"),
+ Header: Header{
+ "Content-Length": {"23"},
+ "Content-Encoding": {"gzip"},
+ "Connection": {"keep-alive"},
+ "Keep-Alive": {"timeout=7200"},
+ },
+ Close: false,
+ ContentLength: 23,
+ },
+ "\x1f\x8b\b\x00\x00\x00\x00\x00\x00\x00s\xf3\xf7\a\x00\xab'\xd4\x1a\x03\x00\x00\x00",
+ },
+
+ // Issue 19989: two spaces between HTTP version and status.
+ {
+ "HTTP/1.0 401 Unauthorized\r\n" +
+ "Content-type: text/html\r\n" +
+ "WWW-Authenticate: Basic realm=\"\"\r\n\r\n" +
+ "Your Authentication failed.\r\n",
+ Response{
+ Status: "401 Unauthorized",
+ StatusCode: 401,
+ Proto: "HTTP/1.0",
+ ProtoMajor: 1,
+ ProtoMinor: 0,
+ Request: dummyReq("GET"),
+ Header: Header{
+ "Content-Type": {"text/html"},
+ "Www-Authenticate": {`Basic realm=""`},
+ },
+ Close: true,
+ ContentLength: -1,
+ },
+ "Your Authentication failed.\r\n",
+ },
+}
+
+// tests successful calls to ReadResponse, and inspects the returned Response.
+// For error cases, see TestReadResponseErrors below.
+func TestReadResponse(t *testing.T) {
+ for i, tt := range respTests {
+ resp, err := ReadResponse(bufio.NewReader(strings.NewReader(tt.Raw)), tt.Resp.Request)
+ if err != nil {
+ t.Errorf("#%d: %v", i, err)
+ continue
+ }
+ rbody := resp.Body
+ resp.Body = nil
+ diff(t, fmt.Sprintf("#%d Response", i), resp, &tt.Resp)
+ var bout strings.Builder
+ if rbody != nil {
+ _, err = io.Copy(&bout, rbody)
+ if err != nil {
+ t.Errorf("#%d: %v", i, err)
+ continue
+ }
+ rbody.Close()
+ }
+ body := bout.String()
+ if body != tt.Body {
+ t.Errorf("#%d: Body = %q want %q", i, body, tt.Body)
+ }
+ }
+}
+
+func TestWriteResponse(t *testing.T) {
+ for i, tt := range respTests {
+ resp, err := ReadResponse(bufio.NewReader(strings.NewReader(tt.Raw)), tt.Resp.Request)
+ if err != nil {
+ t.Errorf("#%d: %v", i, err)
+ continue
+ }
+ err = resp.Write(io.Discard)
+ if err != nil {
+ t.Errorf("#%d: %v", i, err)
+ continue
+ }
+ }
+}
+
+var readResponseCloseInMiddleTests = []struct {
+ chunked, compressed bool
+}{
+ {false, false},
+ {true, false},
+ {true, true},
+}
+
+type readerAndCloser struct {
+ io.Reader
+ io.Closer
+}
+
+// TestReadResponseCloseInMiddle tests that closing a body after
+// reading only part of its contents advances the read to the end of
+// the request, right up until the next request.
+func TestReadResponseCloseInMiddle(t *testing.T) {
+ t.Parallel()
+ for _, test := range readResponseCloseInMiddleTests {
+ fatalf := func(format string, args ...any) {
+ args = append([]any{test.chunked, test.compressed}, args...)
+ t.Fatalf("on test chunked=%v, compressed=%v: "+format, args...)
+ }
+ checkErr := func(err error, msg string) {
+ if err == nil {
+ return
+ }
+ fatalf(msg+": %v", err)
+ }
+ var buf bytes.Buffer
+ buf.WriteString("HTTP/1.1 200 OK\r\n")
+ if test.chunked {
+ buf.WriteString("Transfer-Encoding: chunked\r\n")
+ } else {
+ buf.WriteString("Content-Length: 1000000\r\n")
+ }
+ var wr io.Writer = &buf
+ if test.chunked {
+ wr = internal.NewChunkedWriter(wr)
+ }
+ if test.compressed {
+ buf.WriteString("Content-Encoding: gzip\r\n")
+ wr = gzip.NewWriter(wr)
+ }
+ buf.WriteString("\r\n")
+
+ chunk := bytes.Repeat([]byte{'x'}, 1000)
+ for i := 0; i < 1000; i++ {
+ if test.compressed {
+ // Otherwise this compresses too well.
+ _, err := io.ReadFull(rand.Reader, chunk)
+ checkErr(err, "rand.Reader ReadFull")
+ }
+ wr.Write(chunk)
+ }
+ if test.compressed {
+ err := wr.(*gzip.Writer).Close()
+ checkErr(err, "compressor close")
+ }
+ if test.chunked {
+ buf.WriteString("0\r\n\r\n")
+ }
+ buf.WriteString("Next Request Here")
+
+ bufr := bufio.NewReader(&buf)
+ resp, err := ReadResponse(bufr, dummyReq("GET"))
+ checkErr(err, "ReadResponse")
+ expectedLength := int64(-1)
+ if !test.chunked {
+ expectedLength = 1000000
+ }
+ if resp.ContentLength != expectedLength {
+ fatalf("expected response length %d, got %d", expectedLength, resp.ContentLength)
+ }
+ if resp.Body == nil {
+ fatalf("nil body")
+ }
+ if test.compressed {
+ gzReader, err := gzip.NewReader(resp.Body)
+ checkErr(err, "gzip.NewReader")
+ resp.Body = &readerAndCloser{gzReader, resp.Body}
+ }
+
+ rbuf := make([]byte, 2500)
+ n, err := io.ReadFull(resp.Body, rbuf)
+ checkErr(err, "2500 byte ReadFull")
+ if n != 2500 {
+ fatalf("ReadFull only read %d bytes", n)
+ }
+ if test.compressed == false && !bytes.Equal(bytes.Repeat([]byte{'x'}, 2500), rbuf) {
+ fatalf("ReadFull didn't read 2500 'x'; got %q", string(rbuf))
+ }
+ resp.Body.Close()
+
+ rest, err := io.ReadAll(bufr)
+ checkErr(err, "ReadAll on remainder")
+ if e, g := "Next Request Here", string(rest); e != g {
+ g = regexp.MustCompile(`(xx+)`).ReplaceAllStringFunc(g, func(match string) string {
+ return fmt.Sprintf("x(repeated x%d)", len(match))
+ })
+ fatalf("remainder = %q, expected %q", g, e)
+ }
+ }
+}
+
+func diff(t *testing.T, prefix string, have, want any) {
+ t.Helper()
+ hv := reflect.ValueOf(have).Elem()
+ wv := reflect.ValueOf(want).Elem()
+ if hv.Type() != wv.Type() {
+ t.Errorf("%s: type mismatch %v want %v", prefix, hv.Type(), wv.Type())
+ }
+ for i := 0; i < hv.NumField(); i++ {
+ name := hv.Type().Field(i).Name
+ if !token.IsExported(name) {
+ continue
+ }
+ hf := hv.Field(i).Interface()
+ wf := wv.Field(i).Interface()
+ if !reflect.DeepEqual(hf, wf) {
+ t.Errorf("%s: %s = %v want %v", prefix, name, hf, wf)
+ }
+ }
+}
+
+type responseLocationTest struct {
+ location string // Response's Location header or ""
+ requrl string // Response.Request.URL or ""
+ want string
+ wantErr error
+}
+
+var responseLocationTests = []responseLocationTest{
+ {"/foo", "http://bar.com/baz", "http://bar.com/foo", nil},
+ {"http://foo.com/", "http://bar.com/baz", "http://foo.com/", nil},
+ {"", "http://bar.com/baz", "", ErrNoLocation},
+ {"/bar", "", "/bar", nil},
+}
+
+func TestLocationResponse(t *testing.T) {
+ for i, tt := range responseLocationTests {
+ res := new(Response)
+ res.Header = make(Header)
+ res.Header.Set("Location", tt.location)
+ if tt.requrl != "" {
+ res.Request = &Request{}
+ var err error
+ res.Request.URL, err = url.Parse(tt.requrl)
+ if err != nil {
+ t.Fatalf("bad test URL %q: %v", tt.requrl, err)
+ }
+ }
+
+ got, err := res.Location()
+ if tt.wantErr != nil {
+ if err == nil {
+ t.Errorf("%d. err=nil; want %q", i, tt.wantErr)
+ continue
+ }
+ if g, e := err.Error(), tt.wantErr.Error(); g != e {
+ t.Errorf("%d. err=%q; want %q", i, g, e)
+ continue
+ }
+ continue
+ }
+ if err != nil {
+ t.Errorf("%d. err=%q", i, err)
+ continue
+ }
+ if g, e := got.String(), tt.want; g != e {
+ t.Errorf("%d. Location=%q; want %q", i, g, e)
+ }
+ }
+}
+
+func TestResponseStatusStutter(t *testing.T) {
+ r := &Response{
+ Status: "123 some status",
+ StatusCode: 123,
+ ProtoMajor: 1,
+ ProtoMinor: 3,
+ }
+ var buf strings.Builder
+ r.Write(&buf)
+ if strings.Contains(buf.String(), "123 123") {
+ t.Errorf("stutter in status: %s", buf.String())
+ }
+}
+
+func TestResponseContentLengthShortBody(t *testing.T) {
+ const shortBody = "Short body, not 123 bytes."
+ br := bufio.NewReader(strings.NewReader("HTTP/1.1 200 OK\r\n" +
+ "Content-Length: 123\r\n" +
+ "\r\n" +
+ shortBody))
+ res, err := ReadResponse(br, &Request{Method: "GET"})
+ if err != nil {
+ t.Fatal(err)
+ }
+ if res.ContentLength != 123 {
+ t.Fatalf("Content-Length = %d; want 123", res.ContentLength)
+ }
+ var buf strings.Builder
+ n, err := io.Copy(&buf, res.Body)
+ if n != int64(len(shortBody)) {
+ t.Errorf("Copied %d bytes; want %d, len(%q)", n, len(shortBody), shortBody)
+ }
+ if buf.String() != shortBody {
+ t.Errorf("Read body %q; want %q", buf.String(), shortBody)
+ }
+ if err != io.ErrUnexpectedEOF {
+ t.Errorf("io.Copy error = %#v; want io.ErrUnexpectedEOF", err)
+ }
+}
+
+// Test various ReadResponse error cases. (also tests success cases, but mostly
+// it's about errors). This does not test anything involving the bodies. Only
+// the return value from ReadResponse itself.
+func TestReadResponseErrors(t *testing.T) {
+ type testCase struct {
+ name string // optional, defaults to in
+ in string
+ wantErr any // nil, err value, or string substring
+ }
+
+ status := func(s string, wantErr any) testCase {
+ if wantErr == true {
+ wantErr = "malformed HTTP status code"
+ }
+ return testCase{
+ name: fmt.Sprintf("status %q", s),
+ in: "HTTP/1.1 " + s + "\r\nFoo: bar\r\n\r\n",
+ wantErr: wantErr,
+ }
+ }
+
+ version := func(s string, wantErr any) testCase {
+ if wantErr == true {
+ wantErr = "malformed HTTP version"
+ }
+ return testCase{
+ name: fmt.Sprintf("version %q", s),
+ in: s + " 200 OK\r\n\r\n",
+ wantErr: wantErr,
+ }
+ }
+
+ contentLength := func(status, body string, wantErr any) testCase {
+ return testCase{
+ name: fmt.Sprintf("status %q %q", status, body),
+ in: fmt.Sprintf("HTTP/1.1 %s\r\n%s", status, body),
+ wantErr: wantErr,
+ }
+ }
+
+ errMultiCL := "message cannot contain multiple Content-Length headers"
+
+ tests := []testCase{
+ {"", "", io.ErrUnexpectedEOF},
+ {"", "HTTP/1.1 301 Moved Permanently\r\nFoo: bar", io.ErrUnexpectedEOF},
+ {"", "HTTP/1.1", "malformed HTTP response"},
+ {"", "HTTP/2.0", "malformed HTTP response"},
+ status("20X Unknown", true),
+ status("abcd Unknown", true),
+ status("二百/两百 OK", true),
+ status(" Unknown", true),
+ status("c8 OK", true),
+ status("0x12d Moved Permanently", true),
+ status("200 OK", nil),
+ status("000 OK", nil),
+ status("001 OK", nil),
+ status("404 NOTFOUND", nil),
+ status("20 OK", true),
+ status("00 OK", true),
+ status("-10 OK", true),
+ status("1000 OK", true),
+ status("999 Done", nil),
+ status("-1 OK", true),
+ status("-200 OK", true),
+ version("HTTP/1.2", nil),
+ version("HTTP/2.0", nil),
+ version("HTTP/1.100000000002", true),
+ version("HTTP/1.-1", true),
+ version("HTTP/A.B", true),
+ version("HTTP/1", true),
+ version("http/1.1", true),
+
+ contentLength("200 OK", "Content-Length: 10\r\nContent-Length: 7\r\n\r\nGopher hey\r\n", errMultiCL),
+ contentLength("200 OK", "Content-Length: 7\r\nContent-Length: 7\r\n\r\nGophers\r\n", nil),
+ contentLength("201 OK", "Content-Length: 0\r\nContent-Length: 7\r\n\r\nGophers\r\n", errMultiCL),
+ contentLength("300 OK", "Content-Length: 0\r\nContent-Length: 0 \r\n\r\nGophers\r\n", nil),
+ contentLength("200 OK", "Content-Length:\r\nContent-Length:\r\n\r\nGophers\r\n", nil),
+ contentLength("206 OK", "Content-Length:\r\nContent-Length: 0 \r\nConnection: close\r\n\r\nGophers\r\n", errMultiCL),
+
+ // multiple content-length headers for 204 and 304 should still be checked
+ contentLength("204 OK", "Content-Length: 7\r\nContent-Length: 8\r\n\r\n", errMultiCL),
+ contentLength("204 OK", "Content-Length: 3\r\nContent-Length: 3\r\n\r\n", nil),
+ contentLength("304 OK", "Content-Length: 880\r\nContent-Length: 1\r\n\r\n", errMultiCL),
+ contentLength("304 OK", "Content-Length: 961\r\nContent-Length: 961\r\n\r\n", nil),
+
+ // golang.org/issue/22464
+ {"leading space in header", "HTTP/1.1 200 OK\r\n Content-type: text/html\r\nFoo: bar\r\n\r\n", "malformed MIME"},
+ {"leading tab in header", "HTTP/1.1 200 OK\r\n\tContent-type: text/html\r\nFoo: bar\r\n\r\n", "malformed MIME"},
+ }
+
+ for i, tt := range tests {
+ br := bufio.NewReader(strings.NewReader(tt.in))
+ _, rerr := ReadResponse(br, nil)
+ if err := matchErr(rerr, tt.wantErr); err != nil {
+ name := tt.name
+ if name == "" {
+ name = fmt.Sprintf("%d. input %q", i, tt.in)
+ }
+ t.Errorf("%s: %v", name, err)
+ }
+ }
+}
+
+// wantErr can be nil, an error value to match exactly, or type string to
+// match a substring.
+func matchErr(err error, wantErr any) error {
+ if err == nil {
+ if wantErr == nil {
+ return nil
+ }
+ if sub, ok := wantErr.(string); ok {
+ return fmt.Errorf("unexpected success; want error with substring %q", sub)
+ }
+ return fmt.Errorf("unexpected success; want error %v", wantErr)
+ }
+ if wantErr == nil {
+ return fmt.Errorf("%v; want success", err)
+ }
+ if sub, ok := wantErr.(string); ok {
+ if strings.Contains(err.Error(), sub) {
+ return nil
+ }
+ return fmt.Errorf("error = %v; want an error with substring %q", err, sub)
+ }
+ if err == wantErr {
+ return nil
+ }
+ return fmt.Errorf("%v; want %v", err, wantErr)
+}
+
+// A response should only write out single Connection: close header. Tests #19499.
+func TestResponseWritesOnlySingleConnectionClose(t *testing.T) {
+ const connectionCloseHeader = "Connection: close"
+
+ res, err := ReadResponse(bufio.NewReader(strings.NewReader("HTTP/1.0 200 OK\r\n\r\nAAAA")), nil)
+ if err != nil {
+ t.Fatalf("ReadResponse failed %v", err)
+ }
+
+ var buf1 bytes.Buffer
+ if err = res.Write(&buf1); err != nil {
+ t.Fatalf("Write failed %v", err)
+ }
+ if res, err = ReadResponse(bufio.NewReader(&buf1), nil); err != nil {
+ t.Fatalf("ReadResponse failed %v", err)
+ }
+
+ var buf2 strings.Builder
+ if err = res.Write(&buf2); err != nil {
+ t.Fatalf("Write failed %v", err)
+ }
+ if count := strings.Count(buf2.String(), connectionCloseHeader); count != 1 {
+ t.Errorf("Found %d %q header", count, connectionCloseHeader)
+ }
+}
diff --git a/src/net/http/responsecontroller.go b/src/net/http/responsecontroller.go
new file mode 100644
index 0000000..92276ff
--- /dev/null
+++ b/src/net/http/responsecontroller.go
@@ -0,0 +1,147 @@
+// Copyright 2022 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "bufio"
+ "fmt"
+ "net"
+ "time"
+)
+
+// A ResponseController is used by an HTTP handler to control the response.
+//
+// A ResponseController may not be used after the Handler.ServeHTTP method has returned.
+type ResponseController struct {
+ rw ResponseWriter
+}
+
+// NewResponseController creates a ResponseController for a request.
+//
+// The ResponseWriter should be the original value passed to the Handler.ServeHTTP method,
+// or have an Unwrap method returning the original ResponseWriter.
+//
+// If the ResponseWriter implements any of the following methods, the ResponseController
+// will call them as appropriate:
+//
+// Flush()
+// FlushError() error // alternative Flush returning an error
+// Hijack() (net.Conn, *bufio.ReadWriter, error)
+// SetReadDeadline(deadline time.Time) error
+// SetWriteDeadline(deadline time.Time) error
+// EnableFullDuplex() error
+//
+// If the ResponseWriter does not support a method, ResponseController returns
+// an error matching ErrNotSupported.
+func NewResponseController(rw ResponseWriter) *ResponseController {
+ return &ResponseController{rw}
+}
+
+type rwUnwrapper interface {
+ Unwrap() ResponseWriter
+}
+
+// Flush flushes buffered data to the client.
+func (c *ResponseController) Flush() error {
+ rw := c.rw
+ for {
+ switch t := rw.(type) {
+ case interface{ FlushError() error }:
+ return t.FlushError()
+ case Flusher:
+ t.Flush()
+ return nil
+ case rwUnwrapper:
+ rw = t.Unwrap()
+ default:
+ return errNotSupported()
+ }
+ }
+}
+
+// Hijack lets the caller take over the connection.
+// See the Hijacker interface for details.
+func (c *ResponseController) Hijack() (net.Conn, *bufio.ReadWriter, error) {
+ rw := c.rw
+ for {
+ switch t := rw.(type) {
+ case Hijacker:
+ return t.Hijack()
+ case rwUnwrapper:
+ rw = t.Unwrap()
+ default:
+ return nil, nil, errNotSupported()
+ }
+ }
+}
+
+// SetReadDeadline sets the deadline for reading the entire request, including the body.
+// Reads from the request body after the deadline has been exceeded will return an error.
+// A zero value means no deadline.
+//
+// Setting the read deadline after it has been exceeded will not extend it.
+func (c *ResponseController) SetReadDeadline(deadline time.Time) error {
+ rw := c.rw
+ for {
+ switch t := rw.(type) {
+ case interface{ SetReadDeadline(time.Time) error }:
+ return t.SetReadDeadline(deadline)
+ case rwUnwrapper:
+ rw = t.Unwrap()
+ default:
+ return errNotSupported()
+ }
+ }
+}
+
+// SetWriteDeadline sets the deadline for writing the response.
+// Writes to the response body after the deadline has been exceeded will not block,
+// but may succeed if the data has been buffered.
+// A zero value means no deadline.
+//
+// Setting the write deadline after it has been exceeded will not extend it.
+func (c *ResponseController) SetWriteDeadline(deadline time.Time) error {
+ rw := c.rw
+ for {
+ switch t := rw.(type) {
+ case interface{ SetWriteDeadline(time.Time) error }:
+ return t.SetWriteDeadline(deadline)
+ case rwUnwrapper:
+ rw = t.Unwrap()
+ default:
+ return errNotSupported()
+ }
+ }
+}
+
+// EnableFullDuplex indicates that the request handler will interleave reads from Request.Body
+// with writes to the ResponseWriter.
+//
+// For HTTP/1 requests, the Go HTTP server by default consumes any unread portion of
+// the request body before beginning to write the response, preventing handlers from
+// concurrently reading from the request and writing the response.
+// Calling EnableFullDuplex disables this behavior and permits handlers to continue to read
+// from the request while concurrently writing the response.
+//
+// For HTTP/2 requests, the Go HTTP server always permits concurrent reads and responses.
+func (c *ResponseController) EnableFullDuplex() error {
+ rw := c.rw
+ for {
+ switch t := rw.(type) {
+ case interface{ EnableFullDuplex() error }:
+ return t.EnableFullDuplex()
+ case rwUnwrapper:
+ rw = t.Unwrap()
+ default:
+ return errNotSupported()
+ }
+ }
+}
+
+// errNotSupported returns an error that Is ErrNotSupported,
+// but is not == to it.
+func errNotSupported() error {
+ return fmt.Errorf("%w", ErrNotSupported)
+}
diff --git a/src/net/http/responsecontroller_test.go b/src/net/http/responsecontroller_test.go
new file mode 100644
index 0000000..5828f37
--- /dev/null
+++ b/src/net/http/responsecontroller_test.go
@@ -0,0 +1,324 @@
+package http_test
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ . "net/http"
+ "os"
+ "sync"
+ "testing"
+ "time"
+)
+
+func TestResponseControllerFlush(t *testing.T) { run(t, testResponseControllerFlush) }
+func testResponseControllerFlush(t *testing.T, mode testMode) {
+ continuec := make(chan struct{})
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ ctl := NewResponseController(w)
+ w.Write([]byte("one"))
+ if err := ctl.Flush(); err != nil {
+ t.Errorf("ctl.Flush() = %v, want nil", err)
+ return
+ }
+ <-continuec
+ w.Write([]byte("two"))
+ }))
+
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Fatalf("unexpected connection error: %v", err)
+ }
+ defer res.Body.Close()
+
+ buf := make([]byte, 16)
+ n, err := res.Body.Read(buf)
+ close(continuec)
+ if err != nil || string(buf[:n]) != "one" {
+ t.Fatalf("Body.Read = %q, %v, want %q, nil", string(buf[:n]), err, "one")
+ }
+
+ got, err := io.ReadAll(res.Body)
+ if err != nil || string(got) != "two" {
+ t.Fatalf("Body.Read = %q, %v, want %q, nil", string(got), err, "two")
+ }
+}
+
+func TestResponseControllerHijack(t *testing.T) { run(t, testResponseControllerHijack) }
+func testResponseControllerHijack(t *testing.T, mode testMode) {
+ const header = "X-Header"
+ const value = "set"
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ ctl := NewResponseController(w)
+ c, _, err := ctl.Hijack()
+ if mode == http2Mode {
+ if err == nil {
+ t.Errorf("ctl.Hijack = nil, want error")
+ }
+ w.Header().Set(header, value)
+ return
+ }
+ if err != nil {
+ t.Errorf("ctl.Hijack = _, _, %v, want _, _, nil", err)
+ return
+ }
+ fmt.Fprintf(c, "HTTP/1.0 200 OK\r\n%v: %v\r\nContent-Length: 0\r\n\r\n", header, value)
+ }))
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got, want := res.Header.Get(header), value; got != want {
+ t.Errorf("response header %q = %q, want %q", header, got, want)
+ }
+}
+
+func TestResponseControllerSetPastWriteDeadline(t *testing.T) {
+ run(t, testResponseControllerSetPastWriteDeadline)
+}
+func testResponseControllerSetPastWriteDeadline(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ ctl := NewResponseController(w)
+ w.Write([]byte("one"))
+ if err := ctl.Flush(); err != nil {
+ t.Errorf("before setting deadline: ctl.Flush() = %v, want nil", err)
+ }
+ if err := ctl.SetWriteDeadline(time.Now().Add(-10 * time.Second)); err != nil {
+ t.Errorf("ctl.SetWriteDeadline() = %v, want nil", err)
+ }
+
+ w.Write([]byte("two"))
+ if err := ctl.Flush(); err == nil {
+ t.Errorf("after setting deadline: ctl.Flush() = nil, want non-nil")
+ }
+ // Connection errors are sticky, so resetting the deadline does not permit
+ // making more progress. We might want to change this in the future, but verify
+ // the current behavior for now. If we do change this, we'll want to make sure
+ // to do so only for writing the response body, not headers.
+ if err := ctl.SetWriteDeadline(time.Now().Add(1 * time.Hour)); err != nil {
+ t.Errorf("ctl.SetWriteDeadline() = %v, want nil", err)
+ }
+ w.Write([]byte("three"))
+ if err := ctl.Flush(); err == nil {
+ t.Errorf("after resetting deadline: ctl.Flush() = nil, want non-nil")
+ }
+ }))
+
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Fatalf("unexpected connection error: %v", err)
+ }
+ defer res.Body.Close()
+ b, _ := io.ReadAll(res.Body)
+ if string(b) != "one" {
+ t.Errorf("unexpected body: %q", string(b))
+ }
+}
+
+func TestResponseControllerSetFutureWriteDeadline(t *testing.T) {
+ run(t, testResponseControllerSetFutureWriteDeadline)
+}
+func testResponseControllerSetFutureWriteDeadline(t *testing.T, mode testMode) {
+ errc := make(chan error, 1)
+ startwritec := make(chan struct{})
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ ctl := NewResponseController(w)
+ w.WriteHeader(200)
+ if err := ctl.Flush(); err != nil {
+ t.Errorf("ctl.Flush() = %v, want nil", err)
+ }
+ <-startwritec // don't set the deadline until the client reads response headers
+ if err := ctl.SetWriteDeadline(time.Now().Add(1 * time.Millisecond)); err != nil {
+ t.Errorf("ctl.SetWriteDeadline() = %v, want nil", err)
+ }
+ _, err := io.Copy(w, neverEnding('a'))
+ errc <- err
+ }))
+
+ res, err := cst.c.Get(cst.ts.URL)
+ close(startwritec)
+ if err != nil {
+ t.Fatalf("unexpected connection error: %v", err)
+ }
+ defer res.Body.Close()
+ _, err = io.Copy(io.Discard, res.Body)
+ if err == nil {
+ t.Errorf("client reading from truncated request body: got nil error, want non-nil")
+ }
+ err = <-errc // io.Copy error
+ if !errors.Is(err, os.ErrDeadlineExceeded) {
+ t.Errorf("server timed out writing request body: got err %v; want os.ErrDeadlineExceeded", err)
+ }
+}
+
+func TestResponseControllerSetPastReadDeadline(t *testing.T) {
+ run(t, testResponseControllerSetPastReadDeadline)
+}
+func testResponseControllerSetPastReadDeadline(t *testing.T, mode testMode) {
+ readc := make(chan struct{})
+ donec := make(chan struct{})
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ defer close(donec)
+ ctl := NewResponseController(w)
+ b := make([]byte, 3)
+ n, err := io.ReadFull(r.Body, b)
+ b = b[:n]
+ if err != nil || string(b) != "one" {
+ t.Errorf("before setting read deadline: Read = %v, %q, want nil, %q", err, string(b), "one")
+ return
+ }
+ if err := ctl.SetReadDeadline(time.Now()); err != nil {
+ t.Errorf("ctl.SetReadDeadline() = %v, want nil", err)
+ return
+ }
+ b, err = io.ReadAll(r.Body)
+ if err == nil || string(b) != "" {
+ t.Errorf("after setting read deadline: Read = %q, nil, want error", string(b))
+ }
+ close(readc)
+ // Connection errors are sticky, so resetting the deadline does not permit
+ // making more progress. We might want to change this in the future, but verify
+ // the current behavior for now.
+ if err := ctl.SetReadDeadline(time.Time{}); err != nil {
+ t.Errorf("ctl.SetReadDeadline() = %v, want nil", err)
+ return
+ }
+ b, err = io.ReadAll(r.Body)
+ if err == nil {
+ t.Errorf("after resetting read deadline: Read = %q, nil, want error", string(b))
+ }
+ }))
+
+ pr, pw := io.Pipe()
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ defer pw.Close()
+ pw.Write([]byte("one"))
+ select {
+ case <-readc:
+ case <-donec:
+ select {
+ case <-readc:
+ default:
+ t.Errorf("server handler unexpectedly exited without closing readc")
+ return
+ }
+ }
+ pw.Write([]byte("two"))
+ }()
+ defer wg.Wait()
+ res, err := cst.c.Post(cst.ts.URL, "text/foo", pr)
+ if err == nil {
+ defer res.Body.Close()
+ }
+}
+
+func TestResponseControllerSetFutureReadDeadline(t *testing.T) {
+ run(t, testResponseControllerSetFutureReadDeadline)
+}
+func testResponseControllerSetFutureReadDeadline(t *testing.T, mode testMode) {
+ respBody := "response body"
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) {
+ ctl := NewResponseController(w)
+ if err := ctl.SetReadDeadline(time.Now().Add(1 * time.Millisecond)); err != nil {
+ t.Errorf("ctl.SetReadDeadline() = %v, want nil", err)
+ }
+ _, err := io.Copy(io.Discard, req.Body)
+ if !errors.Is(err, os.ErrDeadlineExceeded) {
+ t.Errorf("server timed out reading request body: got err %v; want os.ErrDeadlineExceeded", err)
+ }
+ w.Write([]byte(respBody))
+ }))
+ pr, pw := io.Pipe()
+ res, err := cst.c.Post(cst.ts.URL, "text/apocryphal", pr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ got, err := io.ReadAll(res.Body)
+ if string(got) != respBody || err != nil {
+ t.Errorf("client read response body: %q, %v; want %q, nil", string(got), err, respBody)
+ }
+ pw.Close()
+}
+
+type wrapWriter struct {
+ ResponseWriter
+}
+
+func (w wrapWriter) Unwrap() ResponseWriter {
+ return w.ResponseWriter
+}
+
+func TestWrappedResponseController(t *testing.T) { run(t, testWrappedResponseController) }
+func testWrappedResponseController(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w = wrapWriter{w}
+ ctl := NewResponseController(w)
+ if err := ctl.Flush(); err != nil {
+ t.Errorf("ctl.Flush() = %v, want nil", err)
+ }
+ if err := ctl.SetReadDeadline(time.Time{}); err != nil {
+ t.Errorf("ctl.SetReadDeadline() = %v, want nil", err)
+ }
+ if err := ctl.SetWriteDeadline(time.Time{}); err != nil {
+ t.Errorf("ctl.SetWriteDeadline() = %v, want nil", err)
+ }
+ }))
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Fatalf("unexpected connection error: %v", err)
+ }
+ io.Copy(io.Discard, res.Body)
+ defer res.Body.Close()
+}
+
+func TestResponseControllerEnableFullDuplex(t *testing.T) {
+ run(t, testResponseControllerEnableFullDuplex)
+}
+func testResponseControllerEnableFullDuplex(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) {
+ ctl := NewResponseController(w)
+ if err := ctl.EnableFullDuplex(); err != nil {
+ // TODO: Drop test for HTTP/2 when x/net is updated to support
+ // EnableFullDuplex. Since HTTP/2 supports full duplex by default,
+ // the rest of the test is fine; it's just the EnableFullDuplex call
+ // that fails.
+ if mode != http2Mode {
+ t.Errorf("ctl.EnableFullDuplex() = %v, want nil", err)
+ }
+ }
+ w.WriteHeader(200)
+ ctl.Flush()
+ for {
+ var buf [1]byte
+ n, err := req.Body.Read(buf[:])
+ if n != 1 || err != nil {
+ break
+ }
+ w.Write(buf[:])
+ ctl.Flush()
+ }
+ }))
+ pr, pw := io.Pipe()
+ res, err := cst.c.Post(cst.ts.URL, "text/apocryphal", pr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ for i := byte(0); i < 10; i++ {
+ if _, err := pw.Write([]byte{i}); err != nil {
+ t.Fatalf("Write: %v", err)
+ }
+ var buf [1]byte
+ if n, err := res.Body.Read(buf[:]); n != 1 || err != nil {
+ t.Fatalf("Read: %v, %v", n, err)
+ }
+ if buf[0] != i {
+ t.Fatalf("read byte %v, want %v", buf[0], i)
+ }
+ }
+ pw.Close()
+}
diff --git a/src/net/http/responsewrite_test.go b/src/net/http/responsewrite_test.go
new file mode 100644
index 0000000..226ad72
--- /dev/null
+++ b/src/net/http/responsewrite_test.go
@@ -0,0 +1,290 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "io"
+ "strings"
+ "testing"
+)
+
+type respWriteTest struct {
+ Resp Response
+ Raw string
+}
+
+func TestResponseWrite(t *testing.T) {
+ respWriteTests := []respWriteTest{
+ // HTTP/1.0, identity coding; no trailer
+ {
+ Response{
+ StatusCode: 503,
+ ProtoMajor: 1,
+ ProtoMinor: 0,
+ Request: dummyReq("GET"),
+ Header: Header{},
+ Body: io.NopCloser(strings.NewReader("abcdef")),
+ ContentLength: 6,
+ },
+
+ "HTTP/1.0 503 Service Unavailable\r\n" +
+ "Content-Length: 6\r\n\r\n" +
+ "abcdef",
+ },
+ // Unchunked response without Content-Length.
+ {
+ Response{
+ StatusCode: 200,
+ ProtoMajor: 1,
+ ProtoMinor: 0,
+ Request: dummyReq("GET"),
+ Header: Header{},
+ Body: io.NopCloser(strings.NewReader("abcdef")),
+ ContentLength: -1,
+ },
+ "HTTP/1.0 200 OK\r\n" +
+ "\r\n" +
+ "abcdef",
+ },
+ // HTTP/1.1 response with unknown length and Connection: close
+ {
+ Response{
+ StatusCode: 200,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: dummyReq("GET"),
+ Header: Header{},
+ Body: io.NopCloser(strings.NewReader("abcdef")),
+ ContentLength: -1,
+ Close: true,
+ },
+ "HTTP/1.1 200 OK\r\n" +
+ "Connection: close\r\n" +
+ "\r\n" +
+ "abcdef",
+ },
+ // HTTP/1.1 response with unknown length and not setting connection: close
+ {
+ Response{
+ StatusCode: 200,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: dummyReq11("GET"),
+ Header: Header{},
+ Body: io.NopCloser(strings.NewReader("abcdef")),
+ ContentLength: -1,
+ Close: false,
+ },
+ "HTTP/1.1 200 OK\r\n" +
+ "Connection: close\r\n" +
+ "\r\n" +
+ "abcdef",
+ },
+ // HTTP/1.1 response with unknown length and not setting connection: close, but
+ // setting chunked.
+ {
+ Response{
+ StatusCode: 200,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: dummyReq11("GET"),
+ Header: Header{},
+ Body: io.NopCloser(strings.NewReader("abcdef")),
+ ContentLength: -1,
+ TransferEncoding: []string{"chunked"},
+ Close: false,
+ },
+ "HTTP/1.1 200 OK\r\n" +
+ "Transfer-Encoding: chunked\r\n\r\n" +
+ "6\r\nabcdef\r\n0\r\n\r\n",
+ },
+ // HTTP/1.1 response 0 content-length, and nil body
+ {
+ Response{
+ StatusCode: 200,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: dummyReq11("GET"),
+ Header: Header{},
+ Body: nil,
+ ContentLength: 0,
+ Close: false,
+ },
+ "HTTP/1.1 200 OK\r\n" +
+ "Content-Length: 0\r\n" +
+ "\r\n",
+ },
+ // HTTP/1.1 response 0 content-length, and non-nil empty body
+ {
+ Response{
+ StatusCode: 200,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: dummyReq11("GET"),
+ Header: Header{},
+ Body: io.NopCloser(strings.NewReader("")),
+ ContentLength: 0,
+ Close: false,
+ },
+ "HTTP/1.1 200 OK\r\n" +
+ "Content-Length: 0\r\n" +
+ "\r\n",
+ },
+ // HTTP/1.1 response 0 content-length, and non-nil non-empty body
+ {
+ Response{
+ StatusCode: 200,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: dummyReq11("GET"),
+ Header: Header{},
+ Body: io.NopCloser(strings.NewReader("foo")),
+ ContentLength: 0,
+ Close: false,
+ },
+ "HTTP/1.1 200 OK\r\n" +
+ "Connection: close\r\n" +
+ "\r\nfoo",
+ },
+ // HTTP/1.1, chunked coding; empty trailer; close
+ {
+ Response{
+ StatusCode: 200,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: dummyReq("GET"),
+ Header: Header{},
+ Body: io.NopCloser(strings.NewReader("abcdef")),
+ ContentLength: 6,
+ TransferEncoding: []string{"chunked"},
+ Close: true,
+ },
+
+ "HTTP/1.1 200 OK\r\n" +
+ "Connection: close\r\n" +
+ "Transfer-Encoding: chunked\r\n\r\n" +
+ "6\r\nabcdef\r\n0\r\n\r\n",
+ },
+
+ // Header value with a newline character (Issue 914).
+ // Also tests removal of leading and trailing whitespace.
+ {
+ Response{
+ StatusCode: 204,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: dummyReq("GET"),
+ Header: Header{
+ "Foo": []string{" Bar\nBaz "},
+ },
+ Body: nil,
+ ContentLength: 0,
+ TransferEncoding: []string{"chunked"},
+ Close: true,
+ },
+
+ "HTTP/1.1 204 No Content\r\n" +
+ "Connection: close\r\n" +
+ "Foo: Bar Baz\r\n" +
+ "\r\n",
+ },
+
+ // Want a single Content-Length header. Fixing issue 8180 where
+ // there were two.
+ {
+ Response{
+ StatusCode: StatusOK,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: &Request{Method: "POST"},
+ Header: Header{},
+ ContentLength: 0,
+ TransferEncoding: nil,
+ Body: nil,
+ },
+ "HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n",
+ },
+
+ // When a response to a POST has Content-Length: -1, make sure we don't
+ // write the Content-Length as -1.
+ {
+ Response{
+ StatusCode: StatusOK,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Request: &Request{Method: "POST"},
+ Header: Header{},
+ ContentLength: -1,
+ Body: io.NopCloser(strings.NewReader("abcdef")),
+ },
+ "HTTP/1.1 200 OK\r\nConnection: close\r\n\r\nabcdef",
+ },
+
+ // Status code under 100 should be zero-padded to
+ // three digits. Still bogus, but less bogus. (be
+ // consistent with generating three digits, since the
+ // Transport requires it)
+ {
+ Response{
+ StatusCode: 7,
+ Status: "license to violate specs",
+ ProtoMajor: 1,
+ ProtoMinor: 0,
+ Request: dummyReq("GET"),
+ Header: Header{},
+ Body: nil,
+ },
+
+ "HTTP/1.0 007 license to violate specs\r\nContent-Length: 0\r\n\r\n",
+ },
+
+ // No stutter. Status code in 1xx range response should
+ // not include a Content-Length header. See issue #16942.
+ {
+ Response{
+ StatusCode: 123,
+ Status: "123 Sesame Street",
+ ProtoMajor: 1,
+ ProtoMinor: 0,
+ Request: dummyReq("GET"),
+ Header: Header{},
+ Body: nil,
+ },
+
+ "HTTP/1.0 123 Sesame Street\r\n\r\n",
+ },
+
+ // Status code 204 (No content) response should not include a
+ // Content-Length header. See issue #16942.
+ {
+ Response{
+ StatusCode: 204,
+ Status: "No Content",
+ ProtoMajor: 1,
+ ProtoMinor: 0,
+ Request: dummyReq("GET"),
+ Header: Header{},
+ Body: nil,
+ },
+
+ "HTTP/1.0 204 No Content\r\n\r\n",
+ },
+ }
+
+ for i := range respWriteTests {
+ tt := &respWriteTests[i]
+ var braw strings.Builder
+ err := tt.Resp.Write(&braw)
+ if err != nil {
+ t.Errorf("error writing #%d: %s", i, err)
+ continue
+ }
+ sraw := braw.String()
+ if sraw != tt.Raw {
+ t.Errorf("Test %d, expecting:\n%q\nGot:\n%q\n", i, tt.Raw, sraw)
+ continue
+ }
+ }
+}
diff --git a/src/net/http/roundtrip.go b/src/net/http/roundtrip.go
new file mode 100644
index 0000000..49ea1a7
--- /dev/null
+++ b/src/net/http/roundtrip.go
@@ -0,0 +1,18 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !js
+
+package http
+
+// RoundTrip implements the RoundTripper interface.
+//
+// For higher-level HTTP client support (such as handling of cookies
+// and redirects), see Get, Post, and the Client type.
+//
+// Like the RoundTripper interface, the error types returned
+// by RoundTrip are unspecified.
+func (t *Transport) RoundTrip(req *Request) (*Response, error) {
+ return t.roundTrip(req)
+}
diff --git a/src/net/http/roundtrip_js.go b/src/net/http/roundtrip_js.go
new file mode 100644
index 0000000..9f9f0cb
--- /dev/null
+++ b/src/net/http/roundtrip_js.go
@@ -0,0 +1,360 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build js && wasm
+
+package http
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "strconv"
+ "strings"
+ "syscall/js"
+)
+
+var uint8Array = js.Global().Get("Uint8Array")
+
+// jsFetchMode is a Request.Header map key that, if present,
+// signals that the map entry is actually an option to the Fetch API mode setting.
+// Valid values are: "cors", "no-cors", "same-origin", "navigate"
+// The default is "same-origin".
+//
+// Reference: https://developer.mozilla.org/en-US/docs/Web/API/WindowOrWorkerGlobalScope/fetch#Parameters
+const jsFetchMode = "js.fetch:mode"
+
+// jsFetchCreds is a Request.Header map key that, if present,
+// signals that the map entry is actually an option to the Fetch API credentials setting.
+// Valid values are: "omit", "same-origin", "include"
+// The default is "same-origin".
+//
+// Reference: https://developer.mozilla.org/en-US/docs/Web/API/WindowOrWorkerGlobalScope/fetch#Parameters
+const jsFetchCreds = "js.fetch:credentials"
+
+// jsFetchRedirect is a Request.Header map key that, if present,
+// signals that the map entry is actually an option to the Fetch API redirect setting.
+// Valid values are: "follow", "error", "manual"
+// The default is "follow".
+//
+// Reference: https://developer.mozilla.org/en-US/docs/Web/API/WindowOrWorkerGlobalScope/fetch#Parameters
+const jsFetchRedirect = "js.fetch:redirect"
+
+// jsFetchMissing will be true if the Fetch API is not present in
+// the browser globals.
+var jsFetchMissing = js.Global().Get("fetch").IsUndefined()
+
+// jsFetchDisabled controls whether the use of Fetch API is disabled.
+// It's set to true when we detect we're running in Node.js, so that
+// RoundTrip ends up talking over the same fake network the HTTP servers
+// currently use in various tests and examples. See go.dev/issue/57613.
+//
+// TODO(go.dev/issue/60810): See if it's viable to test the Fetch API
+// code path.
+var jsFetchDisabled = js.Global().Get("process").Type() == js.TypeObject &&
+ strings.HasPrefix(js.Global().Get("process").Get("argv0").String(), "node")
+
+// RoundTrip implements the RoundTripper interface using the WHATWG Fetch API.
+func (t *Transport) RoundTrip(req *Request) (*Response, error) {
+ // The Transport has a documented contract that states that if the DialContext or
+ // DialTLSContext functions are set, they will be used to set up the connections.
+ // If they aren't set then the documented contract is to use Dial or DialTLS, even
+ // though they are deprecated. Therefore, if any of these are set, we should obey
+ // the contract and dial using the regular round-trip instead. Otherwise, we'll try
+ // to fall back on the Fetch API, unless it's not available.
+ if t.Dial != nil || t.DialContext != nil || t.DialTLS != nil || t.DialTLSContext != nil || jsFetchMissing || jsFetchDisabled {
+ return t.roundTrip(req)
+ }
+
+ ac := js.Global().Get("AbortController")
+ if !ac.IsUndefined() {
+ // Some browsers that support WASM don't necessarily support
+ // the AbortController. See
+ // https://developer.mozilla.org/en-US/docs/Web/API/AbortController#Browser_compatibility.
+ ac = ac.New()
+ }
+
+ opt := js.Global().Get("Object").New()
+ // See https://developer.mozilla.org/en-US/docs/Web/API/WindowOrWorkerGlobalScope/fetch
+ // for options available.
+ opt.Set("method", req.Method)
+ opt.Set("credentials", "same-origin")
+ if h := req.Header.Get(jsFetchCreds); h != "" {
+ opt.Set("credentials", h)
+ req.Header.Del(jsFetchCreds)
+ }
+ if h := req.Header.Get(jsFetchMode); h != "" {
+ opt.Set("mode", h)
+ req.Header.Del(jsFetchMode)
+ }
+ if h := req.Header.Get(jsFetchRedirect); h != "" {
+ opt.Set("redirect", h)
+ req.Header.Del(jsFetchRedirect)
+ }
+ if !ac.IsUndefined() {
+ opt.Set("signal", ac.Get("signal"))
+ }
+ headers := js.Global().Get("Headers").New()
+ for key, values := range req.Header {
+ for _, value := range values {
+ headers.Call("append", key, value)
+ }
+ }
+ opt.Set("headers", headers)
+
+ if req.Body != nil {
+ // TODO(johanbrandhorst): Stream request body when possible.
+ // See https://bugs.chromium.org/p/chromium/issues/detail?id=688906 for Blink issue.
+ // See https://bugzilla.mozilla.org/show_bug.cgi?id=1387483 for Firefox issue.
+ // See https://github.com/web-platform-tests/wpt/issues/7693 for WHATWG tests issue.
+ // See https://developer.mozilla.org/en-US/docs/Web/API/Streams_API for more details on the Streams API
+ // and browser support.
+ // NOTE(haruyama480): Ensure HTTP/1 fallback exists.
+ // See https://go.dev/issue/61889 for discussion.
+ body, err := io.ReadAll(req.Body)
+ if err != nil {
+ req.Body.Close() // RoundTrip must always close the body, including on errors.
+ return nil, err
+ }
+ req.Body.Close()
+ if len(body) != 0 {
+ buf := uint8Array.New(len(body))
+ js.CopyBytesToJS(buf, body)
+ opt.Set("body", buf)
+ }
+ }
+
+ fetchPromise := js.Global().Call("fetch", req.URL.String(), opt)
+ var (
+ respCh = make(chan *Response, 1)
+ errCh = make(chan error, 1)
+ success, failure js.Func
+ )
+ success = js.FuncOf(func(this js.Value, args []js.Value) any {
+ success.Release()
+ failure.Release()
+
+ result := args[0]
+ header := Header{}
+ // https://developer.mozilla.org/en-US/docs/Web/API/Headers/entries
+ headersIt := result.Get("headers").Call("entries")
+ for {
+ n := headersIt.Call("next")
+ if n.Get("done").Bool() {
+ break
+ }
+ pair := n.Get("value")
+ key, value := pair.Index(0).String(), pair.Index(1).String()
+ ck := CanonicalHeaderKey(key)
+ header[ck] = append(header[ck], value)
+ }
+
+ contentLength := int64(0)
+ clHeader := header.Get("Content-Length")
+ switch {
+ case clHeader != "":
+ cl, err := strconv.ParseInt(clHeader, 10, 64)
+ if err != nil {
+ errCh <- fmt.Errorf("net/http: ill-formed Content-Length header: %v", err)
+ return nil
+ }
+ if cl < 0 {
+ // Content-Length values less than 0 are invalid.
+ // See: https://datatracker.ietf.org/doc/html/rfc2616/#section-14.13
+ errCh <- fmt.Errorf("net/http: invalid Content-Length header: %q", clHeader)
+ return nil
+ }
+ contentLength = cl
+ default:
+ // If the response length is not declared, set it to -1.
+ contentLength = -1
+ }
+
+ b := result.Get("body")
+ var body io.ReadCloser
+ // The body is undefined when the browser does not support streaming response bodies (Firefox),
+ // and null in certain error cases, i.e. when the request is blocked because of CORS settings.
+ if !b.IsUndefined() && !b.IsNull() {
+ body = &streamReader{stream: b.Call("getReader")}
+ } else {
+ // Fall back to using ArrayBuffer
+ // https://developer.mozilla.org/en-US/docs/Web/API/Body/arrayBuffer
+ body = &arrayReader{arrayPromise: result.Call("arrayBuffer")}
+ }
+
+ code := result.Get("status").Int()
+ respCh <- &Response{
+ Status: fmt.Sprintf("%d %s", code, StatusText(code)),
+ StatusCode: code,
+ Header: header,
+ ContentLength: contentLength,
+ Body: body,
+ Request: req,
+ }
+
+ return nil
+ })
+ failure = js.FuncOf(func(this js.Value, args []js.Value) any {
+ success.Release()
+ failure.Release()
+
+ err := args[0]
+ // The error is a JS Error type
+ // https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Error
+ // We can use the toString() method to get a string representation of the error.
+ errMsg := err.Call("toString").String()
+ // Errors can optionally contain a cause.
+ if cause := err.Get("cause"); !cause.IsUndefined() {
+ // The exact type of the cause is not defined,
+ // but if it's another error, we can call toString() on it too.
+ if !cause.Get("toString").IsUndefined() {
+ errMsg += ": " + cause.Call("toString").String()
+ } else if cause.Type() == js.TypeString {
+ errMsg += ": " + cause.String()
+ }
+ }
+ errCh <- fmt.Errorf("net/http: fetch() failed: %s", errMsg)
+ return nil
+ })
+
+ fetchPromise.Call("then", success, failure)
+ select {
+ case <-req.Context().Done():
+ if !ac.IsUndefined() {
+ // Abort the Fetch request.
+ ac.Call("abort")
+ }
+ return nil, req.Context().Err()
+ case resp := <-respCh:
+ return resp, nil
+ case err := <-errCh:
+ return nil, err
+ }
+}
+
+var errClosed = errors.New("net/http: reader is closed")
+
+// streamReader implements an io.ReadCloser wrapper for ReadableStream.
+// See https://fetch.spec.whatwg.org/#readablestream for more information.
+type streamReader struct {
+ pending []byte
+ stream js.Value
+ err error // sticky read error
+}
+
+func (r *streamReader) Read(p []byte) (n int, err error) {
+ if r.err != nil {
+ return 0, r.err
+ }
+ if len(r.pending) == 0 {
+ var (
+ bCh = make(chan []byte, 1)
+ errCh = make(chan error, 1)
+ )
+ success := js.FuncOf(func(this js.Value, args []js.Value) any {
+ result := args[0]
+ if result.Get("done").Bool() {
+ errCh <- io.EOF
+ return nil
+ }
+ value := make([]byte, result.Get("value").Get("byteLength").Int())
+ js.CopyBytesToGo(value, result.Get("value"))
+ bCh <- value
+ return nil
+ })
+ defer success.Release()
+ failure := js.FuncOf(func(this js.Value, args []js.Value) any {
+ // Assumes it's a TypeError. See
+ // https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/TypeError
+ // for more information on this type. See
+ // https://streams.spec.whatwg.org/#byob-reader-read for the spec on
+ // the read method.
+ errCh <- errors.New(args[0].Get("message").String())
+ return nil
+ })
+ defer failure.Release()
+ r.stream.Call("read").Call("then", success, failure)
+ select {
+ case b := <-bCh:
+ r.pending = b
+ case err := <-errCh:
+ r.err = err
+ return 0, err
+ }
+ }
+ n = copy(p, r.pending)
+ r.pending = r.pending[n:]
+ return n, nil
+}
+
+func (r *streamReader) Close() error {
+ // This ignores any error returned from cancel method. So far, I did not encounter any concrete
+ // situation where reporting the error is meaningful. Most users ignore error from resp.Body.Close().
+ // If there's a need to report error here, it can be implemented and tested when that need comes up.
+ r.stream.Call("cancel")
+ if r.err == nil {
+ r.err = errClosed
+ }
+ return nil
+}
+
+// arrayReader implements an io.ReadCloser wrapper for ArrayBuffer.
+// https://developer.mozilla.org/en-US/docs/Web/API/Body/arrayBuffer.
+type arrayReader struct {
+ arrayPromise js.Value
+ pending []byte
+ read bool
+ err error // sticky read error
+}
+
+func (r *arrayReader) Read(p []byte) (n int, err error) {
+ if r.err != nil {
+ return 0, r.err
+ }
+ if !r.read {
+ r.read = true
+ var (
+ bCh = make(chan []byte, 1)
+ errCh = make(chan error, 1)
+ )
+ success := js.FuncOf(func(this js.Value, args []js.Value) any {
+ // Wrap the input ArrayBuffer with a Uint8Array
+ uint8arrayWrapper := uint8Array.New(args[0])
+ value := make([]byte, uint8arrayWrapper.Get("byteLength").Int())
+ js.CopyBytesToGo(value, uint8arrayWrapper)
+ bCh <- value
+ return nil
+ })
+ defer success.Release()
+ failure := js.FuncOf(func(this js.Value, args []js.Value) any {
+ // Assumes it's a TypeError. See
+ // https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/TypeError
+ // for more information on this type.
+ // See https://fetch.spec.whatwg.org/#concept-body-consume-body for reasons this might error.
+ errCh <- errors.New(args[0].Get("message").String())
+ return nil
+ })
+ defer failure.Release()
+ r.arrayPromise.Call("then", success, failure)
+ select {
+ case b := <-bCh:
+ r.pending = b
+ case err := <-errCh:
+ return 0, err
+ }
+ }
+ if len(r.pending) == 0 {
+ return 0, io.EOF
+ }
+ n = copy(p, r.pending)
+ r.pending = r.pending[n:]
+ return n, nil
+}
+
+func (r *arrayReader) Close() error {
+ if r.err == nil {
+ r.err = errClosed
+ }
+ return nil
+}
diff --git a/src/net/http/serve_test.go b/src/net/http/serve_test.go
new file mode 100644
index 0000000..bb380cf
--- /dev/null
+++ b/src/net/http/serve_test.go
@@ -0,0 +1,6870 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// End-to-end serving tests
+
+package http_test
+
+import (
+ "bufio"
+ "bytes"
+ "compress/gzip"
+ "compress/zlib"
+ "context"
+ "crypto/tls"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "internal/testenv"
+ "io"
+ "log"
+ "math/rand"
+ "mime/multipart"
+ "net"
+ . "net/http"
+ "net/http/httptest"
+ "net/http/httptrace"
+ "net/http/httputil"
+ "net/http/internal"
+ "net/http/internal/testcert"
+ "net/url"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "reflect"
+ "regexp"
+ "runtime"
+ "strconv"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "syscall"
+ "testing"
+ "time"
+)
+
+type dummyAddr string
+type oneConnListener struct {
+ conn net.Conn
+}
+
+func (l *oneConnListener) Accept() (c net.Conn, err error) {
+ c = l.conn
+ if c == nil {
+ err = io.EOF
+ return
+ }
+ err = nil
+ l.conn = nil
+ return
+}
+
+func (l *oneConnListener) Close() error {
+ return nil
+}
+
+func (l *oneConnListener) Addr() net.Addr {
+ return dummyAddr("test-address")
+}
+
+func (a dummyAddr) Network() string {
+ return string(a)
+}
+
+func (a dummyAddr) String() string {
+ return string(a)
+}
+
+type noopConn struct{}
+
+func (noopConn) LocalAddr() net.Addr { return dummyAddr("local-addr") }
+func (noopConn) RemoteAddr() net.Addr { return dummyAddr("remote-addr") }
+func (noopConn) SetDeadline(t time.Time) error { return nil }
+func (noopConn) SetReadDeadline(t time.Time) error { return nil }
+func (noopConn) SetWriteDeadline(t time.Time) error { return nil }
+
+type rwTestConn struct {
+ io.Reader
+ io.Writer
+ noopConn
+
+ closeFunc func() error // called if non-nil
+ closec chan bool // else, if non-nil, send value to it on close
+}
+
+func (c *rwTestConn) Close() error {
+ if c.closeFunc != nil {
+ return c.closeFunc()
+ }
+ select {
+ case c.closec <- true:
+ default:
+ }
+ return nil
+}
+
+type testConn struct {
+ readMu sync.Mutex // for TestHandlerBodyClose
+ readBuf bytes.Buffer
+ writeBuf bytes.Buffer
+ closec chan bool // if non-nil, send value to it on close
+ noopConn
+}
+
+func (c *testConn) Read(b []byte) (int, error) {
+ c.readMu.Lock()
+ defer c.readMu.Unlock()
+ return c.readBuf.Read(b)
+}
+
+func (c *testConn) Write(b []byte) (int, error) {
+ return c.writeBuf.Write(b)
+}
+
+func (c *testConn) Close() error {
+ select {
+ case c.closec <- true:
+ default:
+ }
+ return nil
+}
+
+// reqBytes treats req as a request (with \n delimiters) and returns it with \r\n delimiters,
+// ending in \r\n\r\n
+func reqBytes(req string) []byte {
+ return []byte(strings.ReplaceAll(strings.TrimSpace(req), "\n", "\r\n") + "\r\n\r\n")
+}
+
+type handlerTest struct {
+ logbuf bytes.Buffer
+ handler Handler
+}
+
+func newHandlerTest(h Handler) handlerTest {
+ return handlerTest{handler: h}
+}
+
+func (ht *handlerTest) rawResponse(req string) string {
+ reqb := reqBytes(req)
+ var output strings.Builder
+ conn := &rwTestConn{
+ Reader: bytes.NewReader(reqb),
+ Writer: &output,
+ closec: make(chan bool, 1),
+ }
+ ln := &oneConnListener{conn: conn}
+ srv := &Server{
+ ErrorLog: log.New(&ht.logbuf, "", 0),
+ Handler: ht.handler,
+ }
+ go srv.Serve(ln)
+ <-conn.closec
+ return output.String()
+}
+
+func TestConsumingBodyOnNextConn(t *testing.T) {
+ t.Parallel()
+ defer afterTest(t)
+ conn := new(testConn)
+ for i := 0; i < 2; i++ {
+ conn.readBuf.Write([]byte(
+ "POST / HTTP/1.1\r\n" +
+ "Host: test\r\n" +
+ "Content-Length: 11\r\n" +
+ "\r\n" +
+ "foo=1&bar=1"))
+ }
+
+ reqNum := 0
+ ch := make(chan *Request)
+ servech := make(chan error)
+ listener := &oneConnListener{conn}
+ handler := func(res ResponseWriter, req *Request) {
+ reqNum++
+ ch <- req
+ }
+
+ go func() {
+ servech <- Serve(listener, HandlerFunc(handler))
+ }()
+
+ var req *Request
+ req = <-ch
+ if req == nil {
+ t.Fatal("Got nil first request.")
+ }
+ if req.Method != "POST" {
+ t.Errorf("For request #1's method, got %q; expected %q",
+ req.Method, "POST")
+ }
+
+ req = <-ch
+ if req == nil {
+ t.Fatal("Got nil first request.")
+ }
+ if req.Method != "POST" {
+ t.Errorf("For request #2's method, got %q; expected %q",
+ req.Method, "POST")
+ }
+
+ if serveerr := <-servech; serveerr != io.EOF {
+ t.Errorf("Serve returned %q; expected EOF", serveerr)
+ }
+}
+
+type stringHandler string
+
+func (s stringHandler) ServeHTTP(w ResponseWriter, r *Request) {
+ w.Header().Set("Result", string(s))
+}
+
+var handlers = []struct {
+ pattern string
+ msg string
+}{
+ {"/", "Default"},
+ {"/someDir/", "someDir"},
+ {"/#/", "hash"},
+ {"someHost.com/someDir/", "someHost.com/someDir"},
+}
+
+var vtests = []struct {
+ url string
+ expected string
+}{
+ {"http://localhost/someDir/apage", "someDir"},
+ {"http://localhost/%23/apage", "hash"},
+ {"http://localhost/otherDir/apage", "Default"},
+ {"http://someHost.com/someDir/apage", "someHost.com/someDir"},
+ {"http://otherHost.com/someDir/apage", "someDir"},
+ {"http://otherHost.com/aDir/apage", "Default"},
+ // redirections for trees
+ {"http://localhost/someDir", "/someDir/"},
+ {"http://localhost/%23", "/%23/"},
+ {"http://someHost.com/someDir", "/someDir/"},
+}
+
+func TestHostHandlers(t *testing.T) { run(t, testHostHandlers, []testMode{http1Mode}) }
+func testHostHandlers(t *testing.T, mode testMode) {
+ mux := NewServeMux()
+ for _, h := range handlers {
+ mux.Handle(h.pattern, stringHandler(h.msg))
+ }
+ ts := newClientServerTest(t, mode, mux).ts
+
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conn.Close()
+ cc := httputil.NewClientConn(conn, nil)
+ for _, vt := range vtests {
+ var r *Response
+ var req Request
+ if req.URL, err = url.Parse(vt.url); err != nil {
+ t.Errorf("cannot parse url: %v", err)
+ continue
+ }
+ if err := cc.Write(&req); err != nil {
+ t.Errorf("writing request: %v", err)
+ continue
+ }
+ r, err := cc.Read(&req)
+ if err != nil {
+ t.Errorf("reading response: %v", err)
+ continue
+ }
+ switch r.StatusCode {
+ case StatusOK:
+ s := r.Header.Get("Result")
+ if s != vt.expected {
+ t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected)
+ }
+ case StatusMovedPermanently:
+ s := r.Header.Get("Location")
+ if s != vt.expected {
+ t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected)
+ }
+ default:
+ t.Errorf("Get(%q) unhandled status code %d", vt.url, r.StatusCode)
+ }
+ }
+}
+
+var serveMuxRegister = []struct {
+ pattern string
+ h Handler
+}{
+ {"/dir/", serve(200)},
+ {"/search", serve(201)},
+ {"codesearch.google.com/search", serve(202)},
+ {"codesearch.google.com/", serve(203)},
+ {"example.com/", HandlerFunc(checkQueryStringHandler)},
+}
+
+// serve returns a handler that sends a response with the given code.
+func serve(code int) HandlerFunc {
+ return func(w ResponseWriter, r *Request) {
+ w.WriteHeader(code)
+ }
+}
+
+// checkQueryStringHandler checks if r.URL.RawQuery has the same value
+// as the URL excluding the scheme and the query string and sends 200
+// response code if it is, 500 otherwise.
+func checkQueryStringHandler(w ResponseWriter, r *Request) {
+ u := *r.URL
+ u.Scheme = "http"
+ u.Host = r.Host
+ u.RawQuery = ""
+ if "http://"+r.URL.RawQuery == u.String() {
+ w.WriteHeader(200)
+ } else {
+ w.WriteHeader(500)
+ }
+}
+
+var serveMuxTests = []struct {
+ method string
+ host string
+ path string
+ code int
+ pattern string
+}{
+ {"GET", "google.com", "/", 404, ""},
+ {"GET", "google.com", "/dir", 301, "/dir/"},
+ {"GET", "google.com", "/dir/", 200, "/dir/"},
+ {"GET", "google.com", "/dir/file", 200, "/dir/"},
+ {"GET", "google.com", "/search", 201, "/search"},
+ {"GET", "google.com", "/search/", 404, ""},
+ {"GET", "google.com", "/search/foo", 404, ""},
+ {"GET", "codesearch.google.com", "/search", 202, "codesearch.google.com/search"},
+ {"GET", "codesearch.google.com", "/search/", 203, "codesearch.google.com/"},
+ {"GET", "codesearch.google.com", "/search/foo", 203, "codesearch.google.com/"},
+ {"GET", "codesearch.google.com", "/", 203, "codesearch.google.com/"},
+ {"GET", "codesearch.google.com:443", "/", 203, "codesearch.google.com/"},
+ {"GET", "images.google.com", "/search", 201, "/search"},
+ {"GET", "images.google.com", "/search/", 404, ""},
+ {"GET", "images.google.com", "/search/foo", 404, ""},
+ {"GET", "google.com", "/../search", 301, "/search"},
+ {"GET", "google.com", "/dir/..", 301, ""},
+ {"GET", "google.com", "/dir/..", 301, ""},
+ {"GET", "google.com", "/dir/./file", 301, "/dir/"},
+
+ // The /foo -> /foo/ redirect applies to CONNECT requests
+ // but the path canonicalization does not.
+ {"CONNECT", "google.com", "/dir", 301, "/dir/"},
+ {"CONNECT", "google.com", "/../search", 404, ""},
+ {"CONNECT", "google.com", "/dir/..", 200, "/dir/"},
+ {"CONNECT", "google.com", "/dir/..", 200, "/dir/"},
+ {"CONNECT", "google.com", "/dir/./file", 200, "/dir/"},
+}
+
+func TestServeMuxHandler(t *testing.T) {
+ setParallel(t)
+ mux := NewServeMux()
+ for _, e := range serveMuxRegister {
+ mux.Handle(e.pattern, e.h)
+ }
+
+ for _, tt := range serveMuxTests {
+ r := &Request{
+ Method: tt.method,
+ Host: tt.host,
+ URL: &url.URL{
+ Path: tt.path,
+ },
+ }
+ h, pattern := mux.Handler(r)
+ rr := httptest.NewRecorder()
+ h.ServeHTTP(rr, r)
+ if pattern != tt.pattern || rr.Code != tt.code {
+ t.Errorf("%s %s %s = %d, %q, want %d, %q", tt.method, tt.host, tt.path, rr.Code, pattern, tt.code, tt.pattern)
+ }
+ }
+}
+
+// Issue 24297
+func TestServeMuxHandleFuncWithNilHandler(t *testing.T) {
+ setParallel(t)
+ defer func() {
+ if err := recover(); err == nil {
+ t.Error("expected call to mux.HandleFunc to panic")
+ }
+ }()
+ mux := NewServeMux()
+ mux.HandleFunc("/", nil)
+}
+
+var serveMuxTests2 = []struct {
+ method string
+ host string
+ url string
+ code int
+ redirOk bool
+}{
+ {"GET", "google.com", "/", 404, false},
+ {"GET", "example.com", "/test/?example.com/test/", 200, false},
+ {"GET", "example.com", "test/?example.com/test/", 200, true},
+}
+
+// TestServeMuxHandlerRedirects tests that automatic redirects generated by
+// mux.Handler() shouldn't clear the request's query string.
+func TestServeMuxHandlerRedirects(t *testing.T) {
+ setParallel(t)
+ mux := NewServeMux()
+ for _, e := range serveMuxRegister {
+ mux.Handle(e.pattern, e.h)
+ }
+
+ for _, tt := range serveMuxTests2 {
+ tries := 1 // expect at most 1 redirection if redirOk is true.
+ turl := tt.url
+ for {
+ u, e := url.Parse(turl)
+ if e != nil {
+ t.Fatal(e)
+ }
+ r := &Request{
+ Method: tt.method,
+ Host: tt.host,
+ URL: u,
+ }
+ h, _ := mux.Handler(r)
+ rr := httptest.NewRecorder()
+ h.ServeHTTP(rr, r)
+ if rr.Code != 301 {
+ if rr.Code != tt.code {
+ t.Errorf("%s %s %s = %d, want %d", tt.method, tt.host, tt.url, rr.Code, tt.code)
+ }
+ break
+ }
+ if !tt.redirOk {
+ t.Errorf("%s %s %s, unexpected redirect", tt.method, tt.host, tt.url)
+ break
+ }
+ turl = rr.HeaderMap.Get("Location")
+ tries--
+ }
+ if tries < 0 {
+ t.Errorf("%s %s %s, too many redirects", tt.method, tt.host, tt.url)
+ }
+ }
+}
+
+// Tests for https://golang.org/issue/900
+func TestMuxRedirectLeadingSlashes(t *testing.T) {
+ setParallel(t)
+ paths := []string{"//foo.txt", "///foo.txt", "/../../foo.txt"}
+ for _, path := range paths {
+ req, err := ReadRequest(bufio.NewReader(strings.NewReader("GET " + path + " HTTP/1.1\r\nHost: test\r\n\r\n")))
+ if err != nil {
+ t.Errorf("%s", err)
+ }
+ mux := NewServeMux()
+ resp := httptest.NewRecorder()
+
+ mux.ServeHTTP(resp, req)
+
+ if loc, expected := resp.Header().Get("Location"), "/foo.txt"; loc != expected {
+ t.Errorf("Expected Location header set to %q; got %q", expected, loc)
+ return
+ }
+
+ if code, expected := resp.Code, StatusMovedPermanently; code != expected {
+ t.Errorf("Expected response code of StatusMovedPermanently; got %d", code)
+ return
+ }
+ }
+}
+
+// Test that the special cased "/route" redirect
+// implicitly created by a registered "/route/"
+// properly sets the query string in the redirect URL.
+// See Issue 17841.
+func TestServeWithSlashRedirectKeepsQueryString(t *testing.T) {
+ run(t, testServeWithSlashRedirectKeepsQueryString, []testMode{http1Mode})
+}
+func testServeWithSlashRedirectKeepsQueryString(t *testing.T, mode testMode) {
+ writeBackQuery := func(w ResponseWriter, r *Request) {
+ fmt.Fprintf(w, "%s", r.URL.RawQuery)
+ }
+
+ mux := NewServeMux()
+ mux.HandleFunc("/testOne", writeBackQuery)
+ mux.HandleFunc("/testTwo/", writeBackQuery)
+ mux.HandleFunc("/testThree", writeBackQuery)
+ mux.HandleFunc("/testThree/", func(w ResponseWriter, r *Request) {
+ fmt.Fprintf(w, "%s:bar", r.URL.RawQuery)
+ })
+
+ ts := newClientServerTest(t, mode, mux).ts
+
+ tests := [...]struct {
+ path string
+ method string
+ want string
+ statusOk bool
+ }{
+ 0: {"/testOne?this=that", "GET", "this=that", true},
+ 1: {"/testTwo?foo=bar", "GET", "foo=bar", true},
+ 2: {"/testTwo?a=1&b=2&a=3", "GET", "a=1&b=2&a=3", true},
+ 3: {"/testTwo?", "GET", "", true},
+ 4: {"/testThree?foo", "GET", "foo", true},
+ 5: {"/testThree/?foo", "GET", "foo:bar", true},
+ 6: {"/testThree?foo", "CONNECT", "foo", true},
+ 7: {"/testThree/?foo", "CONNECT", "foo:bar", true},
+
+ // canonicalization or not
+ 8: {"/testOne/foo/..?foo", "GET", "foo", true},
+ 9: {"/testOne/foo/..?foo", "CONNECT", "404 page not found\n", false},
+ }
+
+ for i, tt := range tests {
+ req, _ := NewRequest(tt.method, ts.URL+tt.path, nil)
+ res, err := ts.Client().Do(req)
+ if err != nil {
+ continue
+ }
+ slurp, _ := io.ReadAll(res.Body)
+ res.Body.Close()
+ if !tt.statusOk {
+ if got, want := res.StatusCode, 404; got != want {
+ t.Errorf("#%d: Status = %d; want = %d", i, got, want)
+ }
+ }
+ if got, want := string(slurp), tt.want; got != want {
+ t.Errorf("#%d: Body = %q; want = %q", i, got, want)
+ }
+ }
+}
+
+func TestServeWithSlashRedirectForHostPatterns(t *testing.T) {
+ setParallel(t)
+
+ mux := NewServeMux()
+ mux.Handle("example.com/pkg/foo/", stringHandler("example.com/pkg/foo/"))
+ mux.Handle("example.com/pkg/bar", stringHandler("example.com/pkg/bar"))
+ mux.Handle("example.com/pkg/bar/", stringHandler("example.com/pkg/bar/"))
+ mux.Handle("example.com:3000/pkg/connect/", stringHandler("example.com:3000/pkg/connect/"))
+ mux.Handle("example.com:9000/", stringHandler("example.com:9000/"))
+ mux.Handle("/pkg/baz/", stringHandler("/pkg/baz/"))
+
+ tests := []struct {
+ method string
+ url string
+ code int
+ loc string
+ want string
+ }{
+ {"GET", "http://example.com/", 404, "", ""},
+ {"GET", "http://example.com/pkg/foo", 301, "/pkg/foo/", ""},
+ {"GET", "http://example.com/pkg/bar", 200, "", "example.com/pkg/bar"},
+ {"GET", "http://example.com/pkg/bar/", 200, "", "example.com/pkg/bar/"},
+ {"GET", "http://example.com/pkg/baz", 301, "/pkg/baz/", ""},
+ {"GET", "http://example.com:3000/pkg/foo", 301, "/pkg/foo/", ""},
+ {"CONNECT", "http://example.com/", 404, "", ""},
+ {"CONNECT", "http://example.com:3000/", 404, "", ""},
+ {"CONNECT", "http://example.com:9000/", 200, "", "example.com:9000/"},
+ {"CONNECT", "http://example.com/pkg/foo", 301, "/pkg/foo/", ""},
+ {"CONNECT", "http://example.com:3000/pkg/foo", 404, "", ""},
+ {"CONNECT", "http://example.com:3000/pkg/baz", 301, "/pkg/baz/", ""},
+ {"CONNECT", "http://example.com:3000/pkg/connect", 301, "/pkg/connect/", ""},
+ }
+
+ for i, tt := range tests {
+ req, _ := NewRequest(tt.method, tt.url, nil)
+ w := httptest.NewRecorder()
+ mux.ServeHTTP(w, req)
+
+ if got, want := w.Code, tt.code; got != want {
+ t.Errorf("#%d: Status = %d; want = %d", i, got, want)
+ }
+
+ if tt.code == 301 {
+ if got, want := w.HeaderMap.Get("Location"), tt.loc; got != want {
+ t.Errorf("#%d: Location = %q; want = %q", i, got, want)
+ }
+ } else {
+ if got, want := w.HeaderMap.Get("Result"), tt.want; got != want {
+ t.Errorf("#%d: Result = %q; want = %q", i, got, want)
+ }
+ }
+ }
+}
+
+func TestShouldRedirectConcurrency(t *testing.T) { run(t, testShouldRedirectConcurrency) }
+func testShouldRedirectConcurrency(t *testing.T, mode testMode) {
+ mux := NewServeMux()
+ newClientServerTest(t, mode, mux)
+ mux.HandleFunc("/", func(w ResponseWriter, r *Request) {})
+}
+
+func BenchmarkServeMux(b *testing.B) { benchmarkServeMux(b, true) }
+func BenchmarkServeMux_SkipServe(b *testing.B) { benchmarkServeMux(b, false) }
+func benchmarkServeMux(b *testing.B, runHandler bool) {
+ type test struct {
+ path string
+ code int
+ req *Request
+ }
+
+ // Build example handlers and requests
+ var tests []test
+ endpoints := []string{"search", "dir", "file", "change", "count", "s"}
+ for _, e := range endpoints {
+ for i := 200; i < 230; i++ {
+ p := fmt.Sprintf("/%s/%d/", e, i)
+ tests = append(tests, test{
+ path: p,
+ code: i,
+ req: &Request{Method: "GET", Host: "localhost", URL: &url.URL{Path: p}},
+ })
+ }
+ }
+ mux := NewServeMux()
+ for _, tt := range tests {
+ mux.Handle(tt.path, serve(tt.code))
+ }
+
+ rw := httptest.NewRecorder()
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ for _, tt := range tests {
+ *rw = httptest.ResponseRecorder{}
+ h, pattern := mux.Handler(tt.req)
+ if runHandler {
+ h.ServeHTTP(rw, tt.req)
+ if pattern != tt.path || rw.Code != tt.code {
+ b.Fatalf("got %d, %q, want %d, %q", rw.Code, pattern, tt.code, tt.path)
+ }
+ }
+ }
+ }
+}
+
+func TestServerTimeouts(t *testing.T) { run(t, testServerTimeouts, []testMode{http1Mode}) }
+func testServerTimeouts(t *testing.T, mode testMode) {
+ // Try three times, with increasing timeouts.
+ tries := []time.Duration{250 * time.Millisecond, 500 * time.Millisecond, 1 * time.Second}
+ for i, timeout := range tries {
+ err := testServerTimeoutsWithTimeout(t, timeout, mode)
+ if err == nil {
+ return
+ }
+ t.Logf("failed at %v: %v", timeout, err)
+ if i != len(tries)-1 {
+ t.Logf("retrying at %v ...", tries[i+1])
+ }
+ }
+ t.Fatal("all attempts failed")
+}
+
+func testServerTimeoutsWithTimeout(t *testing.T, timeout time.Duration, mode testMode) error {
+ reqNum := 0
+ ts := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
+ reqNum++
+ fmt.Fprintf(res, "req=%d", reqNum)
+ }), func(ts *httptest.Server) {
+ ts.Config.ReadTimeout = timeout
+ ts.Config.WriteTimeout = timeout
+ }).ts
+
+ // Hit the HTTP server successfully.
+ c := ts.Client()
+ r, err := c.Get(ts.URL)
+ if err != nil {
+ return fmt.Errorf("http Get #1: %v", err)
+ }
+ got, err := io.ReadAll(r.Body)
+ expected := "req=1"
+ if string(got) != expected || err != nil {
+ return fmt.Errorf("Unexpected response for request #1; got %q ,%v; expected %q, nil",
+ string(got), err, expected)
+ }
+
+ // Slow client that should timeout.
+ t1 := time.Now()
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ return fmt.Errorf("Dial: %v", err)
+ }
+ buf := make([]byte, 1)
+ n, err := conn.Read(buf)
+ conn.Close()
+ latency := time.Since(t1)
+ if n != 0 || err != io.EOF {
+ return fmt.Errorf("Read = %v, %v, wanted %v, %v", n, err, 0, io.EOF)
+ }
+ minLatency := timeout / 5 * 4
+ if latency < minLatency {
+ return fmt.Errorf("got EOF after %s, want >= %s", latency, minLatency)
+ }
+
+ // Hit the HTTP server successfully again, verifying that the
+ // previous slow connection didn't run our handler. (that we
+ // get "req=2", not "req=3")
+ r, err = c.Get(ts.URL)
+ if err != nil {
+ return fmt.Errorf("http Get #2: %v", err)
+ }
+ got, err = io.ReadAll(r.Body)
+ r.Body.Close()
+ expected = "req=2"
+ if string(got) != expected || err != nil {
+ return fmt.Errorf("Get #2 got %q, %v, want %q, nil", string(got), err, expected)
+ }
+
+ if !testing.Short() {
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ return fmt.Errorf("long Dial: %v", err)
+ }
+ defer conn.Close()
+ go io.Copy(io.Discard, conn)
+ for i := 0; i < 5; i++ {
+ _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
+ if err != nil {
+ return fmt.Errorf("on write %d: %v", i, err)
+ }
+ time.Sleep(timeout / 2)
+ }
+ }
+ return nil
+}
+
+func TestServerReadTimeout(t *testing.T) { run(t, testServerReadTimeout) }
+func testServerReadTimeout(t *testing.T, mode testMode) {
+ respBody := "response body"
+ for timeout := 5 * time.Millisecond; ; timeout *= 2 {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
+ _, err := io.Copy(io.Discard, req.Body)
+ if !errors.Is(err, os.ErrDeadlineExceeded) {
+ t.Errorf("server timed out reading request body: got err %v; want os.ErrDeadlineExceeded", err)
+ }
+ res.Write([]byte(respBody))
+ }), func(ts *httptest.Server) {
+ ts.Config.ReadHeaderTimeout = -1 // don't time out while reading headers
+ ts.Config.ReadTimeout = timeout
+ })
+ pr, pw := io.Pipe()
+ res, err := cst.c.Post(cst.ts.URL, "text/apocryphal", pr)
+ if err != nil {
+ t.Logf("Get error, retrying: %v", err)
+ cst.close()
+ continue
+ }
+ defer res.Body.Close()
+ got, err := io.ReadAll(res.Body)
+ if string(got) != respBody || err != nil {
+ t.Errorf("client read response body: %q, %v; want %q, nil", string(got), err, respBody)
+ }
+ pw.Close()
+ break
+ }
+}
+
+func TestServerWriteTimeout(t *testing.T) { run(t, testServerWriteTimeout) }
+func testServerWriteTimeout(t *testing.T, mode testMode) {
+ for timeout := 5 * time.Millisecond; ; timeout *= 2 {
+ errc := make(chan error, 2)
+ cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
+ errc <- nil
+ _, err := io.Copy(res, neverEnding('a'))
+ errc <- err
+ }), func(ts *httptest.Server) {
+ ts.Config.WriteTimeout = timeout
+ })
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ // Probably caused by the write timeout expiring before the handler runs.
+ t.Logf("Get error, retrying: %v", err)
+ cst.close()
+ continue
+ }
+ defer res.Body.Close()
+ _, err = io.Copy(io.Discard, res.Body)
+ if err == nil {
+ t.Errorf("client reading from truncated request body: got nil error, want non-nil")
+ }
+ select {
+ case <-errc:
+ err = <-errc // io.Copy error
+ if !errors.Is(err, os.ErrDeadlineExceeded) {
+ t.Errorf("server timed out writing request body: got err %v; want os.ErrDeadlineExceeded", err)
+ }
+ return
+ default:
+ // The write timeout expired before the handler started.
+ t.Logf("handler didn't run, retrying")
+ cst.close()
+ }
+ }
+}
+
+// Test that the HTTP/2 server handles Server.WriteTimeout (Issue 18437)
+func TestWriteDeadlineExtendedOnNewRequest(t *testing.T) {
+ run(t, testWriteDeadlineExtendedOnNewRequest)
+}
+func testWriteDeadlineExtendedOnNewRequest(t *testing.T, mode testMode) {
+ if testing.Short() {
+ t.Skip("skipping in short mode")
+ }
+ ts := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {}),
+ func(ts *httptest.Server) {
+ ts.Config.WriteTimeout = 250 * time.Millisecond
+ },
+ ).ts
+
+ c := ts.Client()
+
+ for i := 1; i <= 3; i++ {
+ req, err := NewRequest("GET", ts.URL, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ r, err := c.Do(req)
+ if err != nil {
+ t.Fatalf("http2 Get #%d: %v", i, err)
+ }
+ r.Body.Close()
+ time.Sleep(ts.Config.WriteTimeout / 2)
+ }
+}
+
+// tryTimeouts runs testFunc with increasing timeouts. Test passes on first success,
+// and fails if all timeouts fail.
+func tryTimeouts(t *testing.T, testFunc func(timeout time.Duration) error) {
+ tries := []time.Duration{250 * time.Millisecond, 500 * time.Millisecond, 1 * time.Second}
+ for i, timeout := range tries {
+ err := testFunc(timeout)
+ if err == nil {
+ return
+ }
+ t.Logf("failed at %v: %v", timeout, err)
+ if i != len(tries)-1 {
+ t.Logf("retrying at %v ...", tries[i+1])
+ }
+ }
+ t.Fatal("all attempts failed")
+}
+
+// Test that the HTTP/2 server RSTs stream on slow write.
+func TestWriteDeadlineEnforcedPerStream(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping in short mode")
+ }
+ setParallel(t)
+ run(t, func(t *testing.T, mode testMode) {
+ tryTimeouts(t, func(timeout time.Duration) error {
+ return testWriteDeadlineEnforcedPerStream(t, mode, timeout)
+ })
+ })
+}
+
+func testWriteDeadlineEnforcedPerStream(t *testing.T, mode testMode, timeout time.Duration) error {
+ reqNum := 0
+ ts := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
+ reqNum++
+ if reqNum == 1 {
+ return // first request succeeds
+ }
+ time.Sleep(timeout) // second request times out
+ }), func(ts *httptest.Server) {
+ ts.Config.WriteTimeout = timeout / 2
+ }).ts
+
+ c := ts.Client()
+
+ req, err := NewRequest("GET", ts.URL, nil)
+ if err != nil {
+ return fmt.Errorf("NewRequest: %v", err)
+ }
+ r, err := c.Do(req)
+ if err != nil {
+ return fmt.Errorf("Get #1: %v", err)
+ }
+ r.Body.Close()
+
+ req, err = NewRequest("GET", ts.URL, nil)
+ if err != nil {
+ return fmt.Errorf("NewRequest: %v", err)
+ }
+ r, err = c.Do(req)
+ if err == nil {
+ r.Body.Close()
+ return fmt.Errorf("Get #2 expected error, got nil")
+ }
+ if mode == http2Mode {
+ expected := "stream ID 3; INTERNAL_ERROR" // client IDs are odd, second stream should be 3
+ if !strings.Contains(err.Error(), expected) {
+ return fmt.Errorf("http2 Get #2: expected error to contain %q, got %q", expected, err)
+ }
+ }
+ return nil
+}
+
+// Test that the HTTP/2 server does not send RST when WriteDeadline not set.
+func TestNoWriteDeadline(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping in short mode")
+ }
+ setParallel(t)
+ defer afterTest(t)
+ run(t, func(t *testing.T, mode testMode) {
+ tryTimeouts(t, func(timeout time.Duration) error {
+ return testNoWriteDeadline(t, mode, timeout)
+ })
+ })
+}
+
+func testNoWriteDeadline(t *testing.T, mode testMode, timeout time.Duration) error {
+ reqNum := 0
+ ts := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
+ reqNum++
+ if reqNum == 1 {
+ return // first request succeeds
+ }
+ time.Sleep(timeout) // second request timesout
+ })).ts
+
+ c := ts.Client()
+
+ for i := 0; i < 2; i++ {
+ req, err := NewRequest("GET", ts.URL, nil)
+ if err != nil {
+ return fmt.Errorf("NewRequest: %v", err)
+ }
+ r, err := c.Do(req)
+ if err != nil {
+ return fmt.Errorf("Get #%d: %v", i, err)
+ }
+ r.Body.Close()
+ }
+ return nil
+}
+
+// golang.org/issue/4741 -- setting only a write timeout that triggers
+// shouldn't cause a handler to block forever on reads (next HTTP
+// request) that will never happen.
+func TestOnlyWriteTimeout(t *testing.T) { run(t, testOnlyWriteTimeout, []testMode{http1Mode}) }
+func testOnlyWriteTimeout(t *testing.T, mode testMode) {
+ var (
+ mu sync.RWMutex
+ conn net.Conn
+ )
+ var afterTimeoutErrc = make(chan error, 1)
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) {
+ buf := make([]byte, 512<<10)
+ _, err := w.Write(buf)
+ if err != nil {
+ t.Errorf("handler Write error: %v", err)
+ return
+ }
+ mu.RLock()
+ defer mu.RUnlock()
+ if conn == nil {
+ t.Error("no established connection found")
+ return
+ }
+ conn.SetWriteDeadline(time.Now().Add(-30 * time.Second))
+ _, err = w.Write(buf)
+ afterTimeoutErrc <- err
+ }), func(ts *httptest.Server) {
+ ts.Listener = trackLastConnListener{ts.Listener, &mu, &conn}
+ }).ts
+
+ c := ts.Client()
+
+ err := func() error {
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ return err
+ }
+ _, err = io.Copy(io.Discard, res.Body)
+ res.Body.Close()
+ return err
+ }()
+ if err == nil {
+ t.Errorf("expected an error copying body from Get request")
+ }
+
+ if err := <-afterTimeoutErrc; err == nil {
+ t.Error("expected write error after timeout")
+ }
+}
+
+// trackLastConnListener tracks the last net.Conn that was accepted.
+type trackLastConnListener struct {
+ net.Listener
+
+ mu *sync.RWMutex
+ last *net.Conn // destination
+}
+
+func (l trackLastConnListener) Accept() (c net.Conn, err error) {
+ c, err = l.Listener.Accept()
+ if err == nil {
+ l.mu.Lock()
+ *l.last = c
+ l.mu.Unlock()
+ }
+ return
+}
+
+// TestIdentityResponse verifies that a handler can unset
+func TestIdentityResponse(t *testing.T) { run(t, testIdentityResponse) }
+func testIdentityResponse(t *testing.T, mode testMode) {
+ if mode == http2Mode {
+ t.Skip("https://go.dev/issue/56019")
+ }
+
+ handler := HandlerFunc(func(rw ResponseWriter, req *Request) {
+ rw.Header().Set("Content-Length", "3")
+ rw.Header().Set("Transfer-Encoding", req.FormValue("te"))
+ switch {
+ case req.FormValue("overwrite") == "1":
+ _, err := rw.Write([]byte("foo TOO LONG"))
+ if err != ErrContentLength {
+ t.Errorf("expected ErrContentLength; got %v", err)
+ }
+ case req.FormValue("underwrite") == "1":
+ rw.Header().Set("Content-Length", "500")
+ rw.Write([]byte("too short"))
+ default:
+ rw.Write([]byte("foo"))
+ }
+ })
+
+ ts := newClientServerTest(t, mode, handler).ts
+ c := ts.Client()
+
+ // Note: this relies on the assumption (which is true) that
+ // Get sends HTTP/1.1 or greater requests. Otherwise the
+ // server wouldn't have the choice to send back chunked
+ // responses.
+ for _, te := range []string{"", "identity"} {
+ url := ts.URL + "/?te=" + te
+ res, err := c.Get(url)
+ if err != nil {
+ t.Fatalf("error with Get of %s: %v", url, err)
+ }
+ if cl, expected := res.ContentLength, int64(3); cl != expected {
+ t.Errorf("for %s expected res.ContentLength of %d; got %d", url, expected, cl)
+ }
+ if cl, expected := res.Header.Get("Content-Length"), "3"; cl != expected {
+ t.Errorf("for %s expected Content-Length header of %q; got %q", url, expected, cl)
+ }
+ if tl, expected := len(res.TransferEncoding), 0; tl != expected {
+ t.Errorf("for %s expected len(res.TransferEncoding) of %d; got %d (%v)",
+ url, expected, tl, res.TransferEncoding)
+ }
+ res.Body.Close()
+ }
+
+ // Verify that ErrContentLength is returned
+ url := ts.URL + "/?overwrite=1"
+ res, err := c.Get(url)
+ if err != nil {
+ t.Fatalf("error with Get of %s: %v", url, err)
+ }
+ res.Body.Close()
+
+ if mode != http1Mode {
+ return
+ }
+
+ // Verify that the connection is closed when the declared Content-Length
+ // is larger than what the handler wrote.
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatalf("error dialing: %v", err)
+ }
+ _, err = conn.Write([]byte("GET /?underwrite=1 HTTP/1.1\r\nHost: foo\r\n\r\n"))
+ if err != nil {
+ t.Fatalf("error writing: %v", err)
+ }
+
+ // The ReadAll will hang for a failing test.
+ got, _ := io.ReadAll(conn)
+ expectedSuffix := "\r\n\r\ntoo short"
+ if !strings.HasSuffix(string(got), expectedSuffix) {
+ t.Errorf("Expected output to end with %q; got response body %q",
+ expectedSuffix, string(got))
+ }
+}
+
+func testTCPConnectionCloses(t *testing.T, req string, h Handler) {
+ setParallel(t)
+ s := newClientServerTest(t, http1Mode, h).ts
+
+ conn, err := net.Dial("tcp", s.Listener.Addr().String())
+ if err != nil {
+ t.Fatal("dial error:", err)
+ }
+ defer conn.Close()
+
+ _, err = fmt.Fprint(conn, req)
+ if err != nil {
+ t.Fatal("print error:", err)
+ }
+
+ r := bufio.NewReader(conn)
+ res, err := ReadResponse(r, &Request{Method: "GET"})
+ if err != nil {
+ t.Fatal("ReadResponse error:", err)
+ }
+
+ _, err = io.ReadAll(r)
+ if err != nil {
+ t.Fatal("read error:", err)
+ }
+
+ if !res.Close {
+ t.Errorf("Response.Close = false; want true")
+ }
+}
+
+func testTCPConnectionStaysOpen(t *testing.T, req string, handler Handler) {
+ setParallel(t)
+ ts := newClientServerTest(t, http1Mode, handler).ts
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conn.Close()
+ br := bufio.NewReader(conn)
+ for i := 0; i < 2; i++ {
+ if _, err := io.WriteString(conn, req); err != nil {
+ t.Fatal(err)
+ }
+ res, err := ReadResponse(br, nil)
+ if err != nil {
+ t.Fatalf("res %d: %v", i+1, err)
+ }
+ if _, err := io.Copy(io.Discard, res.Body); err != nil {
+ t.Fatalf("res %d body copy: %v", i+1, err)
+ }
+ res.Body.Close()
+ }
+}
+
+// TestServeHTTP10Close verifies that HTTP/1.0 requests won't be kept alive.
+func TestServeHTTP10Close(t *testing.T) {
+ testTCPConnectionCloses(t, "GET / HTTP/1.0\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
+ ServeFile(w, r, "testdata/file")
+ }))
+}
+
+// TestClientCanClose verifies that clients can also force a connection to close.
+func TestClientCanClose(t *testing.T) {
+ testTCPConnectionCloses(t, "GET / HTTP/1.1\r\nHost: foo\r\nConnection: close\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
+ // Nothing.
+ }))
+}
+
+// TestHandlersCanSetConnectionClose verifies that handlers can force a connection to close,
+// even for HTTP/1.1 requests.
+func TestHandlersCanSetConnectionClose11(t *testing.T) {
+ testTCPConnectionCloses(t, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Connection", "close")
+ }))
+}
+
+func TestHandlersCanSetConnectionClose10(t *testing.T) {
+ testTCPConnectionCloses(t, "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Connection", "close")
+ }))
+}
+
+func TestHTTP2UpgradeClosesConnection(t *testing.T) {
+ testTCPConnectionCloses(t, "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
+ // Nothing. (if not hijacked, the server should close the connection
+ // afterwards)
+ }))
+}
+
+func send204(w ResponseWriter, r *Request) { w.WriteHeader(204) }
+func send304(w ResponseWriter, r *Request) { w.WriteHeader(304) }
+
+// Issue 15647: 204 responses can't have bodies, so HTTP/1.0 keep-alive conns should stay open.
+func TestHTTP10KeepAlive204Response(t *testing.T) {
+ testTCPConnectionStaysOpen(t, "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", HandlerFunc(send204))
+}
+
+func TestHTTP11KeepAlive204Response(t *testing.T) {
+ testTCPConnectionStaysOpen(t, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n", HandlerFunc(send204))
+}
+
+func TestHTTP10KeepAlive304Response(t *testing.T) {
+ testTCPConnectionStaysOpen(t,
+ "GET / HTTP/1.0\r\nConnection: keep-alive\r\nIf-Modified-Since: Mon, 02 Jan 2006 15:04:05 GMT\r\n\r\n",
+ HandlerFunc(send304))
+}
+
+// Issue 15703
+func TestKeepAliveFinalChunkWithEOF(t *testing.T) { run(t, testKeepAliveFinalChunkWithEOF) }
+func testKeepAliveFinalChunkWithEOF(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.(Flusher).Flush() // force chunked encoding
+ w.Write([]byte("{\"Addr\": \"" + r.RemoteAddr + "\"}"))
+ }))
+ type data struct {
+ Addr string
+ }
+ var addrs [2]data
+ for i := range addrs {
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if err := json.NewDecoder(res.Body).Decode(&addrs[i]); err != nil {
+ t.Fatal(err)
+ }
+ if addrs[i].Addr == "" {
+ t.Fatal("no address")
+ }
+ res.Body.Close()
+ }
+ if addrs[0] != addrs[1] {
+ t.Fatalf("connection not reused")
+ }
+}
+
+func TestSetsRemoteAddr(t *testing.T) { run(t, testSetsRemoteAddr) }
+func testSetsRemoteAddr(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ fmt.Fprintf(w, "%s", r.RemoteAddr)
+ }))
+
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Fatalf("Get error: %v", err)
+ }
+ body, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatalf("ReadAll error: %v", err)
+ }
+ ip := string(body)
+ if !strings.HasPrefix(ip, "127.0.0.1:") && !strings.HasPrefix(ip, "[::1]:") {
+ t.Fatalf("Expected local addr; got %q", ip)
+ }
+}
+
+type blockingRemoteAddrListener struct {
+ net.Listener
+ conns chan<- net.Conn
+}
+
+func (l *blockingRemoteAddrListener) Accept() (net.Conn, error) {
+ c, err := l.Listener.Accept()
+ if err != nil {
+ return nil, err
+ }
+ brac := &blockingRemoteAddrConn{
+ Conn: c,
+ addrs: make(chan net.Addr, 1),
+ }
+ l.conns <- brac
+ return brac, nil
+}
+
+type blockingRemoteAddrConn struct {
+ net.Conn
+ addrs chan net.Addr
+}
+
+func (c *blockingRemoteAddrConn) RemoteAddr() net.Addr {
+ return <-c.addrs
+}
+
+// Issue 12943
+func TestServerAllowsBlockingRemoteAddr(t *testing.T) {
+ run(t, testServerAllowsBlockingRemoteAddr, []testMode{http1Mode})
+}
+func testServerAllowsBlockingRemoteAddr(t *testing.T, mode testMode) {
+ conns := make(chan net.Conn)
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ fmt.Fprintf(w, "RA:%s", r.RemoteAddr)
+ }), func(ts *httptest.Server) {
+ ts.Listener = &blockingRemoteAddrListener{
+ Listener: ts.Listener,
+ conns: conns,
+ }
+ }).ts
+
+ c := ts.Client()
+ // Force separate connection for each:
+ c.Transport.(*Transport).DisableKeepAlives = true
+
+ fetch := func(num int, response chan<- string) {
+ resp, err := c.Get(ts.URL)
+ if err != nil {
+ t.Errorf("Request %d: %v", num, err)
+ response <- ""
+ return
+ }
+ defer resp.Body.Close()
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ t.Errorf("Request %d: %v", num, err)
+ response <- ""
+ return
+ }
+ response <- string(body)
+ }
+
+ // Start a request. The server will block on getting conn.RemoteAddr.
+ response1c := make(chan string, 1)
+ go fetch(1, response1c)
+
+ // Wait for the server to accept it; grab the connection.
+ conn1 := <-conns
+
+ // Start another request and grab its connection
+ response2c := make(chan string, 1)
+ go fetch(2, response2c)
+ conn2 := <-conns
+
+ // Send a response on connection 2.
+ conn2.(*blockingRemoteAddrConn).addrs <- &net.TCPAddr{
+ IP: net.ParseIP("12.12.12.12"), Port: 12}
+
+ // ... and see it
+ response2 := <-response2c
+ if g, e := response2, "RA:12.12.12.12:12"; g != e {
+ t.Fatalf("response 2 addr = %q; want %q", g, e)
+ }
+
+ // Finish the first response.
+ conn1.(*blockingRemoteAddrConn).addrs <- &net.TCPAddr{
+ IP: net.ParseIP("21.21.21.21"), Port: 21}
+
+ // ... and see it
+ response1 := <-response1c
+ if g, e := response1, "RA:21.21.21.21:21"; g != e {
+ t.Fatalf("response 1 addr = %q; want %q", g, e)
+ }
+}
+
+// TestHeadResponses verifies that all MIME type sniffing and Content-Length
+// counting of GET requests also happens on HEAD requests.
+func TestHeadResponses(t *testing.T) { run(t, testHeadResponses) }
+func testHeadResponses(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ _, err := w.Write([]byte("<html>"))
+ if err != nil {
+ t.Errorf("ResponseWriter.Write: %v", err)
+ }
+
+ // Also exercise the ReaderFrom path
+ _, err = io.Copy(w, strings.NewReader("789a"))
+ if err != nil {
+ t.Errorf("Copy(ResponseWriter, ...): %v", err)
+ }
+ }))
+ res, err := cst.c.Head(cst.ts.URL)
+ if err != nil {
+ t.Error(err)
+ }
+ if len(res.TransferEncoding) > 0 {
+ t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding)
+ }
+ if ct := res.Header.Get("Content-Type"); ct != "text/html; charset=utf-8" {
+ t.Errorf("Content-Type: %q; want text/html; charset=utf-8", ct)
+ }
+ if v := res.ContentLength; v != 10 {
+ t.Errorf("Content-Length: %d; want 10", v)
+ }
+ body, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Error(err)
+ }
+ if len(body) > 0 {
+ t.Errorf("got unexpected body %q", string(body))
+ }
+}
+
+func TestTLSHandshakeTimeout(t *testing.T) {
+ run(t, testTLSHandshakeTimeout, []testMode{https1Mode, http2Mode})
+}
+func testTLSHandshakeTimeout(t *testing.T, mode testMode) {
+ errc := make(chanWriter, 10) // but only expecting 1
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}),
+ func(ts *httptest.Server) {
+ ts.Config.ReadTimeout = 250 * time.Millisecond
+ ts.Config.ErrorLog = log.New(errc, "", 0)
+ },
+ ).ts
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ defer conn.Close()
+
+ var buf [1]byte
+ n, err := conn.Read(buf[:])
+ if err == nil || n != 0 {
+ t.Errorf("Read = %d, %v; want an error and no bytes", n, err)
+ }
+
+ v := <-errc
+ if !strings.Contains(v, "timeout") && !strings.Contains(v, "TLS handshake") {
+ t.Errorf("expected a TLS handshake timeout error; got %q", v)
+ }
+}
+
+func TestTLSServer(t *testing.T) { run(t, testTLSServer, []testMode{https1Mode, http2Mode}) }
+func testTLSServer(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.TLS != nil {
+ w.Header().Set("X-TLS-Set", "true")
+ if r.TLS.HandshakeComplete {
+ w.Header().Set("X-TLS-HandshakeComplete", "true")
+ }
+ }
+ }), func(ts *httptest.Server) {
+ ts.Config.ErrorLog = log.New(io.Discard, "", 0)
+ }).ts
+
+ // Connect an idle TCP connection to this server before we run
+ // our real tests. This idle connection used to block forever
+ // in the TLS handshake, preventing future connections from
+ // being accepted. It may prevent future accidental blocking
+ // in newConn.
+ idleConn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ defer idleConn.Close()
+
+ if !strings.HasPrefix(ts.URL, "https://") {
+ t.Errorf("expected test TLS server to start with https://, got %q", ts.URL)
+ return
+ }
+ client := ts.Client()
+ res, err := client.Get(ts.URL)
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ if res == nil {
+ t.Errorf("got nil Response")
+ return
+ }
+ defer res.Body.Close()
+ if res.Header.Get("X-TLS-Set") != "true" {
+ t.Errorf("expected X-TLS-Set response header")
+ return
+ }
+ if res.Header.Get("X-TLS-HandshakeComplete") != "true" {
+ t.Errorf("expected X-TLS-HandshakeComplete header")
+ }
+}
+
+func TestServeTLS(t *testing.T) {
+ CondSkipHTTP2(t)
+ // Not parallel: uses global test hooks.
+ defer afterTest(t)
+ defer SetTestHookServerServe(nil)
+
+ cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
+ if err != nil {
+ t.Fatal(err)
+ }
+ tlsConf := &tls.Config{
+ Certificates: []tls.Certificate{cert},
+ }
+
+ ln := newLocalListener(t)
+ defer ln.Close()
+ addr := ln.Addr().String()
+
+ serving := make(chan bool, 1)
+ SetTestHookServerServe(func(s *Server, ln net.Listener) {
+ serving <- true
+ })
+ handler := HandlerFunc(func(w ResponseWriter, r *Request) {})
+ s := &Server{
+ Addr: addr,
+ TLSConfig: tlsConf,
+ Handler: handler,
+ }
+ errc := make(chan error, 1)
+ go func() { errc <- s.ServeTLS(ln, "", "") }()
+ select {
+ case err := <-errc:
+ t.Fatalf("ServeTLS: %v", err)
+ case <-serving:
+ }
+
+ c, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{
+ InsecureSkipVerify: true,
+ NextProtos: []string{"h2", "http/1.1"},
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+ if got, want := c.ConnectionState().NegotiatedProtocol, "h2"; got != want {
+ t.Errorf("NegotiatedProtocol = %q; want %q", got, want)
+ }
+ if got, want := c.ConnectionState().NegotiatedProtocolIsMutual, true; got != want {
+ t.Errorf("NegotiatedProtocolIsMutual = %v; want %v", got, want)
+ }
+}
+
+// Test that the HTTPS server nicely rejects plaintext HTTP/1.x requests.
+func TestTLSServerRejectHTTPRequests(t *testing.T) {
+ run(t, testTLSServerRejectHTTPRequests, []testMode{https1Mode, http2Mode})
+}
+func testTLSServerRejectHTTPRequests(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ t.Error("unexpected HTTPS request")
+ }), func(ts *httptest.Server) {
+ var errBuf bytes.Buffer
+ ts.Config.ErrorLog = log.New(&errBuf, "", 0)
+ }).ts
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conn.Close()
+ io.WriteString(conn, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n")
+ slurp, err := io.ReadAll(conn)
+ if err != nil {
+ t.Fatal(err)
+ }
+ const wantPrefix = "HTTP/1.0 400 Bad Request\r\n"
+ if !strings.HasPrefix(string(slurp), wantPrefix) {
+ t.Errorf("response = %q; wanted prefix %q", slurp, wantPrefix)
+ }
+}
+
+// Issue 15908
+func TestAutomaticHTTP2_Serve_NoTLSConfig(t *testing.T) {
+ testAutomaticHTTP2_Serve(t, nil, true)
+}
+
+func TestAutomaticHTTP2_Serve_NonH2TLSConfig(t *testing.T) {
+ testAutomaticHTTP2_Serve(t, &tls.Config{}, false)
+}
+
+func TestAutomaticHTTP2_Serve_H2TLSConfig(t *testing.T) {
+ testAutomaticHTTP2_Serve(t, &tls.Config{NextProtos: []string{"h2"}}, true)
+}
+
+func testAutomaticHTTP2_Serve(t *testing.T, tlsConf *tls.Config, wantH2 bool) {
+ setParallel(t)
+ defer afterTest(t)
+ ln := newLocalListener(t)
+ ln.Close() // immediately (not a defer!)
+ var s Server
+ s.TLSConfig = tlsConf
+ if err := s.Serve(ln); err == nil {
+ t.Fatal("expected an error")
+ }
+ gotH2 := s.TLSNextProto["h2"] != nil
+ if gotH2 != wantH2 {
+ t.Errorf("http2 configured = %v; want %v", gotH2, wantH2)
+ }
+}
+
+func TestAutomaticHTTP2_Serve_WithTLSConfig(t *testing.T) {
+ setParallel(t)
+ defer afterTest(t)
+ ln := newLocalListener(t)
+ ln.Close() // immediately (not a defer!)
+ var s Server
+ // Set the TLSConfig. In reality, this would be the
+ // *tls.Config given to tls.NewListener.
+ s.TLSConfig = &tls.Config{
+ NextProtos: []string{"h2"},
+ }
+ if err := s.Serve(ln); err == nil {
+ t.Fatal("expected an error")
+ }
+ on := s.TLSNextProto["h2"] != nil
+ if !on {
+ t.Errorf("http2 wasn't automatically enabled")
+ }
+}
+
+func TestAutomaticHTTP2_ListenAndServe(t *testing.T) {
+ cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
+ if err != nil {
+ t.Fatal(err)
+ }
+ testAutomaticHTTP2_ListenAndServe(t, &tls.Config{
+ Certificates: []tls.Certificate{cert},
+ })
+}
+
+func TestAutomaticHTTP2_ListenAndServe_GetCertificate(t *testing.T) {
+ cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
+ if err != nil {
+ t.Fatal(err)
+ }
+ testAutomaticHTTP2_ListenAndServe(t, &tls.Config{
+ GetCertificate: func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
+ return &cert, nil
+ },
+ })
+}
+
+func testAutomaticHTTP2_ListenAndServe(t *testing.T, tlsConf *tls.Config) {
+ CondSkipHTTP2(t)
+ // Not parallel: uses global test hooks.
+ defer afterTest(t)
+ defer SetTestHookServerServe(nil)
+ var ok bool
+ var s *Server
+ const maxTries = 5
+ var ln net.Listener
+Try:
+ for try := 0; try < maxTries; try++ {
+ ln = newLocalListener(t)
+ addr := ln.Addr().String()
+ ln.Close()
+ t.Logf("Got %v", addr)
+ lnc := make(chan net.Listener, 1)
+ SetTestHookServerServe(func(s *Server, ln net.Listener) {
+ lnc <- ln
+ })
+ s = &Server{
+ Addr: addr,
+ TLSConfig: tlsConf,
+ }
+ errc := make(chan error, 1)
+ go func() { errc <- s.ListenAndServeTLS("", "") }()
+ select {
+ case err := <-errc:
+ t.Logf("On try #%v: %v", try+1, err)
+ continue
+ case ln = <-lnc:
+ ok = true
+ t.Logf("Listening on %v", ln.Addr().String())
+ break Try
+ }
+ }
+ if !ok {
+ t.Fatalf("Failed to start up after %d tries", maxTries)
+ }
+ defer ln.Close()
+ c, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{
+ InsecureSkipVerify: true,
+ NextProtos: []string{"h2", "http/1.1"},
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+ if got, want := c.ConnectionState().NegotiatedProtocol, "h2"; got != want {
+ t.Errorf("NegotiatedProtocol = %q; want %q", got, want)
+ }
+ if got, want := c.ConnectionState().NegotiatedProtocolIsMutual, true; got != want {
+ t.Errorf("NegotiatedProtocolIsMutual = %v; want %v", got, want)
+ }
+}
+
+type serverExpectTest struct {
+ contentLength int // of request body
+ chunked bool
+ expectation string // e.g. "100-continue"
+ readBody bool // whether handler should read the body (if false, sends StatusUnauthorized)
+ expectedResponse string // expected substring in first line of http response
+}
+
+func expectTest(contentLength int, expectation string, readBody bool, expectedResponse string) serverExpectTest {
+ return serverExpectTest{
+ contentLength: contentLength,
+ expectation: expectation,
+ readBody: readBody,
+ expectedResponse: expectedResponse,
+ }
+}
+
+var serverExpectTests = []serverExpectTest{
+ // Normal 100-continues, case-insensitive.
+ expectTest(100, "100-continue", true, "100 Continue"),
+ expectTest(100, "100-cOntInUE", true, "100 Continue"),
+
+ // No 100-continue.
+ expectTest(100, "", true, "200 OK"),
+
+ // 100-continue but requesting client to deny us,
+ // so it never reads the body.
+ expectTest(100, "100-continue", false, "401 Unauthorized"),
+ // Likewise without 100-continue:
+ expectTest(100, "", false, "401 Unauthorized"),
+
+ // Non-standard expectations are failures
+ expectTest(0, "a-pony", false, "417 Expectation Failed"),
+
+ // Expect-100 requested but no body (is apparently okay: Issue 7625)
+ expectTest(0, "100-continue", true, "200 OK"),
+ // Expect-100 requested but handler doesn't read the body
+ expectTest(0, "100-continue", false, "401 Unauthorized"),
+ // Expect-100 continue with no body, but a chunked body.
+ {
+ expectation: "100-continue",
+ readBody: true,
+ chunked: true,
+ expectedResponse: "100 Continue",
+ },
+}
+
+// Tests that the server responds to the "Expect" request header
+// correctly.
+func TestServerExpect(t *testing.T) { run(t, testServerExpect, []testMode{http1Mode}) }
+func testServerExpect(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ // Note using r.FormValue("readbody") because for POST
+ // requests that would read from r.Body, which we only
+ // conditionally want to do.
+ if strings.Contains(r.URL.RawQuery, "readbody=true") {
+ io.ReadAll(r.Body)
+ w.Write([]byte("Hi"))
+ } else {
+ w.WriteHeader(StatusUnauthorized)
+ }
+ })).ts
+
+ runTest := func(test serverExpectTest) {
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ defer conn.Close()
+
+ // Only send the body immediately if we're acting like an HTTP client
+ // that doesn't send 100-continue expectations.
+ writeBody := test.contentLength != 0 && strings.ToLower(test.expectation) != "100-continue"
+
+ wg := sync.WaitGroup{}
+ wg.Add(1)
+ defer wg.Wait()
+
+ go func() {
+ defer wg.Done()
+
+ contentLen := fmt.Sprintf("Content-Length: %d", test.contentLength)
+ if test.chunked {
+ contentLen = "Transfer-Encoding: chunked"
+ }
+ _, err := fmt.Fprintf(conn, "POST /?readbody=%v HTTP/1.1\r\n"+
+ "Connection: close\r\n"+
+ "%s\r\n"+
+ "Expect: %s\r\nHost: foo\r\n\r\n",
+ test.readBody, contentLen, test.expectation)
+ if err != nil {
+ t.Errorf("On test %#v, error writing request headers: %v", test, err)
+ return
+ }
+ if writeBody {
+ var targ io.WriteCloser = struct {
+ io.Writer
+ io.Closer
+ }{
+ conn,
+ io.NopCloser(nil),
+ }
+ if test.chunked {
+ targ = httputil.NewChunkedWriter(conn)
+ }
+ body := strings.Repeat("A", test.contentLength)
+ _, err = fmt.Fprint(targ, body)
+ if err == nil {
+ err = targ.Close()
+ }
+ if err != nil {
+ if !test.readBody {
+ // Server likely already hung up on us.
+ // See larger comment below.
+ t.Logf("On test %#v, acceptable error writing request body: %v", test, err)
+ return
+ }
+ t.Errorf("On test %#v, error writing request body: %v", test, err)
+ }
+ }
+ }()
+ bufr := bufio.NewReader(conn)
+ line, err := bufr.ReadString('\n')
+ if err != nil {
+ if writeBody && !test.readBody {
+ // This is an acceptable failure due to a possible TCP race:
+ // We were still writing data and the server hung up on us. A TCP
+ // implementation may send a RST if our request body data was known
+ // to be lost, which may trigger our reads to fail.
+ // See RFC 1122 page 88.
+ t.Logf("On test %#v, acceptable error from ReadString: %v", test, err)
+ return
+ }
+ t.Fatalf("On test %#v, ReadString: %v", test, err)
+ }
+ if !strings.Contains(line, test.expectedResponse) {
+ t.Errorf("On test %#v, got first line = %q; want %q", test, line, test.expectedResponse)
+ }
+ }
+
+ for _, test := range serverExpectTests {
+ runTest(test)
+ }
+}
+
+// Under a ~256KB (maxPostHandlerReadBytes) threshold, the server
+// should consume client request bodies that a handler didn't read.
+func TestServerUnreadRequestBodyLittle(t *testing.T) {
+ setParallel(t)
+ defer afterTest(t)
+ conn := new(testConn)
+ body := strings.Repeat("x", 100<<10)
+ conn.readBuf.Write([]byte(fmt.Sprintf(
+ "POST / HTTP/1.1\r\n"+
+ "Host: test\r\n"+
+ "Content-Length: %d\r\n"+
+ "\r\n", len(body))))
+ conn.readBuf.Write([]byte(body))
+
+ done := make(chan bool)
+
+ readBufLen := func() int {
+ conn.readMu.Lock()
+ defer conn.readMu.Unlock()
+ return conn.readBuf.Len()
+ }
+
+ ls := &oneConnListener{conn}
+ go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
+ defer close(done)
+ if bufLen := readBufLen(); bufLen < len(body)/2 {
+ t.Errorf("on request, read buffer length is %d; expected about 100 KB", bufLen)
+ }
+ rw.WriteHeader(200)
+ rw.(Flusher).Flush()
+ if g, e := readBufLen(), 0; g != e {
+ t.Errorf("after WriteHeader, read buffer length is %d; want %d", g, e)
+ }
+ if c := rw.Header().Get("Connection"); c != "" {
+ t.Errorf(`Connection header = %q; want ""`, c)
+ }
+ }))
+ <-done
+}
+
+// Over a ~256KB (maxPostHandlerReadBytes) threshold, the server
+// should ignore client request bodies that a handler didn't read
+// and close the connection.
+func TestServerUnreadRequestBodyLarge(t *testing.T) {
+ setParallel(t)
+ if testing.Short() && testenv.Builder() == "" {
+ t.Log("skipping in short mode")
+ }
+ conn := new(testConn)
+ body := strings.Repeat("x", 1<<20)
+ conn.readBuf.Write([]byte(fmt.Sprintf(
+ "POST / HTTP/1.1\r\n"+
+ "Host: test\r\n"+
+ "Content-Length: %d\r\n"+
+ "\r\n", len(body))))
+ conn.readBuf.Write([]byte(body))
+ conn.closec = make(chan bool, 1)
+
+ ls := &oneConnListener{conn}
+ go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
+ if conn.readBuf.Len() < len(body)/2 {
+ t.Errorf("on request, read buffer length is %d; expected about 1MB", conn.readBuf.Len())
+ }
+ rw.WriteHeader(200)
+ rw.(Flusher).Flush()
+ if conn.readBuf.Len() < len(body)/2 {
+ t.Errorf("post-WriteHeader, read buffer length is %d; expected about 1MB", conn.readBuf.Len())
+ }
+ }))
+ <-conn.closec
+
+ if res := conn.writeBuf.String(); !strings.Contains(res, "Connection: close") {
+ t.Errorf("Expected a Connection: close header; got response: %s", res)
+ }
+}
+
+type handlerBodyCloseTest struct {
+ bodySize int
+ bodyChunked bool
+ reqConnClose bool
+
+ wantEOFSearch bool // should Handler's Body.Close do Reads, looking for EOF?
+ wantNextReq bool // should it find the next request on the same conn?
+}
+
+func (t handlerBodyCloseTest) connectionHeader() string {
+ if t.reqConnClose {
+ return "Connection: close\r\n"
+ }
+ return ""
+}
+
+var handlerBodyCloseTests = [...]handlerBodyCloseTest{
+ // Small enough to slurp past to the next request +
+ // has Content-Length.
+ 0: {
+ bodySize: 20 << 10,
+ bodyChunked: false,
+ reqConnClose: false,
+ wantEOFSearch: true,
+ wantNextReq: true,
+ },
+
+ // Small enough to slurp past to the next request +
+ // is chunked.
+ 1: {
+ bodySize: 20 << 10,
+ bodyChunked: true,
+ reqConnClose: false,
+ wantEOFSearch: true,
+ wantNextReq: true,
+ },
+
+ // Small enough to slurp past to the next request +
+ // has Content-Length +
+ // declares Connection: close (so pointless to read more).
+ 2: {
+ bodySize: 20 << 10,
+ bodyChunked: false,
+ reqConnClose: true,
+ wantEOFSearch: false,
+ wantNextReq: false,
+ },
+
+ // Small enough to slurp past to the next request +
+ // declares Connection: close,
+ // but chunked, so it might have trailers.
+ // TODO: maybe skip this search if no trailers were declared
+ // in the headers.
+ 3: {
+ bodySize: 20 << 10,
+ bodyChunked: true,
+ reqConnClose: true,
+ wantEOFSearch: true,
+ wantNextReq: false,
+ },
+
+ // Big with Content-Length, so give up immediately if we know it's too big.
+ 4: {
+ bodySize: 1 << 20,
+ bodyChunked: false, // has a Content-Length
+ reqConnClose: false,
+ wantEOFSearch: false,
+ wantNextReq: false,
+ },
+
+ // Big chunked, so read a bit before giving up.
+ 5: {
+ bodySize: 1 << 20,
+ bodyChunked: true,
+ reqConnClose: false,
+ wantEOFSearch: true,
+ wantNextReq: false,
+ },
+
+ // Big with Connection: close, but chunked, so search for trailers.
+ // TODO: maybe skip this search if no trailers were declared
+ // in the headers.
+ 6: {
+ bodySize: 1 << 20,
+ bodyChunked: true,
+ reqConnClose: true,
+ wantEOFSearch: true,
+ wantNextReq: false,
+ },
+
+ // Big with Connection: close, so don't do any reads on Close.
+ // With Content-Length.
+ 7: {
+ bodySize: 1 << 20,
+ bodyChunked: false,
+ reqConnClose: true,
+ wantEOFSearch: false,
+ wantNextReq: false,
+ },
+}
+
+func TestHandlerBodyClose(t *testing.T) {
+ setParallel(t)
+ if testing.Short() && testenv.Builder() == "" {
+ t.Skip("skipping in -short mode")
+ }
+ for i, tt := range handlerBodyCloseTests {
+ testHandlerBodyClose(t, i, tt)
+ }
+}
+
+func testHandlerBodyClose(t *testing.T, i int, tt handlerBodyCloseTest) {
+ conn := new(testConn)
+ body := strings.Repeat("x", tt.bodySize)
+ if tt.bodyChunked {
+ conn.readBuf.WriteString("POST / HTTP/1.1\r\n" +
+ "Host: test\r\n" +
+ tt.connectionHeader() +
+ "Transfer-Encoding: chunked\r\n" +
+ "\r\n")
+ cw := internal.NewChunkedWriter(&conn.readBuf)
+ io.WriteString(cw, body)
+ cw.Close()
+ conn.readBuf.WriteString("\r\n")
+ } else {
+ conn.readBuf.Write([]byte(fmt.Sprintf(
+ "POST / HTTP/1.1\r\n"+
+ "Host: test\r\n"+
+ tt.connectionHeader()+
+ "Content-Length: %d\r\n"+
+ "\r\n", len(body))))
+ conn.readBuf.Write([]byte(body))
+ }
+ if !tt.reqConnClose {
+ conn.readBuf.WriteString("GET / HTTP/1.1\r\nHost: test\r\n\r\n")
+ }
+ conn.closec = make(chan bool, 1)
+
+ readBufLen := func() int {
+ conn.readMu.Lock()
+ defer conn.readMu.Unlock()
+ return conn.readBuf.Len()
+ }
+
+ ls := &oneConnListener{conn}
+ var numReqs int
+ var size0, size1 int
+ go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
+ numReqs++
+ if numReqs == 1 {
+ size0 = readBufLen()
+ req.Body.Close()
+ size1 = readBufLen()
+ }
+ }))
+ <-conn.closec
+ if numReqs < 1 || numReqs > 2 {
+ t.Fatalf("%d. bug in test. unexpected number of requests = %d", i, numReqs)
+ }
+ didSearch := size0 != size1
+ if didSearch != tt.wantEOFSearch {
+ t.Errorf("%d. did EOF search = %v; want %v (size went from %d to %d)", i, didSearch, !didSearch, size0, size1)
+ }
+ if tt.wantNextReq && numReqs != 2 {
+ t.Errorf("%d. numReq = %d; want 2", i, numReqs)
+ }
+}
+
+// testHandlerBodyConsumer represents a function injected into a test handler to
+// vary work done on a request Body.
+type testHandlerBodyConsumer struct {
+ name string
+ f func(io.ReadCloser)
+}
+
+var testHandlerBodyConsumers = []testHandlerBodyConsumer{
+ {"nil", func(io.ReadCloser) {}},
+ {"close", func(r io.ReadCloser) { r.Close() }},
+ {"discard", func(r io.ReadCloser) { io.Copy(io.Discard, r) }},
+}
+
+func TestRequestBodyReadErrorClosesConnection(t *testing.T) {
+ setParallel(t)
+ defer afterTest(t)
+ for _, handler := range testHandlerBodyConsumers {
+ conn := new(testConn)
+ conn.readBuf.WriteString("POST /public HTTP/1.1\r\n" +
+ "Host: test\r\n" +
+ "Transfer-Encoding: chunked\r\n" +
+ "\r\n" +
+ "hax\r\n" + // Invalid chunked encoding
+ "GET /secret HTTP/1.1\r\n" +
+ "Host: test\r\n" +
+ "\r\n")
+
+ conn.closec = make(chan bool, 1)
+ ls := &oneConnListener{conn}
+ var numReqs int
+ go Serve(ls, HandlerFunc(func(_ ResponseWriter, req *Request) {
+ numReqs++
+ if strings.Contains(req.URL.Path, "secret") {
+ t.Error("Request for /secret encountered, should not have happened.")
+ }
+ handler.f(req.Body)
+ }))
+ <-conn.closec
+ if numReqs != 1 {
+ t.Errorf("Handler %v: got %d reqs; want 1", handler.name, numReqs)
+ }
+ }
+}
+
+func TestInvalidTrailerClosesConnection(t *testing.T) {
+ setParallel(t)
+ defer afterTest(t)
+ for _, handler := range testHandlerBodyConsumers {
+ conn := new(testConn)
+ conn.readBuf.WriteString("POST /public HTTP/1.1\r\n" +
+ "Host: test\r\n" +
+ "Trailer: hack\r\n" +
+ "Transfer-Encoding: chunked\r\n" +
+ "\r\n" +
+ "3\r\n" +
+ "hax\r\n" +
+ "0\r\n" +
+ "I'm not a valid trailer\r\n" +
+ "GET /secret HTTP/1.1\r\n" +
+ "Host: test\r\n" +
+ "\r\n")
+
+ conn.closec = make(chan bool, 1)
+ ln := &oneConnListener{conn}
+ var numReqs int
+ go Serve(ln, HandlerFunc(func(_ ResponseWriter, req *Request) {
+ numReqs++
+ if strings.Contains(req.URL.Path, "secret") {
+ t.Errorf("Handler %s, Request for /secret encountered, should not have happened.", handler.name)
+ }
+ handler.f(req.Body)
+ }))
+ <-conn.closec
+ if numReqs != 1 {
+ t.Errorf("Handler %s: got %d reqs; want 1", handler.name, numReqs)
+ }
+ }
+}
+
+// slowTestConn is a net.Conn that provides a means to simulate parts of a
+// request being received piecemeal. Deadlines can be set and enforced in both
+// Read and Write.
+type slowTestConn struct {
+ // over multiple calls to Read, time.Durations are slept, strings are read.
+ script []any
+ closec chan bool
+
+ mu sync.Mutex // guards rd/wd
+ rd, wd time.Time // read, write deadline
+ noopConn
+}
+
+func (c *slowTestConn) SetDeadline(t time.Time) error {
+ c.SetReadDeadline(t)
+ c.SetWriteDeadline(t)
+ return nil
+}
+
+func (c *slowTestConn) SetReadDeadline(t time.Time) error {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ c.rd = t
+ return nil
+}
+
+func (c *slowTestConn) SetWriteDeadline(t time.Time) error {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ c.wd = t
+ return nil
+}
+
+func (c *slowTestConn) Read(b []byte) (n int, err error) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+restart:
+ if !c.rd.IsZero() && time.Now().After(c.rd) {
+ return 0, syscall.ETIMEDOUT
+ }
+ if len(c.script) == 0 {
+ return 0, io.EOF
+ }
+
+ switch cue := c.script[0].(type) {
+ case time.Duration:
+ if !c.rd.IsZero() {
+ // If the deadline falls in the middle of our sleep window, deduct
+ // part of the sleep, then return a timeout.
+ if remaining := time.Until(c.rd); remaining < cue {
+ c.script[0] = cue - remaining
+ time.Sleep(remaining)
+ return 0, syscall.ETIMEDOUT
+ }
+ }
+ c.script = c.script[1:]
+ time.Sleep(cue)
+ goto restart
+
+ case string:
+ n = copy(b, cue)
+ // If cue is too big for the buffer, leave the end for the next Read.
+ if len(cue) > n {
+ c.script[0] = cue[n:]
+ } else {
+ c.script = c.script[1:]
+ }
+
+ default:
+ panic("unknown cue in slowTestConn script")
+ }
+
+ return
+}
+
+func (c *slowTestConn) Close() error {
+ select {
+ case c.closec <- true:
+ default:
+ }
+ return nil
+}
+
+func (c *slowTestConn) Write(b []byte) (int, error) {
+ if !c.wd.IsZero() && time.Now().After(c.wd) {
+ return 0, syscall.ETIMEDOUT
+ }
+ return len(b), nil
+}
+
+func TestRequestBodyTimeoutClosesConnection(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping in -short mode")
+ }
+ defer afterTest(t)
+ for _, handler := range testHandlerBodyConsumers {
+ conn := &slowTestConn{
+ script: []any{
+ "POST /public HTTP/1.1\r\n" +
+ "Host: test\r\n" +
+ "Content-Length: 10000\r\n" +
+ "\r\n",
+ "foo bar baz",
+ 600 * time.Millisecond, // Request deadline should hit here
+ "GET /secret HTTP/1.1\r\n" +
+ "Host: test\r\n" +
+ "\r\n",
+ },
+ closec: make(chan bool, 1),
+ }
+ ls := &oneConnListener{conn}
+
+ var numReqs int
+ s := Server{
+ Handler: HandlerFunc(func(_ ResponseWriter, req *Request) {
+ numReqs++
+ if strings.Contains(req.URL.Path, "secret") {
+ t.Error("Request for /secret encountered, should not have happened.")
+ }
+ handler.f(req.Body)
+ }),
+ ReadTimeout: 400 * time.Millisecond,
+ }
+ go s.Serve(ls)
+ <-conn.closec
+
+ if numReqs != 1 {
+ t.Errorf("Handler %v: got %d reqs; want 1", handler.name, numReqs)
+ }
+ }
+}
+
+// cancelableTimeoutContext overwrites the error message to DeadlineExceeded
+type cancelableTimeoutContext struct {
+ context.Context
+}
+
+func (c cancelableTimeoutContext) Err() error {
+ if c.Context.Err() != nil {
+ return context.DeadlineExceeded
+ }
+ return nil
+}
+
+func TestTimeoutHandler(t *testing.T) { run(t, testTimeoutHandler) }
+func testTimeoutHandler(t *testing.T, mode testMode) {
+ sendHi := make(chan bool, 1)
+ writeErrors := make(chan error, 1)
+ sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
+ <-sendHi
+ _, werr := w.Write([]byte("hi"))
+ writeErrors <- werr
+ })
+ ctx, cancel := context.WithCancel(context.Background())
+ h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx})
+ cst := newClientServerTest(t, mode, h)
+
+ // Succeed without timing out:
+ sendHi <- true
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Error(err)
+ }
+ if g, e := res.StatusCode, StatusOK; g != e {
+ t.Errorf("got res.StatusCode %d; expected %d", g, e)
+ }
+ body, _ := io.ReadAll(res.Body)
+ if g, e := string(body), "hi"; g != e {
+ t.Errorf("got body %q; expected %q", g, e)
+ }
+ if g := <-writeErrors; g != nil {
+ t.Errorf("got unexpected Write error on first request: %v", g)
+ }
+
+ // Times out:
+ cancel()
+
+ res, err = cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Error(err)
+ }
+ if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
+ t.Errorf("got res.StatusCode %d; expected %d", g, e)
+ }
+ body, _ = io.ReadAll(res.Body)
+ if !strings.Contains(string(body), "<title>Timeout</title>") {
+ t.Errorf("expected timeout body; got %q", string(body))
+ }
+ if g, w := res.Header.Get("Content-Type"), "text/html; charset=utf-8"; g != w {
+ t.Errorf("response content-type = %q; want %q", g, w)
+ }
+
+ // Now make the previously-timed out handler speak again,
+ // which verifies the panic is handled:
+ sendHi <- true
+ if g, e := <-writeErrors, ErrHandlerTimeout; g != e {
+ t.Errorf("expected Write error of %v; got %v", e, g)
+ }
+}
+
+// See issues 8209 and 8414.
+func TestTimeoutHandlerRace(t *testing.T) { run(t, testTimeoutHandlerRace) }
+func testTimeoutHandlerRace(t *testing.T, mode testMode) {
+ delayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
+ ms, _ := strconv.Atoi(r.URL.Path[1:])
+ if ms == 0 {
+ ms = 1
+ }
+ for i := 0; i < ms; i++ {
+ w.Write([]byte("hi"))
+ time.Sleep(time.Millisecond)
+ }
+ })
+
+ ts := newClientServerTest(t, mode, TimeoutHandler(delayHi, 20*time.Millisecond, "")).ts
+
+ c := ts.Client()
+
+ var wg sync.WaitGroup
+ gate := make(chan bool, 10)
+ n := 50
+ if testing.Short() {
+ n = 10
+ gate = make(chan bool, 3)
+ }
+ for i := 0; i < n; i++ {
+ gate <- true
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ defer func() { <-gate }()
+ res, err := c.Get(fmt.Sprintf("%s/%d", ts.URL, rand.Intn(50)))
+ if err == nil {
+ io.Copy(io.Discard, res.Body)
+ res.Body.Close()
+ }
+ }()
+ }
+ wg.Wait()
+}
+
+// See issues 8209 and 8414.
+// Both issues involved panics in the implementation of TimeoutHandler.
+func TestTimeoutHandlerRaceHeader(t *testing.T) { run(t, testTimeoutHandlerRaceHeader) }
+func testTimeoutHandlerRaceHeader(t *testing.T, mode testMode) {
+ delay204 := HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.WriteHeader(204)
+ })
+
+ ts := newClientServerTest(t, mode, TimeoutHandler(delay204, time.Nanosecond, "")).ts
+
+ var wg sync.WaitGroup
+ gate := make(chan bool, 50)
+ n := 500
+ if testing.Short() {
+ n = 10
+ }
+
+ c := ts.Client()
+ for i := 0; i < n; i++ {
+ gate <- true
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ defer func() { <-gate }()
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ // We see ECONNRESET from the connection occasionally,
+ // and that's OK: this test is checking that the server does not panic.
+ t.Log(err)
+ return
+ }
+ defer res.Body.Close()
+ io.Copy(io.Discard, res.Body)
+ }()
+ }
+ wg.Wait()
+}
+
+// Issue 9162
+func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { run(t, testTimeoutHandlerRaceHeaderTimeout) }
+func testTimeoutHandlerRaceHeaderTimeout(t *testing.T, mode testMode) {
+ sendHi := make(chan bool, 1)
+ writeErrors := make(chan error, 1)
+ sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Type", "text/plain")
+ <-sendHi
+ _, werr := w.Write([]byte("hi"))
+ writeErrors <- werr
+ })
+ ctx, cancel := context.WithCancel(context.Background())
+ h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx})
+ cst := newClientServerTest(t, mode, h)
+
+ // Succeed without timing out:
+ sendHi <- true
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Error(err)
+ }
+ if g, e := res.StatusCode, StatusOK; g != e {
+ t.Errorf("got res.StatusCode %d; expected %d", g, e)
+ }
+ body, _ := io.ReadAll(res.Body)
+ if g, e := string(body), "hi"; g != e {
+ t.Errorf("got body %q; expected %q", g, e)
+ }
+ if g := <-writeErrors; g != nil {
+ t.Errorf("got unexpected Write error on first request: %v", g)
+ }
+
+ // Times out:
+ cancel()
+
+ res, err = cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Error(err)
+ }
+ if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
+ t.Errorf("got res.StatusCode %d; expected %d", g, e)
+ }
+ body, _ = io.ReadAll(res.Body)
+ if !strings.Contains(string(body), "<title>Timeout</title>") {
+ t.Errorf("expected timeout body; got %q", string(body))
+ }
+
+ // Now make the previously-timed out handler speak again,
+ // which verifies the panic is handled:
+ sendHi <- true
+ if g, e := <-writeErrors, ErrHandlerTimeout; g != e {
+ t.Errorf("expected Write error of %v; got %v", e, g)
+ }
+}
+
+// Issue 14568.
+func TestTimeoutHandlerStartTimerWhenServing(t *testing.T) {
+ run(t, testTimeoutHandlerStartTimerWhenServing)
+}
+func testTimeoutHandlerStartTimerWhenServing(t *testing.T, mode testMode) {
+ if testing.Short() {
+ t.Skip("skipping sleeping test in -short mode")
+ }
+ var handler HandlerFunc = func(w ResponseWriter, _ *Request) {
+ w.WriteHeader(StatusNoContent)
+ }
+ timeout := 300 * time.Millisecond
+ ts := newClientServerTest(t, mode, TimeoutHandler(handler, timeout, "")).ts
+ defer ts.Close()
+
+ c := ts.Client()
+
+ // Issue was caused by the timeout handler starting the timer when
+ // was created, not when the request. So wait for more than the timeout
+ // to ensure that's not the case.
+ time.Sleep(2 * timeout)
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ if res.StatusCode != StatusNoContent {
+ t.Errorf("got res.StatusCode %d, want %v", res.StatusCode, StatusNoContent)
+ }
+}
+
+func TestTimeoutHandlerContextCanceled(t *testing.T) { run(t, testTimeoutHandlerContextCanceled) }
+func testTimeoutHandlerContextCanceled(t *testing.T, mode testMode) {
+ writeErrors := make(chan error, 1)
+ sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Type", "text/plain")
+ var err error
+ // The request context has already been canceled, but
+ // retry the write for a while to give the timeout handler
+ // a chance to notice.
+ for i := 0; i < 100; i++ {
+ _, err = w.Write([]byte("a"))
+ if err != nil {
+ break
+ }
+ time.Sleep(1 * time.Millisecond)
+ }
+ writeErrors <- err
+ })
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+ h := NewTestTimeoutHandler(sayHi, ctx)
+ cst := newClientServerTest(t, mode, h)
+ defer cst.close()
+
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Error(err)
+ }
+ if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
+ t.Errorf("got res.StatusCode %d; expected %d", g, e)
+ }
+ body, _ := io.ReadAll(res.Body)
+ if g, e := string(body), ""; g != e {
+ t.Errorf("got body %q; expected %q", g, e)
+ }
+ if g, e := <-writeErrors, context.Canceled; g != e {
+ t.Errorf("got unexpected Write in handler: %v, want %g", g, e)
+ }
+}
+
+// https://golang.org/issue/15948
+func TestTimeoutHandlerEmptyResponse(t *testing.T) { run(t, testTimeoutHandlerEmptyResponse) }
+func testTimeoutHandlerEmptyResponse(t *testing.T, mode testMode) {
+ var handler HandlerFunc = func(w ResponseWriter, _ *Request) {
+ // No response.
+ }
+ timeout := 300 * time.Millisecond
+ ts := newClientServerTest(t, mode, TimeoutHandler(handler, timeout, "")).ts
+
+ c := ts.Client()
+
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ if res.StatusCode != StatusOK {
+ t.Errorf("got res.StatusCode %d, want %v", res.StatusCode, StatusOK)
+ }
+}
+
+// https://golang.org/issues/22084
+func TestTimeoutHandlerPanicRecovery(t *testing.T) {
+ wrapper := func(h Handler) Handler {
+ return TimeoutHandler(h, time.Second, "")
+ }
+ run(t, func(t *testing.T, mode testMode) {
+ testHandlerPanic(t, false, mode, wrapper, "intentional death for testing")
+ }, testNotParallel)
+}
+
+func TestRedirectBadPath(t *testing.T) {
+ // This used to crash. It's not valid input (bad path), but it
+ // shouldn't crash.
+ rr := httptest.NewRecorder()
+ req := &Request{
+ Method: "GET",
+ URL: &url.URL{
+ Scheme: "http",
+ Path: "not-empty-but-no-leading-slash", // bogus
+ },
+ }
+ Redirect(rr, req, "", 304)
+ if rr.Code != 304 {
+ t.Errorf("Code = %d; want 304", rr.Code)
+ }
+}
+
+// Test different URL formats and schemes
+func TestRedirect(t *testing.T) {
+ req, _ := NewRequest("GET", "http://example.com/qux/", nil)
+
+ var tests = []struct {
+ in string
+ want string
+ }{
+ // normal http
+ {"http://foobar.com/baz", "http://foobar.com/baz"},
+ // normal https
+ {"https://foobar.com/baz", "https://foobar.com/baz"},
+ // custom scheme
+ {"test://foobar.com/baz", "test://foobar.com/baz"},
+ // schemeless
+ {"//foobar.com/baz", "//foobar.com/baz"},
+ // relative to the root
+ {"/foobar.com/baz", "/foobar.com/baz"},
+ // relative to the current path
+ {"foobar.com/baz", "/qux/foobar.com/baz"},
+ // relative to the current path (+ going upwards)
+ {"../quux/foobar.com/baz", "/quux/foobar.com/baz"},
+ // incorrect number of slashes
+ {"///foobar.com/baz", "/foobar.com/baz"},
+
+ // Verifies we don't path.Clean() on the wrong parts in redirects:
+ {"/foo?next=http://bar.com/", "/foo?next=http://bar.com/"},
+ {"http://localhost:8080/_ah/login?continue=http://localhost:8080/",
+ "http://localhost:8080/_ah/login?continue=http://localhost:8080/"},
+
+ {"/фубар", "/%d1%84%d1%83%d0%b1%d0%b0%d1%80"},
+ {"http://foo.com/фубар", "http://foo.com/%d1%84%d1%83%d0%b1%d0%b0%d1%80"},
+ }
+
+ for _, tt := range tests {
+ rec := httptest.NewRecorder()
+ Redirect(rec, req, tt.in, 302)
+ if got, want := rec.Code, 302; got != want {
+ t.Errorf("Redirect(%q) generated status code %v; want %v", tt.in, got, want)
+ }
+ if got := rec.Header().Get("Location"); got != tt.want {
+ t.Errorf("Redirect(%q) generated Location header %q; want %q", tt.in, got, tt.want)
+ }
+ }
+}
+
+// Test that Redirect sets Content-Type header for GET and HEAD requests
+// and writes a short HTML body, unless the request already has a Content-Type header.
+func TestRedirectContentTypeAndBody(t *testing.T) {
+ type ctHeader struct {
+ Values []string
+ }
+
+ var tests = []struct {
+ method string
+ ct *ctHeader // Optional Content-Type header to set.
+ wantCT string
+ wantBody string
+ }{
+ {MethodGet, nil, "text/html; charset=utf-8", "<a href=\"/foo\">Found</a>.\n\n"},
+ {MethodHead, nil, "text/html; charset=utf-8", ""},
+ {MethodPost, nil, "", ""},
+ {MethodDelete, nil, "", ""},
+ {"foo", nil, "", ""},
+ {MethodGet, &ctHeader{[]string{"application/test"}}, "application/test", ""},
+ {MethodGet, &ctHeader{[]string{}}, "", ""},
+ {MethodGet, &ctHeader{nil}, "", ""},
+ }
+ for _, tt := range tests {
+ req := httptest.NewRequest(tt.method, "http://example.com/qux/", nil)
+ rec := httptest.NewRecorder()
+ if tt.ct != nil {
+ rec.Header()["Content-Type"] = tt.ct.Values
+ }
+ Redirect(rec, req, "/foo", 302)
+ if got, want := rec.Code, 302; got != want {
+ t.Errorf("Redirect(%q, %#v) generated status code %v; want %v", tt.method, tt.ct, got, want)
+ }
+ if got, want := rec.Header().Get("Content-Type"), tt.wantCT; got != want {
+ t.Errorf("Redirect(%q, %#v) generated Content-Type header %q; want %q", tt.method, tt.ct, got, want)
+ }
+ resp := rec.Result()
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got, want := string(body), tt.wantBody; got != want {
+ t.Errorf("Redirect(%q, %#v) generated Body %q; want %q", tt.method, tt.ct, got, want)
+ }
+ }
+}
+
+// TestZeroLengthPostAndResponse exercises an optimization done by the Transport:
+// when there is no body (either because the method doesn't permit a body, or an
+// explicit Content-Length of zero is present), then the transport can re-use the
+// connection immediately. But when it re-uses the connection, it typically closes
+// the previous request's body, which is not optimal for zero-lengthed bodies,
+// as the client would then see http.ErrBodyReadAfterClose and not 0, io.EOF.
+func TestZeroLengthPostAndResponse(t *testing.T) { run(t, testZeroLengthPostAndResponse) }
+
+func testZeroLengthPostAndResponse(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
+ all, err := io.ReadAll(r.Body)
+ if err != nil {
+ t.Fatalf("handler ReadAll: %v", err)
+ }
+ if len(all) != 0 {
+ t.Errorf("handler got %d bytes; expected 0", len(all))
+ }
+ rw.Header().Set("Content-Length", "0")
+ }))
+
+ req, err := NewRequest("POST", cst.ts.URL, strings.NewReader(""))
+ if err != nil {
+ t.Fatal(err)
+ }
+ req.ContentLength = 0
+
+ var resp [5]*Response
+ for i := range resp {
+ resp[i], err = cst.c.Do(req)
+ if err != nil {
+ t.Fatalf("client post #%d: %v", i, err)
+ }
+ }
+
+ for i := range resp {
+ all, err := io.ReadAll(resp[i].Body)
+ if err != nil {
+ t.Fatalf("req #%d: client ReadAll: %v", i, err)
+ }
+ if len(all) != 0 {
+ t.Errorf("req #%d: client got %d bytes; expected 0", i, len(all))
+ }
+ }
+}
+
+func TestHandlerPanicNil(t *testing.T) {
+ run(t, func(t *testing.T, mode testMode) {
+ testHandlerPanic(t, false, mode, nil, nil)
+ }, testNotParallel)
+}
+
+func TestHandlerPanic(t *testing.T) {
+ run(t, func(t *testing.T, mode testMode) {
+ testHandlerPanic(t, false, mode, nil, "intentional death for testing")
+ }, testNotParallel)
+}
+
+func TestHandlerPanicWithHijack(t *testing.T) {
+ // Only testing HTTP/1, and our http2 server doesn't support hijacking.
+ run(t, func(t *testing.T, mode testMode) {
+ testHandlerPanic(t, true, mode, nil, "intentional death for testing")
+ }, []testMode{http1Mode})
+}
+
+func testHandlerPanic(t *testing.T, withHijack bool, mode testMode, wrapper func(Handler) Handler, panicValue any) {
+ // Direct log output to a pipe.
+ //
+ // We read from the pipe to verify that the handler actually caught the panic
+ // and logged something.
+ //
+ // We use a pipe rather than a buffer, because when testing connection hijacking
+ // server shutdown doesn't wait for the hijacking handler to return, so the
+ // log may occur after the server has shut down.
+ pr, pw := io.Pipe()
+ defer pw.Close()
+
+ var handler Handler = HandlerFunc(func(w ResponseWriter, r *Request) {
+ if withHijack {
+ rwc, _, err := w.(Hijacker).Hijack()
+ if err != nil {
+ t.Logf("unexpected error: %v", err)
+ }
+ defer rwc.Close()
+ }
+ panic(panicValue)
+ })
+ if wrapper != nil {
+ handler = wrapper(handler)
+ }
+ cst := newClientServerTest(t, mode, handler, func(ts *httptest.Server) {
+ ts.Config.ErrorLog = log.New(pw, "", 0)
+ })
+
+ // Do a blocking read on the log output pipe.
+ done := make(chan bool, 1)
+ go func() {
+ buf := make([]byte, 4<<10)
+ _, err := pr.Read(buf)
+ pr.Close()
+ if err != nil && err != io.EOF {
+ t.Error(err)
+ }
+ done <- true
+ }()
+
+ _, err := cst.c.Get(cst.ts.URL)
+ if err == nil {
+ t.Logf("expected an error")
+ }
+
+ if panicValue == nil {
+ return
+ }
+
+ <-done
+}
+
+type terrorWriter struct{ t *testing.T }
+
+func (w terrorWriter) Write(p []byte) (int, error) {
+ w.t.Errorf("%s", p)
+ return len(p), nil
+}
+
+// Issue 16456: allow writing 0 bytes on hijacked conn to test hijack
+// without any log spam.
+func TestServerWriteHijackZeroBytes(t *testing.T) {
+ run(t, testServerWriteHijackZeroBytes, []testMode{http1Mode})
+}
+func testServerWriteHijackZeroBytes(t *testing.T, mode testMode) {
+ done := make(chan struct{})
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ defer close(done)
+ w.(Flusher).Flush()
+ conn, _, err := w.(Hijacker).Hijack()
+ if err != nil {
+ t.Errorf("Hijack: %v", err)
+ return
+ }
+ defer conn.Close()
+ _, err = w.Write(nil)
+ if err != ErrHijacked {
+ t.Errorf("Write error = %v; want ErrHijacked", err)
+ }
+ }), func(ts *httptest.Server) {
+ ts.Config.ErrorLog = log.New(terrorWriter{t}, "Unexpected write: ", 0)
+ }).ts
+
+ c := ts.Client()
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ <-done
+}
+
+func TestServerNoDate(t *testing.T) {
+ run(t, func(t *testing.T, mode testMode) {
+ testServerNoHeader(t, mode, "Date")
+ })
+}
+
+func TestServerContentType(t *testing.T) {
+ run(t, func(t *testing.T, mode testMode) {
+ testServerNoHeader(t, mode, "Content-Type")
+ })
+}
+
+func testServerNoHeader(t *testing.T, mode testMode, header string) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header()[header] = nil
+ io.WriteString(w, "<html>foo</html>") // non-empty
+ }))
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ if got, ok := res.Header[header]; ok {
+ t.Fatalf("Expected no %s header; got %q", header, got)
+ }
+}
+
+func TestStripPrefix(t *testing.T) { run(t, testStripPrefix) }
+func testStripPrefix(t *testing.T, mode testMode) {
+ h := HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("X-Path", r.URL.Path)
+ w.Header().Set("X-RawPath", r.URL.RawPath)
+ })
+ ts := newClientServerTest(t, mode, StripPrefix("/foo/bar", h)).ts
+
+ c := ts.Client()
+
+ cases := []struct {
+ reqPath string
+ path string // If empty we want a 404.
+ rawPath string
+ }{
+ {"/foo/bar/qux", "/qux", ""},
+ {"/foo/bar%2Fqux", "/qux", "%2Fqux"},
+ {"/foo%2Fbar/qux", "", ""}, // Escaped prefix does not match.
+ {"/bar", "", ""}, // No prefix match.
+ }
+ for _, tc := range cases {
+ t.Run(tc.reqPath, func(t *testing.T) {
+ res, err := c.Get(ts.URL + tc.reqPath)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ if tc.path == "" {
+ if res.StatusCode != StatusNotFound {
+ t.Errorf("got %q, want 404 Not Found", res.Status)
+ }
+ return
+ }
+ if res.StatusCode != StatusOK {
+ t.Fatalf("got %q, want 200 OK", res.Status)
+ }
+ if g, w := res.Header.Get("X-Path"), tc.path; g != w {
+ t.Errorf("got Path %q, want %q", g, w)
+ }
+ if g, w := res.Header.Get("X-RawPath"), tc.rawPath; g != w {
+ t.Errorf("got RawPath %q, want %q", g, w)
+ }
+ })
+ }
+}
+
+// https://golang.org/issue/18952.
+func TestStripPrefixNotModifyRequest(t *testing.T) {
+ h := StripPrefix("/foo", NotFoundHandler())
+ req := httptest.NewRequest("GET", "/foo/bar", nil)
+ h.ServeHTTP(httptest.NewRecorder(), req)
+ if req.URL.Path != "/foo/bar" {
+ t.Errorf("StripPrefix should not modify the provided Request, but it did")
+ }
+}
+
+func TestRequestLimit(t *testing.T) { run(t, testRequestLimit) }
+func testRequestLimit(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ t.Fatalf("didn't expect to get request in Handler")
+ }), optQuietLog)
+ req, _ := NewRequest("GET", cst.ts.URL, nil)
+ var bytesPerHeader = len("header12345: val12345\r\n")
+ for i := 0; i < ((DefaultMaxHeaderBytes+4096)/bytesPerHeader)+1; i++ {
+ req.Header.Set(fmt.Sprintf("header%05d", i), fmt.Sprintf("val%05d", i))
+ }
+ res, err := cst.c.Do(req)
+ if res != nil {
+ defer res.Body.Close()
+ }
+ if mode == http2Mode {
+ // In HTTP/2, the result depends on a race. If the client has received the
+ // server's SETTINGS before RoundTrip starts sending the request, then RoundTrip
+ // will fail with an error. Otherwise, the client should receive a 431 from the
+ // server.
+ if err == nil && res.StatusCode != 431 {
+ t.Fatalf("expected 431 response status; got: %d %s", res.StatusCode, res.Status)
+ }
+ } else {
+ // In HTTP/1, we expect a 431 from the server.
+ // Some HTTP clients may fail on this undefined behavior (server replying and
+ // closing the connection while the request is still being written), but
+ // we do support it (at least currently), so we expect a response below.
+ if err != nil {
+ t.Fatalf("Do: %v", err)
+ }
+ if res.StatusCode != 431 {
+ t.Fatalf("expected 431 response status; got: %d %s", res.StatusCode, res.Status)
+ }
+ }
+}
+
+type neverEnding byte
+
+func (b neverEnding) Read(p []byte) (n int, err error) {
+ for i := range p {
+ p[i] = byte(b)
+ }
+ return len(p), nil
+}
+
+type countReader struct {
+ r io.Reader
+ n *int64
+}
+
+func (cr countReader) Read(p []byte) (n int, err error) {
+ n, err = cr.r.Read(p)
+ atomic.AddInt64(cr.n, int64(n))
+ return
+}
+
+func TestRequestBodyLimit(t *testing.T) { run(t, testRequestBodyLimit) }
+func testRequestBodyLimit(t *testing.T, mode testMode) {
+ const limit = 1 << 20
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ r.Body = MaxBytesReader(w, r.Body, limit)
+ n, err := io.Copy(io.Discard, r.Body)
+ if err == nil {
+ t.Errorf("expected error from io.Copy")
+ }
+ if n != limit {
+ t.Errorf("io.Copy = %d, want %d", n, limit)
+ }
+ mbErr, ok := err.(*MaxBytesError)
+ if !ok {
+ t.Errorf("expected MaxBytesError, got %T", err)
+ }
+ if mbErr.Limit != limit {
+ t.Errorf("MaxBytesError.Limit = %d, want %d", mbErr.Limit, limit)
+ }
+ }))
+
+ nWritten := new(int64)
+ req, _ := NewRequest("POST", cst.ts.URL, io.LimitReader(countReader{neverEnding('a'), nWritten}, limit*200))
+
+ // Send the POST, but don't care it succeeds or not. The
+ // remote side is going to reply and then close the TCP
+ // connection, and HTTP doesn't really define if that's
+ // allowed or not. Some HTTP clients will get the response
+ // and some (like ours, currently) will complain that the
+ // request write failed, without reading the response.
+ //
+ // But that's okay, since what we're really testing is that
+ // the remote side hung up on us before we wrote too much.
+ resp, err := cst.c.Do(req)
+ if err == nil {
+ resp.Body.Close()
+ }
+
+ if atomic.LoadInt64(nWritten) > limit*100 {
+ t.Errorf("handler restricted the request body to %d bytes, but client managed to write %d",
+ limit, nWritten)
+ }
+}
+
+// TestClientWriteShutdown tests that if the client shuts down the write
+// side of their TCP connection, the server doesn't send a 400 Bad Request.
+func TestClientWriteShutdown(t *testing.T) { run(t, testClientWriteShutdown) }
+func testClientWriteShutdown(t *testing.T, mode testMode) {
+ if runtime.GOOS == "plan9" {
+ t.Skip("skipping test; see https://golang.org/issue/17906")
+ }
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ err = conn.(*net.TCPConn).CloseWrite()
+ if err != nil {
+ t.Fatalf("CloseWrite: %v", err)
+ }
+
+ bs, err := io.ReadAll(conn)
+ if err != nil {
+ t.Errorf("ReadAll: %v", err)
+ }
+ got := string(bs)
+ if got != "" {
+ t.Errorf("read %q from server; want nothing", got)
+ }
+}
+
+// Tests that chunked server responses that write 1 byte at a time are
+// buffered before chunk headers are added, not after chunk headers.
+func TestServerBufferedChunking(t *testing.T) {
+ conn := new(testConn)
+ conn.readBuf.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
+ conn.closec = make(chan bool, 1)
+ ls := &oneConnListener{conn}
+ go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
+ rw.(Flusher).Flush() // force the Header to be sent, in chunking mode, not counting the length
+ rw.Write([]byte{'x'})
+ rw.Write([]byte{'y'})
+ rw.Write([]byte{'z'})
+ }))
+ <-conn.closec
+ if !bytes.HasSuffix(conn.writeBuf.Bytes(), []byte("\r\n\r\n3\r\nxyz\r\n0\r\n\r\n")) {
+ t.Errorf("response didn't end with a single 3 byte 'xyz' chunk; got:\n%q",
+ conn.writeBuf.Bytes())
+ }
+}
+
+// Tests that the server flushes its response headers out when it's
+// ignoring the response body and waits a bit before forcefully
+// closing the TCP connection, causing the client to get a RST.
+// See https://golang.org/issue/3595
+func TestServerGracefulClose(t *testing.T) {
+ run(t, testServerGracefulClose, []testMode{http1Mode})
+}
+func testServerGracefulClose(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ Error(w, "bye", StatusUnauthorized)
+ })).ts
+
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conn.Close()
+ const bodySize = 5 << 20
+ req := []byte(fmt.Sprintf("POST / HTTP/1.1\r\nHost: foo.com\r\nContent-Length: %d\r\n\r\n", bodySize))
+ for i := 0; i < bodySize; i++ {
+ req = append(req, 'x')
+ }
+ writeErr := make(chan error)
+ go func() {
+ _, err := conn.Write(req)
+ writeErr <- err
+ }()
+ br := bufio.NewReader(conn)
+ lineNum := 0
+ for {
+ line, err := br.ReadString('\n')
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ t.Fatalf("ReadLine: %v", err)
+ }
+ lineNum++
+ if lineNum == 1 && !strings.Contains(line, "401 Unauthorized") {
+ t.Errorf("Response line = %q; want a 401", line)
+ }
+ }
+ // Wait for write to finish. This is a broken pipe on both
+ // Darwin and Linux, but checking this isn't the point of
+ // the test.
+ <-writeErr
+}
+
+func TestCaseSensitiveMethod(t *testing.T) { run(t, testCaseSensitiveMethod) }
+func testCaseSensitiveMethod(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.Method != "get" {
+ t.Errorf(`Got method %q; want "get"`, r.Method)
+ }
+ }))
+ defer cst.close()
+ req, _ := NewRequest("get", cst.ts.URL, nil)
+ res, err := cst.c.Do(req)
+ if err != nil {
+ t.Error(err)
+ return
+ }
+
+ res.Body.Close()
+}
+
+// TestContentLengthZero tests that for both an HTTP/1.0 and HTTP/1.1
+// request (both keep-alive), when a Handler never writes any
+// response, the net/http package adds a "Content-Length: 0" response
+// header.
+func TestContentLengthZero(t *testing.T) {
+ run(t, testContentLengthZero, []testMode{http1Mode})
+}
+func testContentLengthZero(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {})).ts
+
+ for _, version := range []string{"HTTP/1.0", "HTTP/1.1"} {
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatalf("error dialing: %v", err)
+ }
+ _, err = fmt.Fprintf(conn, "GET / %v\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n", version)
+ if err != nil {
+ t.Fatalf("error writing: %v", err)
+ }
+ req, _ := NewRequest("GET", "/", nil)
+ res, err := ReadResponse(bufio.NewReader(conn), req)
+ if err != nil {
+ t.Fatalf("error reading response: %v", err)
+ }
+ if te := res.TransferEncoding; len(te) > 0 {
+ t.Errorf("For version %q, Transfer-Encoding = %q; want none", version, te)
+ }
+ if cl := res.ContentLength; cl != 0 {
+ t.Errorf("For version %q, Content-Length = %v; want 0", version, cl)
+ }
+ conn.Close()
+ }
+}
+
+func TestCloseNotifier(t *testing.T) {
+ run(t, testCloseNotifier, []testMode{http1Mode})
+}
+func testCloseNotifier(t *testing.T, mode testMode) {
+ gotReq := make(chan bool, 1)
+ sawClose := make(chan bool, 1)
+ ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
+ gotReq <- true
+ cc := rw.(CloseNotifier).CloseNotify()
+ <-cc
+ sawClose <- true
+ })).ts
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatalf("error dialing: %v", err)
+ }
+ diec := make(chan bool)
+ go func() {
+ _, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n")
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ <-diec
+ conn.Close()
+ }()
+For:
+ for {
+ select {
+ case <-gotReq:
+ diec <- true
+ case <-sawClose:
+ break For
+ }
+ }
+ ts.Close()
+}
+
+// Tests that a pipelined request does not cause the first request's
+// Handler's CloseNotify channel to fire.
+//
+// Issue 13165 (where it used to deadlock), but behavior changed in Issue 23921.
+func TestCloseNotifierPipelined(t *testing.T) {
+ run(t, testCloseNotifierPipelined, []testMode{http1Mode})
+}
+func testCloseNotifierPipelined(t *testing.T, mode testMode) {
+ gotReq := make(chan bool, 2)
+ sawClose := make(chan bool, 2)
+ ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
+ gotReq <- true
+ cc := rw.(CloseNotifier).CloseNotify()
+ select {
+ case <-cc:
+ t.Error("unexpected CloseNotify")
+ case <-time.After(100 * time.Millisecond):
+ }
+ sawClose <- true
+ })).ts
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatalf("error dialing: %v", err)
+ }
+ diec := make(chan bool, 1)
+ defer close(diec)
+ go func() {
+ const req = "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n"
+ _, err = io.WriteString(conn, req+req) // two requests
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ <-diec
+ conn.Close()
+ }()
+ reqs := 0
+ closes := 0
+ for {
+ select {
+ case <-gotReq:
+ reqs++
+ if reqs > 2 {
+ t.Fatal("too many requests")
+ }
+ case <-sawClose:
+ closes++
+ if closes > 1 {
+ return
+ }
+ }
+ }
+}
+
+func TestCloseNotifierChanLeak(t *testing.T) {
+ defer afterTest(t)
+ req := reqBytes("GET / HTTP/1.0\nHost: golang.org")
+ for i := 0; i < 20; i++ {
+ var output bytes.Buffer
+ conn := &rwTestConn{
+ Reader: bytes.NewReader(req),
+ Writer: &output,
+ closec: make(chan bool, 1),
+ }
+ ln := &oneConnListener{conn: conn}
+ handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
+ // Ignore the return value and never read from
+ // it, testing that we don't leak goroutines
+ // on the sending side:
+ _ = rw.(CloseNotifier).CloseNotify()
+ })
+ go Serve(ln, handler)
+ <-conn.closec
+ }
+}
+
+// Tests that we can use CloseNotifier in one request, and later call Hijack
+// on a second request on the same connection.
+//
+// It also tests that the connReader stitches together its background
+// 1-byte read for CloseNotifier when CloseNotifier doesn't fire with
+// the rest of the second HTTP later.
+//
+// Issue 9763.
+// HTTP/1-only test. (http2 doesn't have Hijack)
+func TestHijackAfterCloseNotifier(t *testing.T) {
+ run(t, testHijackAfterCloseNotifier, []testMode{http1Mode})
+}
+func testHijackAfterCloseNotifier(t *testing.T, mode testMode) {
+ script := make(chan string, 2)
+ script <- "closenotify"
+ script <- "hijack"
+ close(script)
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ plan := <-script
+ switch plan {
+ default:
+ panic("bogus plan; too many requests")
+ case "closenotify":
+ w.(CloseNotifier).CloseNotify() // discard result
+ w.Header().Set("X-Addr", r.RemoteAddr)
+ case "hijack":
+ c, _, err := w.(Hijacker).Hijack()
+ if err != nil {
+ t.Errorf("Hijack in Handler: %v", err)
+ return
+ }
+ if _, ok := c.(*net.TCPConn); !ok {
+ // Verify it's not wrapped in some type.
+ // Not strictly a go1 compat issue, but in practice it probably is.
+ t.Errorf("type of hijacked conn is %T; want *net.TCPConn", c)
+ }
+ fmt.Fprintf(c, "HTTP/1.0 200 OK\r\nX-Addr: %v\r\nContent-Length: 0\r\n\r\n", r.RemoteAddr)
+ c.Close()
+ return
+ }
+ })).ts
+ res1, err := ts.Client().Get(ts.URL)
+ if err != nil {
+ log.Fatal(err)
+ }
+ res2, err := ts.Client().Get(ts.URL)
+ if err != nil {
+ log.Fatal(err)
+ }
+ addr1 := res1.Header.Get("X-Addr")
+ addr2 := res2.Header.Get("X-Addr")
+ if addr1 == "" || addr1 != addr2 {
+ t.Errorf("addr1, addr2 = %q, %q; want same", addr1, addr2)
+ }
+}
+
+func TestHijackBeforeRequestBodyRead(t *testing.T) {
+ run(t, testHijackBeforeRequestBodyRead, []testMode{http1Mode})
+}
+func testHijackBeforeRequestBodyRead(t *testing.T, mode testMode) {
+ var requestBody = bytes.Repeat([]byte("a"), 1<<20)
+ bodyOkay := make(chan bool, 1)
+ gotCloseNotify := make(chan bool, 1)
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ defer close(bodyOkay) // caller will read false if nothing else
+
+ reqBody := r.Body
+ r.Body = nil // to test that server.go doesn't use this value.
+
+ gone := w.(CloseNotifier).CloseNotify()
+ slurp, err := io.ReadAll(reqBody)
+ if err != nil {
+ t.Errorf("Body read: %v", err)
+ return
+ }
+ if len(slurp) != len(requestBody) {
+ t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody))
+ return
+ }
+ if !bytes.Equal(slurp, requestBody) {
+ t.Error("Backend read wrong request body.") // 1MB; omitting details
+ return
+ }
+ bodyOkay <- true
+ <-gone
+ gotCloseNotify <- true
+ })).ts
+
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conn.Close()
+
+ fmt.Fprintf(conn, "POST / HTTP/1.1\r\nHost: foo\r\nContent-Length: %d\r\n\r\n%s",
+ len(requestBody), requestBody)
+ if !<-bodyOkay {
+ // already failed.
+ return
+ }
+ conn.Close()
+ <-gotCloseNotify
+}
+
+func TestOptions(t *testing.T) { run(t, testOptions, []testMode{http1Mode}) }
+func testOptions(t *testing.T, mode testMode) {
+ uric := make(chan string, 2) // only expect 1, but leave space for 2
+ mux := NewServeMux()
+ mux.HandleFunc("/", func(w ResponseWriter, r *Request) {
+ uric <- r.RequestURI
+ })
+ ts := newClientServerTest(t, mode, mux).ts
+
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conn.Close()
+
+ // An OPTIONS * request should succeed.
+ _, err = conn.Write([]byte("OPTIONS * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
+ if err != nil {
+ t.Fatal(err)
+ }
+ br := bufio.NewReader(conn)
+ res, err := ReadResponse(br, &Request{Method: "OPTIONS"})
+ if err != nil {
+ t.Fatal(err)
+ }
+ if res.StatusCode != 200 {
+ t.Errorf("Got non-200 response to OPTIONS *: %#v", res)
+ }
+
+ // A GET * request on a ServeMux should fail.
+ _, err = conn.Write([]byte("GET * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
+ if err != nil {
+ t.Fatal(err)
+ }
+ res, err = ReadResponse(br, &Request{Method: "GET"})
+ if err != nil {
+ t.Fatal(err)
+ }
+ if res.StatusCode != 400 {
+ t.Errorf("Got non-400 response to GET *: %#v", res)
+ }
+
+ res, err = Get(ts.URL + "/second")
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ if got := <-uric; got != "/second" {
+ t.Errorf("Handler saw request for %q; want /second", got)
+ }
+}
+
+func TestOptionsHandler(t *testing.T) { run(t, testOptionsHandler, []testMode{http1Mode}) }
+func testOptionsHandler(t *testing.T, mode testMode) {
+ rc := make(chan *Request, 1)
+
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ rc <- r
+ }), func(ts *httptest.Server) {
+ ts.Config.DisableGeneralOptionsHandler = true
+ }).ts
+
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conn.Close()
+
+ _, err = conn.Write([]byte("OPTIONS * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if got := <-rc; got.Method != "OPTIONS" || got.RequestURI != "*" {
+ t.Errorf("Expected OPTIONS * request, got %v", got)
+ }
+}
+
+// Tests regarding the ordering of Write, WriteHeader, Header, and
+// Flush calls. In Go 1.0, rw.WriteHeader immediately flushed the
+// (*response).header to the wire. In Go 1.1, the actual wire flush is
+// delayed, so we could maybe tack on a Content-Length and better
+// Content-Type after we see more (or all) of the output. To preserve
+// compatibility with Go 1, we need to be careful to track which
+// headers were live at the time of WriteHeader, so we write the same
+// ones, even if the handler modifies them (~erroneously) after the
+// first Write.
+func TestHeaderToWire(t *testing.T) {
+ tests := []struct {
+ name string
+ handler func(ResponseWriter, *Request)
+ check func(got, logs string) error
+ }{
+ {
+ name: "write without Header",
+ handler: func(rw ResponseWriter, r *Request) {
+ rw.Write([]byte("hello world"))
+ },
+ check: func(got, logs string) error {
+ if !strings.Contains(got, "Content-Length:") {
+ return errors.New("no content-length")
+ }
+ if !strings.Contains(got, "Content-Type: text/plain") {
+ return errors.New("no content-type")
+ }
+ return nil
+ },
+ },
+ {
+ name: "Header mutation before write",
+ handler: func(rw ResponseWriter, r *Request) {
+ h := rw.Header()
+ h.Set("Content-Type", "some/type")
+ rw.Write([]byte("hello world"))
+ h.Set("Too-Late", "bogus")
+ },
+ check: func(got, logs string) error {
+ if !strings.Contains(got, "Content-Length:") {
+ return errors.New("no content-length")
+ }
+ if !strings.Contains(got, "Content-Type: some/type") {
+ return errors.New("wrong content-type")
+ }
+ if strings.Contains(got, "Too-Late") {
+ return errors.New("don't want too-late header")
+ }
+ return nil
+ },
+ },
+ {
+ name: "write then useless Header mutation",
+ handler: func(rw ResponseWriter, r *Request) {
+ rw.Write([]byte("hello world"))
+ rw.Header().Set("Too-Late", "Write already wrote headers")
+ },
+ check: func(got, logs string) error {
+ if strings.Contains(got, "Too-Late") {
+ return errors.New("header appeared from after WriteHeader")
+ }
+ return nil
+ },
+ },
+ {
+ name: "flush then write",
+ handler: func(rw ResponseWriter, r *Request) {
+ rw.(Flusher).Flush()
+ rw.Write([]byte("post-flush"))
+ rw.Header().Set("Too-Late", "Write already wrote headers")
+ },
+ check: func(got, logs string) error {
+ if !strings.Contains(got, "Transfer-Encoding: chunked") {
+ return errors.New("not chunked")
+ }
+ if strings.Contains(got, "Too-Late") {
+ return errors.New("header appeared from after WriteHeader")
+ }
+ return nil
+ },
+ },
+ {
+ name: "header then flush",
+ handler: func(rw ResponseWriter, r *Request) {
+ rw.Header().Set("Content-Type", "some/type")
+ rw.(Flusher).Flush()
+ rw.Write([]byte("post-flush"))
+ rw.Header().Set("Too-Late", "Write already wrote headers")
+ },
+ check: func(got, logs string) error {
+ if !strings.Contains(got, "Transfer-Encoding: chunked") {
+ return errors.New("not chunked")
+ }
+ if strings.Contains(got, "Too-Late") {
+ return errors.New("header appeared from after WriteHeader")
+ }
+ if !strings.Contains(got, "Content-Type: some/type") {
+ return errors.New("wrong content-type")
+ }
+ return nil
+ },
+ },
+ {
+ name: "sniff-on-first-write content-type",
+ handler: func(rw ResponseWriter, r *Request) {
+ rw.Write([]byte("<html><head></head><body>some html</body></html>"))
+ rw.Header().Set("Content-Type", "x/wrong")
+ },
+ check: func(got, logs string) error {
+ if !strings.Contains(got, "Content-Type: text/html") {
+ return errors.New("wrong content-type; want html")
+ }
+ return nil
+ },
+ },
+ {
+ name: "explicit content-type wins",
+ handler: func(rw ResponseWriter, r *Request) {
+ rw.Header().Set("Content-Type", "some/type")
+ rw.Write([]byte("<html><head></head><body>some html</body></html>"))
+ },
+ check: func(got, logs string) error {
+ if !strings.Contains(got, "Content-Type: some/type") {
+ return errors.New("wrong content-type; want html")
+ }
+ return nil
+ },
+ },
+ {
+ name: "empty handler",
+ handler: func(rw ResponseWriter, r *Request) {
+ },
+ check: func(got, logs string) error {
+ if !strings.Contains(got, "Content-Length: 0") {
+ return errors.New("want 0 content-length")
+ }
+ return nil
+ },
+ },
+ {
+ name: "only Header, no write",
+ handler: func(rw ResponseWriter, r *Request) {
+ rw.Header().Set("Some-Header", "some-value")
+ },
+ check: func(got, logs string) error {
+ if !strings.Contains(got, "Some-Header") {
+ return errors.New("didn't get header")
+ }
+ return nil
+ },
+ },
+ {
+ name: "WriteHeader call",
+ handler: func(rw ResponseWriter, r *Request) {
+ rw.WriteHeader(404)
+ rw.Header().Set("Too-Late", "some-value")
+ },
+ check: func(got, logs string) error {
+ if !strings.Contains(got, "404") {
+ return errors.New("wrong status")
+ }
+ if strings.Contains(got, "Too-Late") {
+ return errors.New("shouldn't have seen Too-Late")
+ }
+ return nil
+ },
+ },
+ }
+ for _, tc := range tests {
+ ht := newHandlerTest(HandlerFunc(tc.handler))
+ got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org")
+ logs := ht.logbuf.String()
+ if err := tc.check(got, logs); err != nil {
+ t.Errorf("%s: %v\nGot response:\n%s\n\n%s", tc.name, err, got, logs)
+ }
+ }
+}
+
+type errorListener struct {
+ errs []error
+}
+
+func (l *errorListener) Accept() (c net.Conn, err error) {
+ if len(l.errs) == 0 {
+ return nil, io.EOF
+ }
+ err = l.errs[0]
+ l.errs = l.errs[1:]
+ return
+}
+
+func (l *errorListener) Close() error {
+ return nil
+}
+
+func (l *errorListener) Addr() net.Addr {
+ return dummyAddr("test-address")
+}
+
+func TestAcceptMaxFds(t *testing.T) {
+ setParallel(t)
+
+ ln := &errorListener{[]error{
+ &net.OpError{
+ Op: "accept",
+ Err: syscall.EMFILE,
+ }}}
+ server := &Server{
+ Handler: HandlerFunc(HandlerFunc(func(ResponseWriter, *Request) {})),
+ ErrorLog: log.New(io.Discard, "", 0), // noisy otherwise
+ }
+ err := server.Serve(ln)
+ if err != io.EOF {
+ t.Errorf("got error %v, want EOF", err)
+ }
+}
+
+func TestWriteAfterHijack(t *testing.T) {
+ req := reqBytes("GET / HTTP/1.1\nHost: golang.org")
+ var buf strings.Builder
+ wrotec := make(chan bool, 1)
+ conn := &rwTestConn{
+ Reader: bytes.NewReader(req),
+ Writer: &buf,
+ closec: make(chan bool, 1),
+ }
+ handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
+ conn, bufrw, err := rw.(Hijacker).Hijack()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ go func() {
+ bufrw.Write([]byte("[hijack-to-bufw]"))
+ bufrw.Flush()
+ conn.Write([]byte("[hijack-to-conn]"))
+ conn.Close()
+ wrotec <- true
+ }()
+ })
+ ln := &oneConnListener{conn: conn}
+ go Serve(ln, handler)
+ <-conn.closec
+ <-wrotec
+ if g, w := buf.String(), "[hijack-to-bufw][hijack-to-conn]"; g != w {
+ t.Errorf("wrote %q; want %q", g, w)
+ }
+}
+
+func TestDoubleHijack(t *testing.T) {
+ req := reqBytes("GET / HTTP/1.1\nHost: golang.org")
+ var buf bytes.Buffer
+ conn := &rwTestConn{
+ Reader: bytes.NewReader(req),
+ Writer: &buf,
+ closec: make(chan bool, 1),
+ }
+ handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
+ conn, _, err := rw.(Hijacker).Hijack()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ _, _, err = rw.(Hijacker).Hijack()
+ if err == nil {
+ t.Errorf("got err = nil; want err != nil")
+ }
+ conn.Close()
+ })
+ ln := &oneConnListener{conn: conn}
+ go Serve(ln, handler)
+ <-conn.closec
+}
+
+// https://golang.org/issue/5955
+// Note that this does not test the "request too large"
+// exit path from the http server. This is intentional;
+// not sending Connection: close is just a minor wire
+// optimization and is pointless if dealing with a
+// badly behaved client.
+func TestHTTP10ConnectionHeader(t *testing.T) {
+ run(t, testHTTP10ConnectionHeader, []testMode{http1Mode})
+}
+func testHTTP10ConnectionHeader(t *testing.T, mode testMode) {
+ mux := NewServeMux()
+ mux.Handle("/", HandlerFunc(func(ResponseWriter, *Request) {}))
+ ts := newClientServerTest(t, mode, mux).ts
+
+ // net/http uses HTTP/1.1 for requests, so write requests manually
+ tests := []struct {
+ req string // raw http request
+ expect []string // expected Connection header(s)
+ }{
+ {
+ req: "GET / HTTP/1.0\r\n\r\n",
+ expect: nil,
+ },
+ {
+ req: "OPTIONS * HTTP/1.0\r\n\r\n",
+ expect: nil,
+ },
+ {
+ req: "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n",
+ expect: []string{"keep-alive"},
+ },
+ }
+
+ for _, tt := range tests {
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatal("dial err:", err)
+ }
+
+ _, err = fmt.Fprint(conn, tt.req)
+ if err != nil {
+ t.Fatal("conn write err:", err)
+ }
+
+ resp, err := ReadResponse(bufio.NewReader(conn), &Request{Method: "GET"})
+ if err != nil {
+ t.Fatal("ReadResponse err:", err)
+ }
+ conn.Close()
+ resp.Body.Close()
+
+ got := resp.Header["Connection"]
+ if !reflect.DeepEqual(got, tt.expect) {
+ t.Errorf("wrong Connection headers for request %q. Got %q expect %q", tt.req, got, tt.expect)
+ }
+ }
+}
+
+// See golang.org/issue/5660
+func TestServerReaderFromOrder(t *testing.T) { run(t, testServerReaderFromOrder) }
+func testServerReaderFromOrder(t *testing.T, mode testMode) {
+ pr, pw := io.Pipe()
+ const size = 3 << 20
+ cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
+ rw.Header().Set("Content-Type", "text/plain") // prevent sniffing path
+ done := make(chan bool)
+ go func() {
+ io.Copy(rw, pr)
+ close(done)
+ }()
+ time.Sleep(25 * time.Millisecond) // give Copy a chance to break things
+ n, err := io.Copy(io.Discard, req.Body)
+ if err != nil {
+ t.Errorf("handler Copy: %v", err)
+ return
+ }
+ if n != size {
+ t.Errorf("handler Copy = %d; want %d", n, size)
+ }
+ pw.Write([]byte("hi"))
+ pw.Close()
+ <-done
+ }))
+
+ req, err := NewRequest("POST", cst.ts.URL, io.LimitReader(neverEnding('a'), size))
+ if err != nil {
+ t.Fatal(err)
+ }
+ res, err := cst.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ all, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ if string(all) != "hi" {
+ t.Errorf("Body = %q; want hi", all)
+ }
+}
+
+// Issue 6157, Issue 6685
+func TestCodesPreventingContentTypeAndBody(t *testing.T) {
+ for _, code := range []int{StatusNotModified, StatusNoContent} {
+ ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.URL.Path == "/header" {
+ w.Header().Set("Content-Length", "123")
+ }
+ w.WriteHeader(code)
+ if r.URL.Path == "/more" {
+ w.Write([]byte("stuff"))
+ }
+ }))
+ for _, req := range []string{
+ "GET / HTTP/1.0",
+ "GET /header HTTP/1.0",
+ "GET /more HTTP/1.0",
+ "GET / HTTP/1.1\nHost: foo",
+ "GET /header HTTP/1.1\nHost: foo",
+ "GET /more HTTP/1.1\nHost: foo",
+ } {
+ got := ht.rawResponse(req)
+ wantStatus := fmt.Sprintf("%d %s", code, StatusText(code))
+ if !strings.Contains(got, wantStatus) {
+ t.Errorf("Code %d: Wanted %q Modified for %q: %s", code, wantStatus, req, got)
+ } else if strings.Contains(got, "Content-Length") {
+ t.Errorf("Code %d: Got a Content-Length from %q: %s", code, req, got)
+ } else if strings.Contains(got, "stuff") {
+ t.Errorf("Code %d: Response contains a body from %q: %s", code, req, got)
+ }
+ }
+ }
+}
+
+func TestContentTypeOkayOn204(t *testing.T) {
+ ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Length", "123") // suppressed
+ w.Header().Set("Content-Type", "foo/bar")
+ w.WriteHeader(204)
+ }))
+ got := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
+ if !strings.Contains(got, "Content-Type: foo/bar") {
+ t.Errorf("Response = %q; want Content-Type: foo/bar", got)
+ }
+ if strings.Contains(got, "Content-Length: 123") {
+ t.Errorf("Response = %q; don't want a Content-Length", got)
+ }
+}
+
+// Issue 6995
+// A server Handler can receive a Request, and then turn around and
+// give a copy of that Request.Body out to the Transport (e.g. any
+// proxy). So then two people own that Request.Body (both the server
+// and the http client), and both think they can close it on failure.
+// Therefore, all incoming server requests Bodies need to be thread-safe.
+func TestTransportAndServerSharedBodyRace(t *testing.T) {
+ run(t, testTransportAndServerSharedBodyRace)
+}
+func testTransportAndServerSharedBodyRace(t *testing.T, mode testMode) {
+ const bodySize = 1 << 20
+
+ // errorf is like t.Errorf, but also writes to println. When
+ // this test fails, it hangs. This helps debugging and I've
+ // added this enough times "temporarily". It now gets added
+ // full time.
+ errorf := func(format string, args ...any) {
+ v := fmt.Sprintf(format, args...)
+ println(v)
+ t.Error(v)
+ }
+
+ unblockBackend := make(chan bool)
+ backend := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
+ gone := rw.(CloseNotifier).CloseNotify()
+ didCopy := make(chan any)
+ go func() {
+ n, err := io.CopyN(rw, req.Body, bodySize)
+ didCopy <- []any{n, err}
+ }()
+ isGone := false
+ Loop:
+ for {
+ select {
+ case <-didCopy:
+ break Loop
+ case <-gone:
+ isGone = true
+ case <-time.After(time.Second):
+ println("1 second passes in backend, proxygone=", isGone)
+ }
+ }
+ <-unblockBackend
+ }))
+ defer backend.close()
+
+ backendRespc := make(chan *Response, 1)
+ var proxy *clientServerTest
+ proxy = newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
+ req2, _ := NewRequest("POST", backend.ts.URL, req.Body)
+ req2.ContentLength = bodySize
+ cancel := make(chan struct{})
+ req2.Cancel = cancel
+
+ bresp, err := proxy.c.Do(req2)
+ if err != nil {
+ errorf("Proxy outbound request: %v", err)
+ return
+ }
+ _, err = io.CopyN(io.Discard, bresp.Body, bodySize/2)
+ if err != nil {
+ errorf("Proxy copy error: %v", err)
+ return
+ }
+ backendRespc <- bresp // to close later
+
+ // Try to cause a race: Both the Transport and the proxy handler's Server
+ // will try to read/close req.Body (aka req2.Body)
+ if mode == http2Mode {
+ close(cancel)
+ } else {
+ proxy.c.Transport.(*Transport).CancelRequest(req2)
+ }
+ rw.Write([]byte("OK"))
+ }))
+ defer proxy.close()
+
+ defer close(unblockBackend)
+ req, _ := NewRequest("POST", proxy.ts.URL, io.LimitReader(neverEnding('a'), bodySize))
+ res, err := proxy.c.Do(req)
+ if err != nil {
+ t.Fatalf("Original request: %v", err)
+ }
+
+ // Cleanup, so we don't leak goroutines.
+ res.Body.Close()
+ select {
+ case res := <-backendRespc:
+ res.Body.Close()
+ default:
+ // We failed earlier. (e.g. on proxy.c.Do(req2))
+ }
+}
+
+// Test that a hanging Request.Body.Read from another goroutine can't
+// cause the Handler goroutine's Request.Body.Close to block.
+// See issue 7121.
+func TestRequestBodyCloseDoesntBlock(t *testing.T) {
+ run(t, testRequestBodyCloseDoesntBlock, []testMode{http1Mode})
+}
+func testRequestBodyCloseDoesntBlock(t *testing.T, mode testMode) {
+ if testing.Short() {
+ t.Skip("skipping in -short mode")
+ }
+
+ readErrCh := make(chan error, 1)
+ errCh := make(chan error, 2)
+
+ server := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
+ go func(body io.Reader) {
+ _, err := body.Read(make([]byte, 100))
+ readErrCh <- err
+ }(req.Body)
+ time.Sleep(500 * time.Millisecond)
+ })).ts
+
+ closeConn := make(chan bool)
+ defer close(closeConn)
+ go func() {
+ conn, err := net.Dial("tcp", server.Listener.Addr().String())
+ if err != nil {
+ errCh <- err
+ return
+ }
+ defer conn.Close()
+ _, err = conn.Write([]byte("POST / HTTP/1.1\r\nConnection: close\r\nHost: foo\r\nContent-Length: 100000\r\n\r\n"))
+ if err != nil {
+ errCh <- err
+ return
+ }
+ // And now just block, making the server block on our
+ // 100000 bytes of body that will never arrive.
+ <-closeConn
+ }()
+ select {
+ case err := <-readErrCh:
+ if err == nil {
+ t.Error("Read was nil. Expected error.")
+ }
+ case err := <-errCh:
+ t.Error(err)
+ }
+}
+
+// test that ResponseWriter implements io.StringWriter.
+func TestResponseWriterWriteString(t *testing.T) {
+ okc := make(chan bool, 1)
+ ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
+ _, ok := w.(io.StringWriter)
+ okc <- ok
+ }))
+ ht.rawResponse("GET / HTTP/1.0")
+ select {
+ case ok := <-okc:
+ if !ok {
+ t.Error("ResponseWriter did not implement io.StringWriter")
+ }
+ default:
+ t.Error("handler was never called")
+ }
+}
+
+func TestAppendTime(t *testing.T) {
+ var b [len(TimeFormat)]byte
+ t1 := time.Date(2013, 9, 21, 15, 41, 0, 0, time.FixedZone("CEST", 2*60*60))
+ res := ExportAppendTime(b[:0], t1)
+ t2, err := ParseTime(string(res))
+ if err != nil {
+ t.Fatalf("Error parsing time: %s", err)
+ }
+ if !t1.Equal(t2) {
+ t.Fatalf("Times differ; expected: %v, got %v (%s)", t1, t2, string(res))
+ }
+}
+
+func TestServerConnState(t *testing.T) { run(t, testServerConnState, []testMode{http1Mode}) }
+func testServerConnState(t *testing.T, mode testMode) {
+ handler := map[string]func(w ResponseWriter, r *Request){
+ "/": func(w ResponseWriter, r *Request) {
+ fmt.Fprintf(w, "Hello.")
+ },
+ "/close": func(w ResponseWriter, r *Request) {
+ w.Header().Set("Connection", "close")
+ fmt.Fprintf(w, "Hello.")
+ },
+ "/hijack": func(w ResponseWriter, r *Request) {
+ c, _, _ := w.(Hijacker).Hijack()
+ c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello."))
+ c.Close()
+ },
+ "/hijack-panic": func(w ResponseWriter, r *Request) {
+ c, _, _ := w.(Hijacker).Hijack()
+ c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello."))
+ c.Close()
+ panic("intentional panic")
+ },
+ }
+
+ // A stateLog is a log of states over the lifetime of a connection.
+ type stateLog struct {
+ active net.Conn // The connection for which the log is recorded; set to the first connection seen in StateNew.
+ got []ConnState
+ want []ConnState
+ complete chan<- struct{} // If non-nil, closed when either 'got' is equal to 'want', or 'got' is no longer a prefix of 'want'.
+ }
+ activeLog := make(chan *stateLog, 1)
+
+ // wantLog invokes doRequests, then waits for the resulting connection to
+ // either pass through the sequence of states in want or enter a state outside
+ // of that sequence.
+ wantLog := func(doRequests func(), want ...ConnState) {
+ t.Helper()
+ complete := make(chan struct{})
+ activeLog <- &stateLog{want: want, complete: complete}
+
+ doRequests()
+
+ <-complete
+ sl := <-activeLog
+ if !reflect.DeepEqual(sl.got, sl.want) {
+ t.Errorf("Request(s) produced unexpected state sequence.\nGot: %v\nWant: %v", sl.got, sl.want)
+ }
+ // Don't return sl to activeLog: we don't expect any further states after
+ // this point, and want to keep the ConnState callback blocked until the
+ // next call to wantLog.
+ }
+
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ handler[r.URL.Path](w, r)
+ }), func(ts *httptest.Server) {
+ ts.Config.ErrorLog = log.New(io.Discard, "", 0)
+ ts.Config.ConnState = func(c net.Conn, state ConnState) {
+ if c == nil {
+ t.Errorf("nil conn seen in state %s", state)
+ return
+ }
+ sl := <-activeLog
+ if sl.active == nil && state == StateNew {
+ sl.active = c
+ } else if sl.active != c {
+ t.Errorf("unexpected conn in state %s", state)
+ activeLog <- sl
+ return
+ }
+ sl.got = append(sl.got, state)
+ if sl.complete != nil && (len(sl.got) >= len(sl.want) || !reflect.DeepEqual(sl.got, sl.want[:len(sl.got)])) {
+ close(sl.complete)
+ sl.complete = nil
+ }
+ activeLog <- sl
+ }
+ }).ts
+ defer func() {
+ activeLog <- &stateLog{} // If the test failed, allow any remaining ConnState callbacks to complete.
+ ts.Close()
+ }()
+
+ c := ts.Client()
+
+ mustGet := func(url string, headers ...string) {
+ t.Helper()
+ req, err := NewRequest("GET", url, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ for len(headers) > 0 {
+ req.Header.Add(headers[0], headers[1])
+ headers = headers[2:]
+ }
+ res, err := c.Do(req)
+ if err != nil {
+ t.Errorf("Error fetching %s: %v", url, err)
+ return
+ }
+ _, err = io.ReadAll(res.Body)
+ defer res.Body.Close()
+ if err != nil {
+ t.Errorf("Error reading %s: %v", url, err)
+ }
+ }
+
+ wantLog(func() {
+ mustGet(ts.URL + "/")
+ mustGet(ts.URL + "/close")
+ }, StateNew, StateActive, StateIdle, StateActive, StateClosed)
+
+ wantLog(func() {
+ mustGet(ts.URL + "/")
+ mustGet(ts.URL+"/", "Connection", "close")
+ }, StateNew, StateActive, StateIdle, StateActive, StateClosed)
+
+ wantLog(func() {
+ mustGet(ts.URL + "/hijack")
+ }, StateNew, StateActive, StateHijacked)
+
+ wantLog(func() {
+ mustGet(ts.URL + "/hijack-panic")
+ }, StateNew, StateActive, StateHijacked)
+
+ wantLog(func() {
+ c, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ c.Close()
+ }, StateNew, StateClosed)
+
+ wantLog(func() {
+ c, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ if _, err := io.WriteString(c, "BOGUS REQUEST\r\n\r\n"); err != nil {
+ t.Fatal(err)
+ }
+ c.Read(make([]byte, 1)) // block until server hangs up on us
+ c.Close()
+ }, StateNew, StateActive, StateClosed)
+
+ wantLog(func() {
+ c, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ if _, err := io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n"); err != nil {
+ t.Fatal(err)
+ }
+ res, err := ReadResponse(bufio.NewReader(c), nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if _, err := io.Copy(io.Discard, res.Body); err != nil {
+ t.Fatal(err)
+ }
+ c.Close()
+ }, StateNew, StateActive, StateIdle, StateClosed)
+}
+
+func TestServerKeepAlivesEnabledResultClose(t *testing.T) {
+ run(t, testServerKeepAlivesEnabledResultClose, []testMode{http1Mode})
+}
+func testServerKeepAlivesEnabledResultClose(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ }), func(ts *httptest.Server) {
+ ts.Config.SetKeepAlivesEnabled(false)
+ }).ts
+ res, err := ts.Client().Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ if !res.Close {
+ t.Errorf("Body.Close == false; want true")
+ }
+}
+
+// golang.org/issue/7856
+func TestServerEmptyBodyRace(t *testing.T) { run(t, testServerEmptyBodyRace) }
+func testServerEmptyBodyRace(t *testing.T, mode testMode) {
+ var n int32
+ cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
+ atomic.AddInt32(&n, 1)
+ }), optQuietLog)
+ var wg sync.WaitGroup
+ const reqs = 20
+ for i := 0; i < reqs; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ // Try to deflake spurious "connection reset by peer" under load.
+ // See golang.org/issue/22540.
+ time.Sleep(10 * time.Millisecond)
+ res, err = cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ }
+ defer res.Body.Close()
+ _, err = io.Copy(io.Discard, res.Body)
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ }()
+ }
+ wg.Wait()
+ if got := atomic.LoadInt32(&n); got != reqs {
+ t.Errorf("handler ran %d times; want %d", got, reqs)
+ }
+}
+
+func TestServerConnStateNew(t *testing.T) {
+ sawNew := false // if the test is buggy, we'll race on this variable.
+ srv := &Server{
+ ConnState: func(c net.Conn, state ConnState) {
+ if state == StateNew {
+ sawNew = true // testing that this write isn't racy
+ }
+ },
+ Handler: HandlerFunc(func(w ResponseWriter, r *Request) {}), // irrelevant
+ }
+ srv.Serve(&oneConnListener{
+ conn: &rwTestConn{
+ Reader: strings.NewReader("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"),
+ Writer: io.Discard,
+ },
+ })
+ if !sawNew { // testing that this read isn't racy
+ t.Error("StateNew not seen")
+ }
+}
+
+type closeWriteTestConn struct {
+ rwTestConn
+ didCloseWrite bool
+}
+
+func (c *closeWriteTestConn) CloseWrite() error {
+ c.didCloseWrite = true
+ return nil
+}
+
+func TestCloseWrite(t *testing.T) {
+ setParallel(t)
+ var srv Server
+ var testConn closeWriteTestConn
+ c := ExportServerNewConn(&srv, &testConn)
+ ExportCloseWriteAndWait(c)
+ if !testConn.didCloseWrite {
+ t.Error("didn't see CloseWrite call")
+ }
+}
+
+// This verifies that a handler can Flush and then Hijack.
+//
+// A similar test crashed once during development, but it was only
+// testing this tangentially and temporarily until another TODO was
+// fixed.
+//
+// So add an explicit test for this.
+func TestServerFlushAndHijack(t *testing.T) { run(t, testServerFlushAndHijack, []testMode{http1Mode}) }
+func testServerFlushAndHijack(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ io.WriteString(w, "Hello, ")
+ w.(Flusher).Flush()
+ conn, buf, _ := w.(Hijacker).Hijack()
+ buf.WriteString("6\r\nworld!\r\n0\r\n\r\n")
+ if err := buf.Flush(); err != nil {
+ t.Error(err)
+ }
+ if err := conn.Close(); err != nil {
+ t.Error(err)
+ }
+ })).ts
+ res, err := Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ all, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if want := "Hello, world!"; string(all) != want {
+ t.Errorf("Got %q; want %q", all, want)
+ }
+}
+
+// golang.org/issue/8534 -- the Server shouldn't reuse a connection
+// for keep-alive after it's seen any Write error (e.g. a timeout) on
+// that net.Conn.
+//
+// To test, verify we don't timeout or see fewer unique client
+// addresses (== unique connections) than requests.
+func TestServerKeepAliveAfterWriteError(t *testing.T) {
+ run(t, testServerKeepAliveAfterWriteError, []testMode{http1Mode})
+}
+func testServerKeepAliveAfterWriteError(t *testing.T, mode testMode) {
+ if testing.Short() {
+ t.Skip("skipping in -short mode")
+ }
+ const numReq = 3
+ addrc := make(chan string, numReq)
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ addrc <- r.RemoteAddr
+ time.Sleep(500 * time.Millisecond)
+ w.(Flusher).Flush()
+ }), func(ts *httptest.Server) {
+ ts.Config.WriteTimeout = 250 * time.Millisecond
+ }).ts
+
+ errc := make(chan error, numReq)
+ go func() {
+ defer close(errc)
+ for i := 0; i < numReq; i++ {
+ res, err := Get(ts.URL)
+ if res != nil {
+ res.Body.Close()
+ }
+ errc <- err
+ }
+ }()
+
+ addrSeen := map[string]bool{}
+ numOkay := 0
+ for {
+ select {
+ case v := <-addrc:
+ addrSeen[v] = true
+ case err, ok := <-errc:
+ if !ok {
+ if len(addrSeen) != numReq {
+ t.Errorf("saw %d unique client addresses; want %d", len(addrSeen), numReq)
+ }
+ if numOkay != 0 {
+ t.Errorf("got %d successful client requests; want 0", numOkay)
+ }
+ return
+ }
+ if err == nil {
+ numOkay++
+ }
+ }
+ }
+}
+
+// Issue 9987: shouldn't add automatic Content-Length (or
+// Content-Type) if a Transfer-Encoding was set by the handler.
+func TestNoContentLengthIfTransferEncoding(t *testing.T) {
+ run(t, testNoContentLengthIfTransferEncoding, []testMode{http1Mode})
+}
+func testNoContentLengthIfTransferEncoding(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Transfer-Encoding", "foo")
+ io.WriteString(w, "<html>")
+ })).ts
+ c, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ defer c.Close()
+ if _, err := io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n"); err != nil {
+ t.Fatal(err)
+ }
+ bs := bufio.NewScanner(c)
+ var got strings.Builder
+ for bs.Scan() {
+ if strings.TrimSpace(bs.Text()) == "" {
+ break
+ }
+ got.WriteString(bs.Text())
+ got.WriteByte('\n')
+ }
+ if err := bs.Err(); err != nil {
+ t.Fatal(err)
+ }
+ if strings.Contains(got.String(), "Content-Length") {
+ t.Errorf("Unexpected Content-Length in response headers: %s", got.String())
+ }
+ if strings.Contains(got.String(), "Content-Type") {
+ t.Errorf("Unexpected Content-Type in response headers: %s", got.String())
+ }
+}
+
+// tolerate extra CRLF(s) before Request-Line on subsequent requests on a conn
+// Issue 10876.
+func TestTolerateCRLFBeforeRequestLine(t *testing.T) {
+ req := []byte("POST / HTTP/1.1\r\nHost: golang.org\r\nContent-Length: 3\r\n\r\nABC" +
+ "\r\n\r\n" + // <-- this stuff is bogus, but we'll ignore it
+ "GET / HTTP/1.1\r\nHost: golang.org\r\n\r\n")
+ var buf bytes.Buffer
+ conn := &rwTestConn{
+ Reader: bytes.NewReader(req),
+ Writer: &buf,
+ closec: make(chan bool, 1),
+ }
+ ln := &oneConnListener{conn: conn}
+ numReq := 0
+ go Serve(ln, HandlerFunc(func(rw ResponseWriter, r *Request) {
+ numReq++
+ }))
+ <-conn.closec
+ if numReq != 2 {
+ t.Errorf("num requests = %d; want 2", numReq)
+ t.Logf("Res: %s", buf.Bytes())
+ }
+}
+
+func TestIssue13893_Expect100(t *testing.T) {
+ // test that the Server doesn't filter out Expect headers.
+ req := reqBytes(`PUT /readbody HTTP/1.1
+User-Agent: PycURL/7.22.0
+Host: 127.0.0.1:9000
+Accept: */*
+Expect: 100-continue
+Content-Length: 10
+
+HelloWorld
+
+`)
+ var buf bytes.Buffer
+ conn := &rwTestConn{
+ Reader: bytes.NewReader(req),
+ Writer: &buf,
+ closec: make(chan bool, 1),
+ }
+ ln := &oneConnListener{conn: conn}
+ go Serve(ln, HandlerFunc(func(w ResponseWriter, r *Request) {
+ if _, ok := r.Header["Expect"]; !ok {
+ t.Error("Expect header should not be filtered out")
+ }
+ }))
+ <-conn.closec
+}
+
+func TestIssue11549_Expect100(t *testing.T) {
+ req := reqBytes(`PUT /readbody HTTP/1.1
+User-Agent: PycURL/7.22.0
+Host: 127.0.0.1:9000
+Accept: */*
+Expect: 100-continue
+Content-Length: 10
+
+HelloWorldPUT /noreadbody HTTP/1.1
+User-Agent: PycURL/7.22.0
+Host: 127.0.0.1:9000
+Accept: */*
+Expect: 100-continue
+Content-Length: 10
+
+GET /should-be-ignored HTTP/1.1
+Host: foo
+
+`)
+ var buf strings.Builder
+ conn := &rwTestConn{
+ Reader: bytes.NewReader(req),
+ Writer: &buf,
+ closec: make(chan bool, 1),
+ }
+ ln := &oneConnListener{conn: conn}
+ numReq := 0
+ go Serve(ln, HandlerFunc(func(w ResponseWriter, r *Request) {
+ numReq++
+ if r.URL.Path == "/readbody" {
+ io.ReadAll(r.Body)
+ }
+ io.WriteString(w, "Hello world!")
+ }))
+ <-conn.closec
+ if numReq != 2 {
+ t.Errorf("num requests = %d; want 2", numReq)
+ }
+ if !strings.Contains(buf.String(), "Connection: close\r\n") {
+ t.Errorf("expected 'Connection: close' in response; got: %s", buf.String())
+ }
+}
+
+// If a Handler finishes and there's an unread request body,
+// verify the server try to do implicit read on it before replying.
+func TestHandlerFinishSkipBigContentLengthRead(t *testing.T) {
+ setParallel(t)
+ conn := &testConn{closec: make(chan bool)}
+ conn.readBuf.Write([]byte(fmt.Sprintf(
+ "POST / HTTP/1.1\r\n" +
+ "Host: test\r\n" +
+ "Content-Length: 9999999999\r\n" +
+ "\r\n" + strings.Repeat("a", 1<<20))))
+
+ ls := &oneConnListener{conn}
+ var inHandlerLen int
+ go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
+ inHandlerLen = conn.readBuf.Len()
+ rw.WriteHeader(404)
+ }))
+ <-conn.closec
+ afterHandlerLen := conn.readBuf.Len()
+
+ if afterHandlerLen != inHandlerLen {
+ t.Errorf("unexpected implicit read. Read buffer went from %d -> %d", inHandlerLen, afterHandlerLen)
+ }
+}
+
+func TestHandlerSetsBodyNil(t *testing.T) { run(t, testHandlerSetsBodyNil) }
+func testHandlerSetsBodyNil(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ r.Body = nil
+ fmt.Fprintf(w, "%v", r.RemoteAddr)
+ }))
+ get := func() string {
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ slurp, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ return string(slurp)
+ }
+ a, b := get(), get()
+ if a != b {
+ t.Errorf("Failed to reuse connections between requests: %v vs %v", a, b)
+ }
+}
+
+// Test that we validate the Host header.
+// Issue 11206 (invalid bytes in Host) and 13624 (Host present in HTTP/1.1)
+func TestServerValidatesHostHeader(t *testing.T) {
+ tests := []struct {
+ proto string
+ host string
+ want int
+ }{
+ {"HTTP/0.9", "", 505},
+
+ {"HTTP/1.1", "", 400},
+ {"HTTP/1.1", "Host: \r\n", 200},
+ {"HTTP/1.1", "Host: 1.2.3.4\r\n", 200},
+ {"HTTP/1.1", "Host: foo.com\r\n", 200},
+ {"HTTP/1.1", "Host: foo-bar_baz.com\r\n", 200},
+ {"HTTP/1.1", "Host: foo.com:80\r\n", 200},
+ {"HTTP/1.1", "Host: ::1\r\n", 200},
+ {"HTTP/1.1", "Host: [::1]\r\n", 200}, // questionable without port, but accept it
+ {"HTTP/1.1", "Host: [::1]:80\r\n", 200},
+ {"HTTP/1.1", "Host: [::1%25en0]:80\r\n", 200},
+ {"HTTP/1.1", "Host: 1.2.3.4\r\n", 200},
+ {"HTTP/1.1", "Host: \x06\r\n", 400},
+ {"HTTP/1.1", "Host: \xff\r\n", 400},
+ {"HTTP/1.1", "Host: {\r\n", 400},
+ {"HTTP/1.1", "Host: }\r\n", 400},
+ {"HTTP/1.1", "Host: first\r\nHost: second\r\n", 400},
+
+ // HTTP/1.0 can lack a host header, but if present
+ // must play by the rules too:
+ {"HTTP/1.0", "", 200},
+ {"HTTP/1.0", "Host: first\r\nHost: second\r\n", 400},
+ {"HTTP/1.0", "Host: \xff\r\n", 400},
+
+ // Make an exception for HTTP upgrade requests:
+ {"PRI * HTTP/2.0", "", 200},
+
+ // Also an exception for CONNECT requests: (Issue 18215)
+ {"CONNECT golang.org:443 HTTP/1.1", "", 200},
+
+ // But not other HTTP/2 stuff:
+ {"PRI / HTTP/2.0", "", 505},
+ {"GET / HTTP/2.0", "", 505},
+ {"GET / HTTP/3.0", "", 505},
+ }
+ for _, tt := range tests {
+ conn := &testConn{closec: make(chan bool, 1)}
+ methodTarget := "GET / "
+ if !strings.HasPrefix(tt.proto, "HTTP/") {
+ methodTarget = ""
+ }
+ io.WriteString(&conn.readBuf, methodTarget+tt.proto+"\r\n"+tt.host+"\r\n")
+
+ ln := &oneConnListener{conn}
+ srv := Server{
+ ErrorLog: quietLog,
+ Handler: HandlerFunc(func(ResponseWriter, *Request) {}),
+ }
+ go srv.Serve(ln)
+ <-conn.closec
+ res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
+ if err != nil {
+ t.Errorf("For %s %q, ReadResponse: %v", tt.proto, tt.host, res)
+ continue
+ }
+ if res.StatusCode != tt.want {
+ t.Errorf("For %s %q, Status = %d; want %d", tt.proto, tt.host, res.StatusCode, tt.want)
+ }
+ }
+}
+
+func TestServerHandlersCanHandleH2PRI(t *testing.T) {
+ run(t, testServerHandlersCanHandleH2PRI, []testMode{http1Mode})
+}
+func testServerHandlersCanHandleH2PRI(t *testing.T, mode testMode) {
+ const upgradeResponse = "upgrade here"
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ conn, br, err := w.(Hijacker).Hijack()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ defer conn.Close()
+ if r.Method != "PRI" || r.RequestURI != "*" {
+ t.Errorf("Got method/target %q %q; want PRI *", r.Method, r.RequestURI)
+ return
+ }
+ if !r.Close {
+ t.Errorf("Request.Close = true; want false")
+ }
+ const want = "SM\r\n\r\n"
+ buf := make([]byte, len(want))
+ n, err := io.ReadFull(br, buf)
+ if err != nil || string(buf[:n]) != want {
+ t.Errorf("Read = %v, %v (%q), want %q", n, err, buf[:n], want)
+ return
+ }
+ io.WriteString(conn, upgradeResponse)
+ })).ts
+
+ c, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ defer c.Close()
+ io.WriteString(c, "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n")
+ slurp, err := io.ReadAll(c)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if string(slurp) != upgradeResponse {
+ t.Errorf("Handler response = %q; want %q", slurp, upgradeResponse)
+ }
+}
+
+// Test that we validate the valid bytes in HTTP/1 headers.
+// Issue 11207.
+func TestServerValidatesHeaders(t *testing.T) {
+ setParallel(t)
+ tests := []struct {
+ header string
+ want int
+ }{
+ {"", 200},
+ {"Foo: bar\r\n", 200},
+ {"X-Foo: bar\r\n", 200},
+ {"Foo: a space\r\n", 200},
+
+ {"A space: foo\r\n", 400}, // space in header
+ {"foo\xffbar: foo\r\n", 400}, // binary in header
+ {"foo\x00bar: foo\r\n", 400}, // binary in header
+ {"Foo: " + strings.Repeat("x", 1<<21) + "\r\n", 431}, // header too large
+ // Spaces between the header key and colon are not allowed.
+ // See RFC 7230, Section 3.2.4.
+ {"Foo : bar\r\n", 400},
+ {"Foo\t: bar\r\n", 400},
+
+ {"foo: foo foo\r\n", 200}, // LWS space is okay
+ {"foo: foo\tfoo\r\n", 200}, // LWS tab is okay
+ {"foo: foo\x00foo\r\n", 400}, // CTL 0x00 in value is bad
+ {"foo: foo\x7ffoo\r\n", 400}, // CTL 0x7f in value is bad
+ {"foo: foo\xfffoo\r\n", 200}, // non-ASCII high octets in value are fine
+ }
+ for _, tt := range tests {
+ conn := &testConn{closec: make(chan bool, 1)}
+ io.WriteString(&conn.readBuf, "GET / HTTP/1.1\r\nHost: foo\r\n"+tt.header+"\r\n")
+
+ ln := &oneConnListener{conn}
+ srv := Server{
+ ErrorLog: quietLog,
+ Handler: HandlerFunc(func(ResponseWriter, *Request) {}),
+ }
+ go srv.Serve(ln)
+ <-conn.closec
+ res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
+ if err != nil {
+ t.Errorf("For %q, ReadResponse: %v", tt.header, res)
+ continue
+ }
+ if res.StatusCode != tt.want {
+ t.Errorf("For %q, Status = %d; want %d", tt.header, res.StatusCode, tt.want)
+ }
+ }
+}
+
+func TestServerRequestContextCancel_ServeHTTPDone(t *testing.T) {
+ run(t, testServerRequestContextCancel_ServeHTTPDone)
+}
+func testServerRequestContextCancel_ServeHTTPDone(t *testing.T, mode testMode) {
+ ctxc := make(chan context.Context, 1)
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ ctx := r.Context()
+ select {
+ case <-ctx.Done():
+ t.Error("should not be Done in ServeHTTP")
+ default:
+ }
+ ctxc <- ctx
+ }))
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ ctx := <-ctxc
+ select {
+ case <-ctx.Done():
+ default:
+ t.Error("context should be done after ServeHTTP completes")
+ }
+}
+
+// Tests that the Request.Context available to the Handler is canceled
+// if the peer closes their TCP connection. This requires that the server
+// is always blocked in a Read call so it notices the EOF from the client.
+// See issues 15927 and 15224.
+func TestServerRequestContextCancel_ConnClose(t *testing.T) {
+ run(t, testServerRequestContextCancel_ConnClose, []testMode{http1Mode})
+}
+func testServerRequestContextCancel_ConnClose(t *testing.T, mode testMode) {
+ inHandler := make(chan struct{})
+ handlerDone := make(chan struct{})
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ close(inHandler)
+ <-r.Context().Done()
+ close(handlerDone)
+ })).ts
+ c, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+ io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n")
+ <-inHandler
+ c.Close() // this should trigger the context being done
+ <-handlerDone
+}
+
+func TestServerContext_ServerContextKey(t *testing.T) {
+ run(t, testServerContext_ServerContextKey)
+}
+func testServerContext_ServerContextKey(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ ctx := r.Context()
+ got := ctx.Value(ServerContextKey)
+ if _, ok := got.(*Server); !ok {
+ t.Errorf("context value = %T; want *http.Server", got)
+ }
+ }))
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+}
+
+func TestServerContext_LocalAddrContextKey(t *testing.T) {
+ run(t, testServerContext_LocalAddrContextKey)
+}
+func testServerContext_LocalAddrContextKey(t *testing.T, mode testMode) {
+ ch := make(chan any, 1)
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ ch <- r.Context().Value(LocalAddrContextKey)
+ }))
+ if _, err := cst.c.Head(cst.ts.URL); err != nil {
+ t.Fatal(err)
+ }
+
+ host := cst.ts.Listener.Addr().String()
+ got := <-ch
+ if addr, ok := got.(net.Addr); !ok {
+ t.Errorf("local addr value = %T; want net.Addr", got)
+ } else if fmt.Sprint(addr) != host {
+ t.Errorf("local addr = %v; want %v", addr, host)
+ }
+}
+
+// https://golang.org/issue/15960
+func TestHandlerSetTransferEncodingChunked(t *testing.T) {
+ setParallel(t)
+ defer afterTest(t)
+ ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Transfer-Encoding", "chunked")
+ w.Write([]byte("hello"))
+ }))
+ resp := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
+ const hdr = "Transfer-Encoding: chunked"
+ if n := strings.Count(resp, hdr); n != 1 {
+ t.Errorf("want 1 occurrence of %q in response, got %v\nresponse: %v", hdr, n, resp)
+ }
+}
+
+// https://golang.org/issue/16063
+func TestHandlerSetTransferEncodingGzip(t *testing.T) {
+ setParallel(t)
+ defer afterTest(t)
+ ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Transfer-Encoding", "gzip")
+ gz := gzip.NewWriter(w)
+ gz.Write([]byte("hello"))
+ gz.Close()
+ }))
+ resp := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
+ for _, v := range []string{"gzip", "chunked"} {
+ hdr := "Transfer-Encoding: " + v
+ if n := strings.Count(resp, hdr); n != 1 {
+ t.Errorf("want 1 occurrence of %q in response, got %v\nresponse: %v", hdr, n, resp)
+ }
+ }
+}
+
+func BenchmarkClientServer(b *testing.B) {
+ run(b, benchmarkClientServer, []testMode{http1Mode, https1Mode, http2Mode})
+}
+func benchmarkClientServer(b *testing.B, mode testMode) {
+ b.ReportAllocs()
+ b.StopTimer()
+ ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
+ fmt.Fprintf(rw, "Hello world.\n")
+ })).ts
+ b.StartTimer()
+
+ c := ts.Client()
+ for i := 0; i < b.N; i++ {
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ b.Fatal("Get:", err)
+ }
+ all, err := io.ReadAll(res.Body)
+ res.Body.Close()
+ if err != nil {
+ b.Fatal("ReadAll:", err)
+ }
+ body := string(all)
+ if body != "Hello world.\n" {
+ b.Fatal("Got body:", body)
+ }
+ }
+
+ b.StopTimer()
+}
+
+func BenchmarkClientServerParallel(b *testing.B) {
+ for _, parallelism := range []int{4, 64} {
+ b.Run(fmt.Sprint(parallelism), func(b *testing.B) {
+ run(b, func(b *testing.B, mode testMode) {
+ benchmarkClientServerParallel(b, parallelism, mode)
+ }, []testMode{http1Mode, https1Mode, http2Mode})
+ })
+ }
+}
+
+func benchmarkClientServerParallel(b *testing.B, parallelism int, mode testMode) {
+ b.ReportAllocs()
+ ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
+ fmt.Fprintf(rw, "Hello world.\n")
+ })).ts
+ b.ResetTimer()
+ b.SetParallelism(parallelism)
+ b.RunParallel(func(pb *testing.PB) {
+ c := ts.Client()
+ for pb.Next() {
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ b.Logf("Get: %v", err)
+ continue
+ }
+ all, err := io.ReadAll(res.Body)
+ res.Body.Close()
+ if err != nil {
+ b.Logf("ReadAll: %v", err)
+ continue
+ }
+ body := string(all)
+ if body != "Hello world.\n" {
+ panic("Got body: " + body)
+ }
+ }
+ })
+}
+
+// A benchmark for profiling the server without the HTTP client code.
+// The client code runs in a subprocess.
+//
+// For use like:
+//
+// $ go test -c
+// $ ./http.test -test.run=XX -test.bench=BenchmarkServer -test.benchtime=15s -test.cpuprofile=http.prof
+// $ go tool pprof http.test http.prof
+// (pprof) web
+func BenchmarkServer(b *testing.B) {
+ b.ReportAllocs()
+ // Child process mode;
+ if url := os.Getenv("TEST_BENCH_SERVER_URL"); url != "" {
+ n, err := strconv.Atoi(os.Getenv("TEST_BENCH_CLIENT_N"))
+ if err != nil {
+ panic(err)
+ }
+ for i := 0; i < n; i++ {
+ res, err := Get(url)
+ if err != nil {
+ log.Panicf("Get: %v", err)
+ }
+ all, err := io.ReadAll(res.Body)
+ res.Body.Close()
+ if err != nil {
+ log.Panicf("ReadAll: %v", err)
+ }
+ body := string(all)
+ if body != "Hello world.\n" {
+ log.Panicf("Got body: %q", body)
+ }
+ }
+ os.Exit(0)
+ return
+ }
+
+ var res = []byte("Hello world.\n")
+ b.StopTimer()
+ ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) {
+ rw.Header().Set("Content-Type", "text/html; charset=utf-8")
+ rw.Write(res)
+ }))
+ defer ts.Close()
+ b.StartTimer()
+
+ cmd := exec.Command(os.Args[0], "-test.run=XXXX", "-test.bench=BenchmarkServer$")
+ cmd.Env = append([]string{
+ fmt.Sprintf("TEST_BENCH_CLIENT_N=%d", b.N),
+ fmt.Sprintf("TEST_BENCH_SERVER_URL=%s", ts.URL),
+ }, os.Environ()...)
+ out, err := cmd.CombinedOutput()
+ if err != nil {
+ b.Errorf("Test failure: %v, with output: %s", err, out)
+ }
+}
+
+// getNoBody wraps Get but closes any Response.Body before returning the response.
+func getNoBody(urlStr string) (*Response, error) {
+ res, err := Get(urlStr)
+ if err != nil {
+ return nil, err
+ }
+ res.Body.Close()
+ return res, nil
+}
+
+// A benchmark for profiling the client without the HTTP server code.
+// The server code runs in a subprocess.
+func BenchmarkClient(b *testing.B) {
+ b.ReportAllocs()
+ b.StopTimer()
+ defer afterTest(b)
+
+ var data = []byte("Hello world.\n")
+ if server := os.Getenv("TEST_BENCH_SERVER"); server != "" {
+ // Server process mode.
+ port := os.Getenv("TEST_BENCH_SERVER_PORT") // can be set by user
+ if port == "" {
+ port = "0"
+ }
+ ln, err := net.Listen("tcp", "localhost:"+port)
+ if err != nil {
+ fmt.Fprintln(os.Stderr, err.Error())
+ os.Exit(1)
+ }
+ fmt.Println(ln.Addr().String())
+ HandleFunc("/", func(w ResponseWriter, r *Request) {
+ r.ParseForm()
+ if r.Form.Get("stop") != "" {
+ os.Exit(0)
+ }
+ w.Header().Set("Content-Type", "text/html; charset=utf-8")
+ w.Write(data)
+ })
+ var srv Server
+ log.Fatal(srv.Serve(ln))
+ }
+
+ // Start server process.
+ ctx, cancel := context.WithCancel(context.Background())
+ cmd := testenv.CommandContext(b, ctx, os.Args[0], "-test.run=XXXX", "-test.bench=BenchmarkClient$")
+ cmd.Env = append(cmd.Environ(), "TEST_BENCH_SERVER=yes")
+ cmd.Stderr = os.Stderr
+ stdout, err := cmd.StdoutPipe()
+ if err != nil {
+ b.Fatal(err)
+ }
+ if err := cmd.Start(); err != nil {
+ b.Fatalf("subprocess failed to start: %v", err)
+ }
+
+ done := make(chan error, 1)
+ go func() {
+ done <- cmd.Wait()
+ close(done)
+ }()
+ defer func() {
+ cancel()
+ <-done
+ }()
+
+ // Wait for the server in the child process to respond and tell us
+ // its listening address, once it's started listening:
+ bs := bufio.NewScanner(stdout)
+ if !bs.Scan() {
+ b.Fatalf("failed to read listening URL from child: %v", bs.Err())
+ }
+ url := "http://" + strings.TrimSpace(bs.Text()) + "/"
+ if _, err := getNoBody(url); err != nil {
+ b.Fatalf("initial probe of child process failed: %v", err)
+ }
+
+ // Do b.N requests to the server.
+ b.StartTimer()
+ for i := 0; i < b.N; i++ {
+ res, err := Get(url)
+ if err != nil {
+ b.Fatalf("Get: %v", err)
+ }
+ body, err := io.ReadAll(res.Body)
+ res.Body.Close()
+ if err != nil {
+ b.Fatalf("ReadAll: %v", err)
+ }
+ if !bytes.Equal(body, data) {
+ b.Fatalf("Got body: %q", body)
+ }
+ }
+ b.StopTimer()
+
+ // Instruct server process to stop.
+ getNoBody(url + "?stop=yes")
+ if err := <-done; err != nil {
+ b.Fatalf("subprocess failed: %v", err)
+ }
+}
+
+func BenchmarkServerFakeConnNoKeepAlive(b *testing.B) {
+ b.ReportAllocs()
+ req := reqBytes(`GET / HTTP/1.0
+Host: golang.org
+Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
+User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17
+Accept-Encoding: gzip,deflate,sdch
+Accept-Language: en-US,en;q=0.8
+Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3
+`)
+ res := []byte("Hello world!\n")
+
+ conn := &testConn{
+ // testConn.Close will not push into the channel
+ // if it's full.
+ closec: make(chan bool, 1),
+ }
+ handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
+ rw.Header().Set("Content-Type", "text/html; charset=utf-8")
+ rw.Write(res)
+ })
+ ln := new(oneConnListener)
+ for i := 0; i < b.N; i++ {
+ conn.readBuf.Reset()
+ conn.writeBuf.Reset()
+ conn.readBuf.Write(req)
+ ln.conn = conn
+ Serve(ln, handler)
+ <-conn.closec
+ }
+}
+
+// repeatReader reads content count times, then EOFs.
+type repeatReader struct {
+ content []byte
+ count int
+ off int
+}
+
+func (r *repeatReader) Read(p []byte) (n int, err error) {
+ if r.count <= 0 {
+ return 0, io.EOF
+ }
+ n = copy(p, r.content[r.off:])
+ r.off += n
+ if r.off == len(r.content) {
+ r.count--
+ r.off = 0
+ }
+ return
+}
+
+func BenchmarkServerFakeConnWithKeepAlive(b *testing.B) {
+ b.ReportAllocs()
+
+ req := reqBytes(`GET / HTTP/1.1
+Host: golang.org
+Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
+User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17
+Accept-Encoding: gzip,deflate,sdch
+Accept-Language: en-US,en;q=0.8
+Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3
+`)
+ res := []byte("Hello world!\n")
+
+ conn := &rwTestConn{
+ Reader: &repeatReader{content: req, count: b.N},
+ Writer: io.Discard,
+ closec: make(chan bool, 1),
+ }
+ handled := 0
+ handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
+ handled++
+ rw.Header().Set("Content-Type", "text/html; charset=utf-8")
+ rw.Write(res)
+ })
+ ln := &oneConnListener{conn: conn}
+ go Serve(ln, handler)
+ <-conn.closec
+ if b.N != handled {
+ b.Errorf("b.N=%d but handled %d", b.N, handled)
+ }
+}
+
+// same as above, but representing the most simple possible request
+// and handler. Notably: the handler does not call rw.Header().
+func BenchmarkServerFakeConnWithKeepAliveLite(b *testing.B) {
+ b.ReportAllocs()
+
+ req := reqBytes(`GET / HTTP/1.1
+Host: golang.org
+`)
+ res := []byte("Hello world!\n")
+
+ conn := &rwTestConn{
+ Reader: &repeatReader{content: req, count: b.N},
+ Writer: io.Discard,
+ closec: make(chan bool, 1),
+ }
+ handled := 0
+ handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
+ handled++
+ rw.Write(res)
+ })
+ ln := &oneConnListener{conn: conn}
+ go Serve(ln, handler)
+ <-conn.closec
+ if b.N != handled {
+ b.Errorf("b.N=%d but handled %d", b.N, handled)
+ }
+}
+
+const someResponse = "<html>some response</html>"
+
+// A Response that's just no bigger than 2KB, the buffer-before-chunking threshold.
+var response = bytes.Repeat([]byte(someResponse), 2<<10/len(someResponse))
+
+// Both Content-Type and Content-Length set. Should be no buffering.
+func BenchmarkServerHandlerTypeLen(b *testing.B) {
+ benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Type", "text/html")
+ w.Header().Set("Content-Length", strconv.Itoa(len(response)))
+ w.Write(response)
+ }))
+}
+
+// A Content-Type is set, but no length. No sniffing, but will count the Content-Length.
+func BenchmarkServerHandlerNoLen(b *testing.B) {
+ benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Type", "text/html")
+ w.Write(response)
+ }))
+}
+
+// A Content-Length is set, but the Content-Type will be sniffed.
+func BenchmarkServerHandlerNoType(b *testing.B) {
+ benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Length", strconv.Itoa(len(response)))
+ w.Write(response)
+ }))
+}
+
+// Neither a Content-Type or Content-Length, so sniffed and counted.
+func BenchmarkServerHandlerNoHeader(b *testing.B) {
+ benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Write(response)
+ }))
+}
+
+func benchmarkHandler(b *testing.B, h Handler) {
+ b.ReportAllocs()
+ req := reqBytes(`GET / HTTP/1.1
+Host: golang.org
+`)
+ conn := &rwTestConn{
+ Reader: &repeatReader{content: req, count: b.N},
+ Writer: io.Discard,
+ closec: make(chan bool, 1),
+ }
+ handled := 0
+ handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
+ handled++
+ h.ServeHTTP(rw, r)
+ })
+ ln := &oneConnListener{conn: conn}
+ go Serve(ln, handler)
+ <-conn.closec
+ if b.N != handled {
+ b.Errorf("b.N=%d but handled %d", b.N, handled)
+ }
+}
+
+func BenchmarkServerHijack(b *testing.B) {
+ b.ReportAllocs()
+ req := reqBytes(`GET / HTTP/1.1
+Host: golang.org
+`)
+ h := HandlerFunc(func(w ResponseWriter, r *Request) {
+ conn, _, err := w.(Hijacker).Hijack()
+ if err != nil {
+ panic(err)
+ }
+ conn.Close()
+ })
+ conn := &rwTestConn{
+ Writer: io.Discard,
+ closec: make(chan bool, 1),
+ }
+ ln := &oneConnListener{conn: conn}
+ for i := 0; i < b.N; i++ {
+ conn.Reader = bytes.NewReader(req)
+ ln.conn = conn
+ Serve(ln, h)
+ <-conn.closec
+ }
+}
+
+func BenchmarkCloseNotifier(b *testing.B) { run(b, benchmarkCloseNotifier, []testMode{http1Mode}) }
+func benchmarkCloseNotifier(b *testing.B, mode testMode) {
+ b.ReportAllocs()
+ b.StopTimer()
+ sawClose := make(chan bool)
+ ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
+ <-rw.(CloseNotifier).CloseNotify()
+ sawClose <- true
+ })).ts
+ b.StartTimer()
+ for i := 0; i < b.N; i++ {
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ b.Fatalf("error dialing: %v", err)
+ }
+ _, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n")
+ if err != nil {
+ b.Fatal(err)
+ }
+ conn.Close()
+ <-sawClose
+ }
+ b.StopTimer()
+}
+
+// Verify this doesn't race (Issue 16505)
+func TestConcurrentServerServe(t *testing.T) {
+ setParallel(t)
+ for i := 0; i < 100; i++ {
+ ln1 := &oneConnListener{conn: nil}
+ ln2 := &oneConnListener{conn: nil}
+ srv := Server{}
+ go func() { srv.Serve(ln1) }()
+ go func() { srv.Serve(ln2) }()
+ }
+}
+
+func TestServerIdleTimeout(t *testing.T) { run(t, testServerIdleTimeout, []testMode{http1Mode}) }
+func testServerIdleTimeout(t *testing.T, mode testMode) {
+ if testing.Short() {
+ t.Skip("skipping in short mode")
+ }
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ io.Copy(io.Discard, r.Body)
+ io.WriteString(w, r.RemoteAddr)
+ }), func(ts *httptest.Server) {
+ ts.Config.ReadHeaderTimeout = 1 * time.Second
+ ts.Config.IdleTimeout = 2 * time.Second
+ }).ts
+ c := ts.Client()
+
+ get := func() string {
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ slurp, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ return string(slurp)
+ }
+
+ a1, a2 := get(), get()
+ if a1 != a2 {
+ t.Fatalf("did requests on different connections")
+ }
+ time.Sleep(3 * time.Second)
+ a3 := get()
+ if a2 == a3 {
+ t.Fatal("request three unexpectedly on same connection")
+ }
+
+ // And test that ReadHeaderTimeout still works:
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conn.Close()
+ conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo.com\r\n"))
+ time.Sleep(2 * time.Second)
+ if _, err := io.CopyN(io.Discard, conn, 1); err == nil {
+ t.Fatal("copy byte succeeded; want err")
+ }
+}
+
+func get(t *testing.T, c *Client, url string) string {
+ res, err := c.Get(url)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ slurp, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ return string(slurp)
+}
+
+// Tests that calls to Server.SetKeepAlivesEnabled(false) closes any
+// currently-open connections.
+func TestServerSetKeepAlivesEnabledClosesConns(t *testing.T) {
+ run(t, testServerSetKeepAlivesEnabledClosesConns, []testMode{http1Mode})
+}
+func testServerSetKeepAlivesEnabledClosesConns(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ io.WriteString(w, r.RemoteAddr)
+ })).ts
+
+ c := ts.Client()
+ tr := c.Transport.(*Transport)
+
+ get := func() string { return get(t, c, ts.URL) }
+
+ a1, a2 := get(), get()
+ if a1 == a2 {
+ t.Logf("made two requests from a single conn %q (as expected)", a1)
+ } else {
+ t.Errorf("server reported requests from %q and %q; expected same connection", a1, a2)
+ }
+
+ // The two requests should have used the same connection,
+ // and there should not have been a second connection that
+ // was created by racing dial against reuse.
+ // (The first get was completed when the second get started.)
+ if conns := tr.IdleConnStrsForTesting(); len(conns) != 1 {
+ t.Errorf("found %d idle conns (%q); want 1", len(conns), conns)
+ }
+
+ // SetKeepAlivesEnabled should discard idle conns.
+ ts.Config.SetKeepAlivesEnabled(false)
+
+ waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
+ if conns := tr.IdleConnStrsForTesting(); len(conns) > 0 {
+ if d > 0 {
+ t.Logf("idle conns %v after SetKeepAlivesEnabled called = %q; waiting for empty", d, conns)
+ }
+ return false
+ }
+ return true
+ })
+
+ // If we make a third request it should use a new connection, but in general
+ // we have no way to verify that: the new connection could happen to reuse the
+ // exact same ports from the previous connection.
+}
+
+func TestServerShutdown(t *testing.T) { run(t, testServerShutdown) }
+func testServerShutdown(t *testing.T, mode testMode) {
+ var cst *clientServerTest
+
+ var once sync.Once
+ statesRes := make(chan map[ConnState]int, 1)
+ shutdownRes := make(chan error, 1)
+ gotOnShutdown := make(chan struct{})
+ handler := HandlerFunc(func(w ResponseWriter, r *Request) {
+ first := false
+ once.Do(func() {
+ statesRes <- cst.ts.Config.ExportAllConnsByState()
+ go func() {
+ shutdownRes <- cst.ts.Config.Shutdown(context.Background())
+ }()
+ first = true
+ })
+
+ if first {
+ // Shutdown is graceful, so it should not interrupt this in-flight response
+ // but should reject new requests. (Since this request is still in flight,
+ // the server's port should not be reused for another server yet.)
+ <-gotOnShutdown
+ // TODO(#59038): The HTTP/2 server empirically does not always reject new
+ // requests. As a workaround, loop until we see a failure.
+ for !t.Failed() {
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ break
+ }
+ out, _ := io.ReadAll(res.Body)
+ res.Body.Close()
+ if mode == http2Mode {
+ t.Logf("%v: unexpected success (%q). Listener should be closed before OnShutdown is called.", cst.ts.URL, out)
+ t.Logf("Retrying to work around https://go.dev/issue/59038.")
+ continue
+ }
+ t.Errorf("%v: unexpected success (%q). Listener should be closed before OnShutdown is called.", cst.ts.URL, out)
+ }
+ }
+
+ io.WriteString(w, r.RemoteAddr)
+ })
+
+ cst = newClientServerTest(t, mode, handler, func(srv *httptest.Server) {
+ srv.Config.RegisterOnShutdown(func() { close(gotOnShutdown) })
+ })
+
+ out := get(t, cst.c, cst.ts.URL) // calls t.Fail on failure
+ t.Logf("%v: %q", cst.ts.URL, out)
+
+ if err := <-shutdownRes; err != nil {
+ t.Fatalf("Shutdown: %v", err)
+ }
+ <-gotOnShutdown // Will hang if RegisterOnShutdown is broken.
+
+ if states := <-statesRes; states[StateActive] != 1 {
+ t.Errorf("connection in wrong state, %v", states)
+ }
+}
+
+func TestServerShutdownStateNew(t *testing.T) { run(t, testServerShutdownStateNew) }
+func testServerShutdownStateNew(t *testing.T, mode testMode) {
+ if testing.Short() {
+ t.Skip("test takes 5-6 seconds; skipping in short mode")
+ }
+
+ var connAccepted sync.WaitGroup
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ // nothing.
+ }), func(ts *httptest.Server) {
+ ts.Config.ConnState = func(conn net.Conn, state ConnState) {
+ if state == StateNew {
+ connAccepted.Done()
+ }
+ }
+ }).ts
+
+ // Start a connection but never write to it.
+ connAccepted.Add(1)
+ c, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+
+ // Wait for the connection to be accepted by the server. Otherwise, if
+ // Shutdown happens to run first, the server will be closed when
+ // encountering the connection, in which case it will be rejected
+ // immediately.
+ connAccepted.Wait()
+
+ shutdownRes := make(chan error, 1)
+ go func() {
+ shutdownRes <- ts.Config.Shutdown(context.Background())
+ }()
+ readRes := make(chan error, 1)
+ go func() {
+ _, err := c.Read([]byte{0})
+ readRes <- err
+ }()
+
+ // TODO(#59037): This timeout is hard-coded in closeIdleConnections.
+ // It is undocumented, and some users may find it surprising.
+ // Either document it, or switch to a less surprising behavior.
+ const expectTimeout = 5 * time.Second
+
+ t0 := time.Now()
+ select {
+ case got := <-shutdownRes:
+ d := time.Since(t0)
+ if got != nil {
+ t.Fatalf("shutdown error after %v: %v", d, err)
+ }
+ if d < expectTimeout/2 {
+ t.Errorf("shutdown too soon after %v", d)
+ }
+ case <-time.After(expectTimeout * 3 / 2):
+ t.Fatalf("timeout waiting for shutdown")
+ }
+
+ // Wait for c.Read to unblock; should be already done at this point,
+ // or within a few milliseconds.
+ if err := <-readRes; err == nil {
+ t.Error("expected error from Read")
+ }
+}
+
+// Issue 17878: tests that we can call Close twice.
+func TestServerCloseDeadlock(t *testing.T) {
+ var s Server
+ s.Close()
+ s.Close()
+}
+
+// Issue 17717: tests that Server.SetKeepAlivesEnabled is respected by
+// both HTTP/1 and HTTP/2.
+func TestServerKeepAlivesEnabled(t *testing.T) { run(t, testServerKeepAlivesEnabled, testNotParallel) }
+func testServerKeepAlivesEnabled(t *testing.T, mode testMode) {
+ if mode == http2Mode {
+ restore := ExportSetH2GoawayTimeout(10 * time.Millisecond)
+ defer restore()
+ }
+ // Not parallel: messes with global variable. (http2goAwayTimeout)
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}))
+ defer cst.close()
+ srv := cst.ts.Config
+ srv.SetKeepAlivesEnabled(false)
+ for try := 0; try < 2; try++ {
+ waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
+ if !srv.ExportAllConnsIdle() {
+ if d > 0 {
+ t.Logf("test server still has active conns after %v", d)
+ }
+ return false
+ }
+ return true
+ })
+ conns := 0
+ var info httptrace.GotConnInfo
+ ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
+ GotConn: func(v httptrace.GotConnInfo) {
+ conns++
+ info = v
+ },
+ })
+ req, err := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res, err := cst.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ if conns != 1 {
+ t.Fatalf("request %v: got %v conns, want 1", try, conns)
+ }
+ if info.Reused || info.WasIdle {
+ t.Fatalf("request %v: Reused=%v (want false), WasIdle=%v (want false)", try, info.Reused, info.WasIdle)
+ }
+ }
+}
+
+// Issue 18447: test that the Server's ReadTimeout is stopped while
+// the server's doing its 1-byte background read between requests,
+// waiting for the connection to maybe close.
+func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) { run(t, testServerCancelsReadTimeoutWhenIdle) }
+func testServerCancelsReadTimeoutWhenIdle(t *testing.T, mode testMode) {
+ runTimeSensitiveTest(t, []time.Duration{
+ 10 * time.Millisecond,
+ 50 * time.Millisecond,
+ 250 * time.Millisecond,
+ time.Second,
+ 2 * time.Second,
+ }, func(t *testing.T, timeout time.Duration) error {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ select {
+ case <-time.After(2 * timeout):
+ fmt.Fprint(w, "ok")
+ case <-r.Context().Done():
+ fmt.Fprint(w, r.Context().Err())
+ }
+ }), func(ts *httptest.Server) {
+ ts.Config.ReadTimeout = timeout
+ }).ts
+
+ c := ts.Client()
+
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ return fmt.Errorf("Get: %v", err)
+ }
+ slurp, err := io.ReadAll(res.Body)
+ res.Body.Close()
+ if err != nil {
+ return fmt.Errorf("Body ReadAll: %v", err)
+ }
+ if string(slurp) != "ok" {
+ return fmt.Errorf("got: %q, want ok", slurp)
+ }
+ return nil
+ })
+}
+
+// Issue 54784: test that the Server's ReadHeaderTimeout only starts once the
+// beginning of a request has been received, rather than including time the
+// connection spent idle.
+func TestServerCancelsReadHeaderTimeoutWhenIdle(t *testing.T) {
+ run(t, testServerCancelsReadHeaderTimeoutWhenIdle, []testMode{http1Mode})
+}
+func testServerCancelsReadHeaderTimeoutWhenIdle(t *testing.T, mode testMode) {
+ runTimeSensitiveTest(t, []time.Duration{
+ 10 * time.Millisecond,
+ 50 * time.Millisecond,
+ 250 * time.Millisecond,
+ time.Second,
+ 2 * time.Second,
+ }, func(t *testing.T, timeout time.Duration) error {
+ ts := newClientServerTest(t, mode, serve(200), func(ts *httptest.Server) {
+ ts.Config.ReadHeaderTimeout = timeout
+ ts.Config.IdleTimeout = 0 // disable idle timeout
+ }).ts
+
+ // rather than using an http.Client, create a single connection, so that
+ // we can ensure this connection is not closed.
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatalf("dial failed: %v", err)
+ }
+ br := bufio.NewReader(conn)
+ defer conn.Close()
+
+ if _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")); err != nil {
+ return fmt.Errorf("writing first request failed: %v", err)
+ }
+
+ if _, err := ReadResponse(br, nil); err != nil {
+ return fmt.Errorf("first response (before timeout) failed: %v", err)
+ }
+
+ // wait for longer than the server's ReadHeaderTimeout, and then send
+ // another request
+ time.Sleep(timeout * 3 / 2)
+
+ if _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")); err != nil {
+ return fmt.Errorf("writing second request failed: %v", err)
+ }
+
+ if _, err := ReadResponse(br, nil); err != nil {
+ return fmt.Errorf("second response (after timeout) failed: %v", err)
+ }
+
+ return nil
+ })
+}
+
+// runTimeSensitiveTest runs test with the provided durations until one passes.
+// If they all fail, t.Fatal is called with the last one's duration and error value.
+func runTimeSensitiveTest(t *testing.T, durations []time.Duration, test func(t *testing.T, d time.Duration) error) {
+ for i, d := range durations {
+ err := test(t, d)
+ if err == nil {
+ return
+ }
+ if i == len(durations)-1 {
+ t.Fatalf("failed with duration %v: %v", d, err)
+ }
+ }
+}
+
+// Issue 18535: test that the Server doesn't try to do a background
+// read if it's already done one.
+func TestServerDuplicateBackgroundRead(t *testing.T) {
+ run(t, testServerDuplicateBackgroundRead, []testMode{http1Mode})
+}
+func testServerDuplicateBackgroundRead(t *testing.T, mode testMode) {
+ if runtime.GOOS == "netbsd" && runtime.GOARCH == "arm" {
+ testenv.SkipFlaky(t, 24826)
+ }
+
+ goroutines := 5
+ requests := 2000
+ if testing.Short() {
+ goroutines = 3
+ requests = 100
+ }
+
+ hts := newClientServerTest(t, mode, HandlerFunc(NotFound)).ts
+
+ reqBytes := []byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")
+
+ var wg sync.WaitGroup
+ for i := 0; i < goroutines; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ cn, err := net.Dial("tcp", hts.Listener.Addr().String())
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ defer cn.Close()
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ io.Copy(io.Discard, cn)
+ }()
+
+ for j := 0; j < requests; j++ {
+ if t.Failed() {
+ return
+ }
+ _, err := cn.Write(reqBytes)
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ }
+ }()
+ }
+ wg.Wait()
+}
+
+// Test that the bufio.Reader returned by Hijack includes any buffered
+// byte (from the Server's backgroundRead) in its buffer. We want the
+// Handler code to be able to tell that a byte is available via
+// bufio.Reader.Buffered(), without resorting to Reading it
+// (potentially blocking) to get at it.
+func TestServerHijackGetsBackgroundByte(t *testing.T) {
+ run(t, testServerHijackGetsBackgroundByte, []testMode{http1Mode})
+}
+func testServerHijackGetsBackgroundByte(t *testing.T, mode testMode) {
+ if runtime.GOOS == "plan9" {
+ t.Skip("skipping test; see https://golang.org/issue/18657")
+ }
+ done := make(chan struct{})
+ inHandler := make(chan bool, 1)
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ defer close(done)
+
+ // Tell the client to send more data after the GET request.
+ inHandler <- true
+
+ conn, buf, err := w.(Hijacker).Hijack()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ defer conn.Close()
+
+ peek, err := buf.Reader.Peek(3)
+ if string(peek) != "foo" || err != nil {
+ t.Errorf("Peek = %q, %v; want foo, nil", peek, err)
+ }
+
+ select {
+ case <-r.Context().Done():
+ t.Error("context unexpectedly canceled")
+ default:
+ }
+ })).ts
+
+ cn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cn.Close()
+ if _, err := cn.Write([]byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")); err != nil {
+ t.Fatal(err)
+ }
+ <-inHandler
+ if _, err := cn.Write([]byte("foo")); err != nil {
+ t.Fatal(err)
+ }
+
+ if err := cn.(*net.TCPConn).CloseWrite(); err != nil {
+ t.Fatal(err)
+ }
+ <-done
+}
+
+// Like TestServerHijackGetsBackgroundByte above but sending a
+// immediate 1MB of data to the server to fill up the server's 4KB
+// buffer.
+func TestServerHijackGetsBackgroundByte_big(t *testing.T) {
+ run(t, testServerHijackGetsBackgroundByte_big, []testMode{http1Mode})
+}
+func testServerHijackGetsBackgroundByte_big(t *testing.T, mode testMode) {
+ if runtime.GOOS == "plan9" {
+ t.Skip("skipping test; see https://golang.org/issue/18657")
+ }
+ done := make(chan struct{})
+ const size = 8 << 10
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ defer close(done)
+
+ conn, buf, err := w.(Hijacker).Hijack()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ defer conn.Close()
+ slurp, err := io.ReadAll(buf.Reader)
+ if err != nil {
+ t.Errorf("Copy: %v", err)
+ }
+ allX := true
+ for _, v := range slurp {
+ if v != 'x' {
+ allX = false
+ }
+ }
+ if len(slurp) != size {
+ t.Errorf("read %d; want %d", len(slurp), size)
+ } else if !allX {
+ t.Errorf("read %q; want %d 'x'", slurp, size)
+ }
+ })).ts
+
+ cn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cn.Close()
+ if _, err := fmt.Fprintf(cn, "GET / HTTP/1.1\r\nHost: e.com\r\n\r\n%s",
+ strings.Repeat("x", size)); err != nil {
+ t.Fatal(err)
+ }
+ if err := cn.(*net.TCPConn).CloseWrite(); err != nil {
+ t.Fatal(err)
+ }
+
+ <-done
+}
+
+// Issue 18319: test that the Server validates the request method.
+func TestServerValidatesMethod(t *testing.T) {
+ tests := []struct {
+ method string
+ want int
+ }{
+ {"GET", 200},
+ {"GE(T", 400},
+ }
+ for _, tt := range tests {
+ conn := &testConn{closec: make(chan bool, 1)}
+ io.WriteString(&conn.readBuf, tt.method+" / HTTP/1.1\r\nHost: foo.example\r\n\r\n")
+
+ ln := &oneConnListener{conn}
+ go Serve(ln, serve(200))
+ <-conn.closec
+ res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
+ if err != nil {
+ t.Errorf("For %s, ReadResponse: %v", tt.method, res)
+ continue
+ }
+ if res.StatusCode != tt.want {
+ t.Errorf("For %s, Status = %d; want %d", tt.method, res.StatusCode, tt.want)
+ }
+ }
+}
+
+// Listener for TestServerListenNotComparableListener.
+type eofListenerNotComparable []int
+
+func (eofListenerNotComparable) Accept() (net.Conn, error) { return nil, io.EOF }
+func (eofListenerNotComparable) Addr() net.Addr { return nil }
+func (eofListenerNotComparable) Close() error { return nil }
+
+// Issue 24812: don't crash on non-comparable Listener
+func TestServerListenNotComparableListener(t *testing.T) {
+ var s Server
+ s.Serve(make(eofListenerNotComparable, 1)) // used to panic
+}
+
+// countCloseListener is a Listener wrapper that counts the number of Close calls.
+type countCloseListener struct {
+ net.Listener
+ closes int32 // atomic
+}
+
+func (p *countCloseListener) Close() error {
+ var err error
+ if n := atomic.AddInt32(&p.closes, 1); n == 1 && p.Listener != nil {
+ err = p.Listener.Close()
+ }
+ return err
+}
+
+// Issue 24803: don't call Listener.Close on Server.Shutdown.
+func TestServerCloseListenerOnce(t *testing.T) {
+ setParallel(t)
+ defer afterTest(t)
+
+ ln := newLocalListener(t)
+ defer ln.Close()
+
+ cl := &countCloseListener{Listener: ln}
+ server := &Server{}
+ sdone := make(chan bool, 1)
+
+ go func() {
+ server.Serve(cl)
+ sdone <- true
+ }()
+ time.Sleep(10 * time.Millisecond)
+ server.Shutdown(context.Background())
+ ln.Close()
+ <-sdone
+
+ nclose := atomic.LoadInt32(&cl.closes)
+ if nclose != 1 {
+ t.Errorf("Close calls = %v; want 1", nclose)
+ }
+}
+
+// Issue 20239: don't block in Serve if Shutdown is called first.
+func TestServerShutdownThenServe(t *testing.T) {
+ var srv Server
+ cl := &countCloseListener{Listener: nil}
+ srv.Shutdown(context.Background())
+ got := srv.Serve(cl)
+ if got != ErrServerClosed {
+ t.Errorf("Serve err = %v; want ErrServerClosed", got)
+ }
+ nclose := atomic.LoadInt32(&cl.closes)
+ if nclose != 1 {
+ t.Errorf("Close calls = %v; want 1", nclose)
+ }
+}
+
+// Issue 23351: document and test behavior of ServeMux with ports
+func TestStripPortFromHost(t *testing.T) {
+ mux := NewServeMux()
+
+ mux.HandleFunc("example.com/", func(w ResponseWriter, r *Request) {
+ fmt.Fprintf(w, "OK")
+ })
+ mux.HandleFunc("example.com:9000/", func(w ResponseWriter, r *Request) {
+ fmt.Fprintf(w, "uh-oh!")
+ })
+
+ req := httptest.NewRequest("GET", "http://example.com:9000/", nil)
+ rw := httptest.NewRecorder()
+
+ mux.ServeHTTP(rw, req)
+
+ response := rw.Body.String()
+ if response != "OK" {
+ t.Errorf("Response gotten was %q", response)
+ }
+}
+
+func TestServerContexts(t *testing.T) { run(t, testServerContexts) }
+func testServerContexts(t *testing.T, mode testMode) {
+ type baseKey struct{}
+ type connKey struct{}
+ ch := make(chan context.Context, 1)
+ ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
+ ch <- r.Context()
+ }), func(ts *httptest.Server) {
+ ts.Config.BaseContext = func(ln net.Listener) context.Context {
+ if strings.Contains(reflect.TypeOf(ln).String(), "onceClose") {
+ t.Errorf("unexpected onceClose listener type %T", ln)
+ }
+ return context.WithValue(context.Background(), baseKey{}, "base")
+ }
+ ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
+ if got, want := ctx.Value(baseKey{}), "base"; got != want {
+ t.Errorf("in ConnContext, base context key = %#v; want %q", got, want)
+ }
+ return context.WithValue(ctx, connKey{}, "conn")
+ }
+ }).ts
+ res, err := ts.Client().Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ ctx := <-ch
+ if got, want := ctx.Value(baseKey{}), "base"; got != want {
+ t.Errorf("base context key = %#v; want %q", got, want)
+ }
+ if got, want := ctx.Value(connKey{}), "conn"; got != want {
+ t.Errorf("conn context key = %#v; want %q", got, want)
+ }
+}
+
+// Issue 35750: check ConnContext not modifying context for other connections
+func TestConnContextNotModifyingAllContexts(t *testing.T) {
+ run(t, testConnContextNotModifyingAllContexts)
+}
+func testConnContextNotModifyingAllContexts(t *testing.T, mode testMode) {
+ type connKey struct{}
+ ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
+ rw.Header().Set("Connection", "close")
+ }), func(ts *httptest.Server) {
+ ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
+ if got := ctx.Value(connKey{}); got != nil {
+ t.Errorf("in ConnContext, unexpected context key = %#v", got)
+ }
+ return context.WithValue(ctx, connKey{}, "conn")
+ }
+ }).ts
+
+ var res *Response
+ var err error
+
+ res, err = ts.Client().Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+
+ res, err = ts.Client().Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+}
+
+// Issue 30710: ensure that as per the spec, a server responds
+// with 501 Not Implemented for unsupported transfer-encodings.
+func TestUnsupportedTransferEncodingsReturn501(t *testing.T) {
+ run(t, testUnsupportedTransferEncodingsReturn501, []testMode{http1Mode})
+}
+func testUnsupportedTransferEncodingsReturn501(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Write([]byte("Hello, World!"))
+ })).ts
+
+ serverURL, err := url.Parse(cst.URL)
+ if err != nil {
+ t.Fatalf("Failed to parse server URL: %v", err)
+ }
+
+ unsupportedTEs := []string{
+ "fugazi",
+ "foo-bar",
+ "unknown",
+ `" chunked"`,
+ }
+
+ for _, badTE := range unsupportedTEs {
+ http1ReqBody := fmt.Sprintf(""+
+ "POST / HTTP/1.1\r\nConnection: close\r\n"+
+ "Host: localhost\r\nTransfer-Encoding: %s\r\n\r\n", badTE)
+
+ gotBody, err := fetchWireResponse(serverURL.Host, []byte(http1ReqBody))
+ if err != nil {
+ t.Errorf("%q. unexpected error: %v", badTE, err)
+ continue
+ }
+
+ wantBody := fmt.Sprintf("" +
+ "HTTP/1.1 501 Not Implemented\r\nContent-Type: text/plain; charset=utf-8\r\n" +
+ "Connection: close\r\n\r\nUnsupported transfer encoding")
+
+ if string(gotBody) != wantBody {
+ t.Errorf("%q. body\ngot\n%q\nwant\n%q", badTE, gotBody, wantBody)
+ }
+ }
+}
+
+// Issue 31753: don't sniff when Content-Encoding is set
+func TestContentEncodingNoSniffing(t *testing.T) { run(t, testContentEncodingNoSniffing) }
+func testContentEncodingNoSniffing(t *testing.T, mode testMode) {
+ type setting struct {
+ name string
+ body []byte
+
+ // setting contentEncoding as an interface instead of a string
+ // directly, so as to differentiate between 3 states:
+ // unset, empty string "" and set string "foo/bar".
+ contentEncoding any
+ wantContentType string
+ }
+
+ settings := []*setting{
+ {
+ name: "gzip content-encoding, gzipped", // don't sniff.
+ contentEncoding: "application/gzip",
+ wantContentType: "",
+ body: func() []byte {
+ buf := new(bytes.Buffer)
+ gzw := gzip.NewWriter(buf)
+ gzw.Write([]byte("doctype html><p>Hello</p>"))
+ gzw.Close()
+ return buf.Bytes()
+ }(),
+ },
+ {
+ name: "zlib content-encoding, zlibbed", // don't sniff.
+ contentEncoding: "application/zlib",
+ wantContentType: "",
+ body: func() []byte {
+ buf := new(bytes.Buffer)
+ zw := zlib.NewWriter(buf)
+ zw.Write([]byte("doctype html><p>Hello</p>"))
+ zw.Close()
+ return buf.Bytes()
+ }(),
+ },
+ {
+ name: "no content-encoding", // must sniff.
+ wantContentType: "application/x-gzip",
+ body: func() []byte {
+ buf := new(bytes.Buffer)
+ gzw := gzip.NewWriter(buf)
+ gzw.Write([]byte("doctype html><p>Hello</p>"))
+ gzw.Close()
+ return buf.Bytes()
+ }(),
+ },
+ {
+ name: "phony content-encoding", // don't sniff.
+ contentEncoding: "foo/bar",
+ body: []byte("doctype html><p>Hello</p>"),
+ },
+ {
+ name: "empty but set content-encoding",
+ contentEncoding: "",
+ wantContentType: "audio/mpeg",
+ body: []byte("ID3"),
+ },
+ }
+
+ for _, tt := range settings {
+ t.Run(tt.name, func(t *testing.T) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
+ if tt.contentEncoding != nil {
+ rw.Header().Set("Content-Encoding", tt.contentEncoding.(string))
+ }
+ rw.Write(tt.body)
+ }))
+
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Fatalf("Failed to fetch URL: %v", err)
+ }
+ defer res.Body.Close()
+
+ if g, w := res.Header.Get("Content-Encoding"), tt.contentEncoding; g != w {
+ if w != nil { // The case where contentEncoding was set explicitly.
+ t.Errorf("Content-Encoding mismatch\n\tgot: %q\n\twant: %q", g, w)
+ } else if g != "" { // "" should be the equivalent when the contentEncoding is unset.
+ t.Errorf("Unexpected Content-Encoding %q", g)
+ }
+ }
+
+ if g, w := res.Header.Get("Content-Type"), tt.wantContentType; g != w {
+ t.Errorf("Content-Type mismatch\n\tgot: %q\n\twant: %q", g, w)
+ }
+ })
+ }
+}
+
+// Issue 30803: ensure that TimeoutHandler logs spurious
+// WriteHeader calls, for consistency with other Handlers.
+func TestTimeoutHandlerSuperfluousLogs(t *testing.T) {
+ run(t, testTimeoutHandlerSuperfluousLogs, []testMode{http1Mode})
+}
+func testTimeoutHandlerSuperfluousLogs(t *testing.T, mode testMode) {
+ if testing.Short() {
+ t.Skip("skipping in short mode")
+ }
+
+ pc, curFile, _, _ := runtime.Caller(0)
+ curFileBaseName := filepath.Base(curFile)
+ testFuncName := runtime.FuncForPC(pc).Name()
+
+ timeoutMsg := "timed out here!"
+
+ tests := []struct {
+ name string
+ mustTimeout bool
+ wantResp string
+ }{
+ {
+ name: "return before timeout",
+ wantResp: "HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n",
+ },
+ {
+ name: "return after timeout",
+ mustTimeout: true,
+ wantResp: fmt.Sprintf("HTTP/1.1 503 Service Unavailable\r\nContent-Length: %d\r\n\r\n%s",
+ len(timeoutMsg), timeoutMsg),
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ exitHandler := make(chan bool, 1)
+ defer close(exitHandler)
+ lastLine := make(chan int, 1)
+
+ sh := HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.WriteHeader(404)
+ w.WriteHeader(404)
+ w.WriteHeader(404)
+ w.WriteHeader(404)
+ _, _, line, _ := runtime.Caller(0)
+ lastLine <- line
+ <-exitHandler
+ })
+
+ if !tt.mustTimeout {
+ exitHandler <- true
+ }
+
+ logBuf := new(strings.Builder)
+ srvLog := log.New(logBuf, "", 0)
+ // When expecting to timeout, we'll keep the duration short.
+ dur := 20 * time.Millisecond
+ if !tt.mustTimeout {
+ // Otherwise, make it arbitrarily long to reduce the risk of flakes.
+ dur = 10 * time.Second
+ }
+ th := TimeoutHandler(sh, dur, timeoutMsg)
+ cst := newClientServerTest(t, mode, th, optWithServerLog(srvLog))
+ defer cst.close()
+
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Fatalf("Unexpected error: %v", err)
+ }
+
+ // Deliberately removing the "Date" header since it is highly ephemeral
+ // and will cause failure if we try to match it exactly.
+ res.Header.Del("Date")
+ res.Header.Del("Content-Type")
+
+ // Match the response.
+ blob, _ := httputil.DumpResponse(res, true)
+ if g, w := string(blob), tt.wantResp; g != w {
+ t.Errorf("Response mismatch\nGot\n%q\n\nWant\n%q", g, w)
+ }
+
+ // Given 4 w.WriteHeader calls, only the first one is valid
+ // and the rest should be reported as the 3 spurious logs.
+ logEntries := strings.Split(strings.TrimSpace(logBuf.String()), "\n")
+ if g, w := len(logEntries), 3; g != w {
+ blob, _ := json.MarshalIndent(logEntries, "", " ")
+ t.Fatalf("Server logs count mismatch\ngot %d, want %d\n\nGot\n%s\n", g, w, blob)
+ }
+
+ lastSpuriousLine := <-lastLine
+ firstSpuriousLine := lastSpuriousLine - 3
+ // Now ensure that the regexes match exactly.
+ // "http: superfluous response.WriteHeader call from <fn>.func\d.\d (<curFile>:lastSpuriousLine-[1, 3]"
+ for i, logEntry := range logEntries {
+ wantLine := firstSpuriousLine + i
+ pat := fmt.Sprintf("^http: superfluous response.WriteHeader call from %s.func\\d+.\\d+ \\(%s:%d\\)$",
+ testFuncName, curFileBaseName, wantLine)
+ re := regexp.MustCompile(pat)
+ if !re.MatchString(logEntry) {
+ t.Errorf("Log entry mismatch\n\t%s\ndoes not match\n\t%s", logEntry, pat)
+ }
+ }
+ })
+ }
+}
+
+// fetchWireResponse is a helper for dialing to host,
+// sending http1ReqBody as the payload and retrieving
+// the response as it was sent on the wire.
+func fetchWireResponse(host string, http1ReqBody []byte) ([]byte, error) {
+ conn, err := net.Dial("tcp", host)
+ if err != nil {
+ return nil, err
+ }
+ defer conn.Close()
+
+ if _, err := conn.Write(http1ReqBody); err != nil {
+ return nil, err
+ }
+ return io.ReadAll(conn)
+}
+
+func BenchmarkResponseStatusLine(b *testing.B) {
+ b.ReportAllocs()
+ b.RunParallel(func(pb *testing.PB) {
+ bw := bufio.NewWriter(io.Discard)
+ var buf3 [3]byte
+ for pb.Next() {
+ Export_writeStatusLine(bw, true, 200, buf3[:])
+ }
+ })
+}
+
+func TestDisableKeepAliveUpgrade(t *testing.T) {
+ run(t, testDisableKeepAliveUpgrade, []testMode{http1Mode})
+}
+func testDisableKeepAliveUpgrade(t *testing.T, mode testMode) {
+ if testing.Short() {
+ t.Skip("skipping in short mode")
+ }
+
+ s := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Connection", "Upgrade")
+ w.Header().Set("Upgrade", "someProto")
+ w.WriteHeader(StatusSwitchingProtocols)
+ c, buf, err := w.(Hijacker).Hijack()
+ if err != nil {
+ return
+ }
+ defer c.Close()
+
+ // Copy from the *bufio.ReadWriter, which may contain buffered data.
+ // Copy to the net.Conn, to avoid buffering the output.
+ io.Copy(c, buf)
+ }), func(ts *httptest.Server) {
+ ts.Config.SetKeepAlivesEnabled(false)
+ }).ts
+
+ cl := s.Client()
+ cl.Transport.(*Transport).DisableKeepAlives = true
+
+ resp, err := cl.Get(s.URL)
+ if err != nil {
+ t.Fatalf("failed to perform request: %v", err)
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != StatusSwitchingProtocols {
+ t.Fatalf("unexpected status code: %v", resp.StatusCode)
+ }
+
+ rwc, ok := resp.Body.(io.ReadWriteCloser)
+ if !ok {
+ t.Fatalf("Response.Body is not an io.ReadWriteCloser: %T", resp.Body)
+ }
+
+ _, err = rwc.Write([]byte("hello"))
+ if err != nil {
+ t.Fatalf("failed to write to body: %v", err)
+ }
+
+ b := make([]byte, 5)
+ _, err = io.ReadFull(rwc, b)
+ if err != nil {
+ t.Fatalf("failed to read from body: %v", err)
+ }
+
+ if string(b) != "hello" {
+ t.Fatalf("unexpected value read from body:\ngot: %q\nwant: %q", b, "hello")
+ }
+}
+
+type tlogWriter struct{ t *testing.T }
+
+func (w tlogWriter) Write(p []byte) (int, error) {
+ w.t.Log(string(p))
+ return len(p), nil
+}
+
+func TestWriteHeaderSwitchingProtocols(t *testing.T) {
+ run(t, testWriteHeaderSwitchingProtocols, []testMode{http1Mode})
+}
+func testWriteHeaderSwitchingProtocols(t *testing.T, mode testMode) {
+ const wantBody = "want"
+ const wantUpgrade = "someProto"
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Connection", "Upgrade")
+ w.Header().Set("Upgrade", wantUpgrade)
+ w.WriteHeader(StatusSwitchingProtocols)
+ NewResponseController(w).Flush()
+
+ // Writing headers or the body after sending a 101 header should fail.
+ w.WriteHeader(200)
+ if _, err := w.Write([]byte("x")); err == nil {
+ t.Errorf("Write to body after 101 Switching Protocols unexpectedly succeeded")
+ }
+
+ c, _, err := NewResponseController(w).Hijack()
+ if err != nil {
+ t.Errorf("Hijack: %v", err)
+ return
+ }
+ defer c.Close()
+ if _, err := c.Write([]byte(wantBody)); err != nil {
+ t.Errorf("Write to hijacked body: %v", err)
+ }
+ }), func(ts *httptest.Server) {
+ // Don't spam log with warning about superfluous WriteHeader call.
+ ts.Config.ErrorLog = log.New(tlogWriter{t}, "log: ", 0)
+ }).ts
+
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatalf("net.Dial: %v", err)
+ }
+ _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
+ if err != nil {
+ t.Fatalf("conn.Write: %v", err)
+ }
+ defer conn.Close()
+
+ r := bufio.NewReader(conn)
+ res, err := ReadResponse(r, &Request{Method: "GET"})
+ if err != nil {
+ t.Fatal("ReadResponse error:", err)
+ }
+ if res.StatusCode != StatusSwitchingProtocols {
+ t.Errorf("Response StatusCode=%v, want 101", res.StatusCode)
+ }
+ if got := res.Header.Get("Upgrade"); got != wantUpgrade {
+ t.Errorf("Response Upgrade header = %q, want %q", got, wantUpgrade)
+ }
+ body, err := io.ReadAll(r)
+ if err != nil {
+ t.Error(err)
+ }
+ if string(body) != wantBody {
+ t.Errorf("Response body = %q, want %q", string(body), wantBody)
+ }
+}
+
+func TestMuxRedirectRelative(t *testing.T) {
+ setParallel(t)
+ req, err := ReadRequest(bufio.NewReader(strings.NewReader("GET http://example.com HTTP/1.1\r\nHost: test\r\n\r\n")))
+ if err != nil {
+ t.Errorf("%s", err)
+ }
+ mux := NewServeMux()
+ resp := httptest.NewRecorder()
+ mux.ServeHTTP(resp, req)
+ if got, want := resp.Header().Get("Location"), "/"; got != want {
+ t.Errorf("Location header expected %q; got %q", want, got)
+ }
+ if got, want := resp.Code, StatusMovedPermanently; got != want {
+ t.Errorf("Expected response code %d; got %d", want, got)
+ }
+}
+
+// TestQuerySemicolon tests the behavior of semicolons in queries. See Issue 25192.
+func TestQuerySemicolon(t *testing.T) {
+ t.Cleanup(func() { afterTest(t) })
+
+ tests := []struct {
+ query string
+ xNoSemicolons string
+ xWithSemicolons string
+ expectParseFormErr bool
+ }{
+ {"?a=1;x=bad&x=good", "good", "bad", true},
+ {"?a=1;b=bad&x=good", "good", "good", true},
+ {"?a=1%3Bx=bad&x=good%3B", "good;", "good;", false},
+ {"?a=1;x=good;x=bad", "", "good", true},
+ }
+
+ run(t, func(t *testing.T, mode testMode) {
+ for _, tt := range tests {
+ t.Run(tt.query+"/allow=false", func(t *testing.T) {
+ allowSemicolons := false
+ testQuerySemicolon(t, mode, tt.query, tt.xNoSemicolons, allowSemicolons, tt.expectParseFormErr)
+ })
+ t.Run(tt.query+"/allow=true", func(t *testing.T) {
+ allowSemicolons, expectParseFormErr := true, false
+ testQuerySemicolon(t, mode, tt.query, tt.xWithSemicolons, allowSemicolons, expectParseFormErr)
+ })
+ }
+ })
+}
+
+func testQuerySemicolon(t *testing.T, mode testMode, query string, wantX string, allowSemicolons, expectParseFormErr bool) {
+ writeBackX := func(w ResponseWriter, r *Request) {
+ x := r.URL.Query().Get("x")
+ if expectParseFormErr {
+ if err := r.ParseForm(); err == nil || !strings.Contains(err.Error(), "semicolon") {
+ t.Errorf("expected error mentioning semicolons from ParseForm, got %v", err)
+ }
+ } else {
+ if err := r.ParseForm(); err != nil {
+ t.Errorf("expected no error from ParseForm, got %v", err)
+ }
+ }
+ if got := r.FormValue("x"); x != got {
+ t.Errorf("got %q from FormValue, want %q", got, x)
+ }
+ fmt.Fprintf(w, "%s", x)
+ }
+
+ h := Handler(HandlerFunc(writeBackX))
+ if allowSemicolons {
+ h = AllowQuerySemicolons(h)
+ }
+
+ logBuf := &strings.Builder{}
+ ts := newClientServerTest(t, mode, h, func(ts *httptest.Server) {
+ ts.Config.ErrorLog = log.New(logBuf, "", 0)
+ }).ts
+
+ req, _ := NewRequest("GET", ts.URL+query, nil)
+ res, err := ts.Client().Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ slurp, _ := io.ReadAll(res.Body)
+ res.Body.Close()
+ if got, want := res.StatusCode, 200; got != want {
+ t.Errorf("Status = %d; want = %d", got, want)
+ }
+ if got, want := string(slurp), wantX; got != want {
+ t.Errorf("Body = %q; want = %q", got, want)
+ }
+}
+
+func TestMaxBytesHandler(t *testing.T) {
+ setParallel(t)
+ defer afterTest(t)
+
+ for _, maxSize := range []int64{100, 1_000, 1_000_000} {
+ for _, requestSize := range []int64{100, 1_000, 1_000_000} {
+ t.Run(fmt.Sprintf("max size %d request size %d", maxSize, requestSize),
+ func(t *testing.T) {
+ run(t, func(t *testing.T, mode testMode) {
+ testMaxBytesHandler(t, mode, maxSize, requestSize)
+ })
+ })
+ }
+ }
+}
+
+func testMaxBytesHandler(t *testing.T, mode testMode, maxSize, requestSize int64) {
+ var (
+ handlerN int64
+ handlerErr error
+ )
+ echo := HandlerFunc(func(w ResponseWriter, r *Request) {
+ var buf bytes.Buffer
+ handlerN, handlerErr = io.Copy(&buf, r.Body)
+ io.Copy(w, &buf)
+ })
+
+ ts := newClientServerTest(t, mode, MaxBytesHandler(echo, maxSize)).ts
+ defer ts.Close()
+
+ c := ts.Client()
+
+ body := strings.Repeat("a", int(requestSize))
+ var wg sync.WaitGroup
+ defer wg.Wait()
+ getBody := func() (io.ReadCloser, error) {
+ wg.Add(1)
+ body := &wgReadCloser{
+ Reader: strings.NewReader(body),
+ wg: &wg,
+ }
+ return body, nil
+ }
+ reqBody, _ := getBody()
+ req, err := NewRequest("POST", ts.URL, reqBody)
+ if err != nil {
+ reqBody.Close()
+ t.Fatal(err)
+ }
+ req.ContentLength = int64(len(body))
+ req.GetBody = getBody
+ req.Header.Set("Content-Type", "text/plain")
+
+ var buf strings.Builder
+ res, err := c.Do(req)
+ if err != nil {
+ t.Errorf("unexpected connection error: %v", err)
+ } else {
+ _, err = io.Copy(&buf, res.Body)
+ res.Body.Close()
+ if err != nil {
+ t.Errorf("unexpected read error: %v", err)
+ }
+ }
+ if handlerN > maxSize {
+ t.Errorf("expected max request body %d; got %d", maxSize, handlerN)
+ }
+ if requestSize > maxSize && handlerErr == nil {
+ t.Error("expected error on handler side; got nil")
+ }
+ if requestSize <= maxSize {
+ if handlerErr != nil {
+ t.Errorf("%d expected nil error on handler side; got %v", requestSize, handlerErr)
+ }
+ if handlerN != requestSize {
+ t.Errorf("expected request of size %d; got %d", requestSize, handlerN)
+ }
+ }
+ if buf.Len() != int(handlerN) {
+ t.Errorf("expected echo of size %d; got %d", handlerN, buf.Len())
+ }
+}
+
+func TestEarlyHints(t *testing.T) {
+ ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
+ h := w.Header()
+ h.Add("Link", "</style.css>; rel=preload; as=style")
+ h.Add("Link", "</script.js>; rel=preload; as=script")
+ w.WriteHeader(StatusEarlyHints)
+
+ h.Add("Link", "</foo.js>; rel=preload; as=script")
+ w.WriteHeader(StatusEarlyHints)
+
+ w.Write([]byte("stuff"))
+ }))
+
+ got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org")
+ expected := "HTTP/1.1 103 Early Hints\r\nLink: </style.css>; rel=preload; as=style\r\nLink: </script.js>; rel=preload; as=script\r\n\r\nHTTP/1.1 103 Early Hints\r\nLink: </style.css>; rel=preload; as=style\r\nLink: </script.js>; rel=preload; as=script\r\nLink: </foo.js>; rel=preload; as=script\r\n\r\nHTTP/1.1 200 OK\r\nLink: </style.css>; rel=preload; as=style\r\nLink: </script.js>; rel=preload; as=script\r\nLink: </foo.js>; rel=preload; as=script\r\nDate: " // dynamic content expected
+ if !strings.Contains(got, expected) {
+ t.Errorf("unexpected response; got %q; should start by %q", got, expected)
+ }
+}
+func TestProcessing(t *testing.T) {
+ ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.WriteHeader(StatusProcessing)
+ w.Write([]byte("stuff"))
+ }))
+
+ got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org")
+ expected := "HTTP/1.1 102 Processing\r\n\r\nHTTP/1.1 200 OK\r\nDate: " // dynamic content expected
+ if !strings.Contains(got, expected) {
+ t.Errorf("unexpected response; got %q; should start by %q", got, expected)
+ }
+}
+
+func TestParseFormCleanup(t *testing.T) { run(t, testParseFormCleanup) }
+func testParseFormCleanup(t *testing.T, mode testMode) {
+ if mode == http2Mode {
+ t.Skip("https://go.dev/issue/20253")
+ }
+
+ const maxMemory = 1024
+ const key = "file"
+
+ if runtime.GOOS == "windows" {
+ // Windows sometimes refuses to remove a file that was just closed.
+ t.Skip("https://go.dev/issue/25965")
+ }
+
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ r.ParseMultipartForm(maxMemory)
+ f, _, err := r.FormFile(key)
+ if err != nil {
+ t.Errorf("r.FormFile(%q) = %v", key, err)
+ return
+ }
+ of, ok := f.(*os.File)
+ if !ok {
+ t.Errorf("r.FormFile(%q) returned type %T, want *os.File", key, f)
+ return
+ }
+ w.Write([]byte(of.Name()))
+ }))
+
+ fBuf := new(bytes.Buffer)
+ mw := multipart.NewWriter(fBuf)
+ mf, err := mw.CreateFormFile(key, "myfile.txt")
+ if err != nil {
+ t.Fatal(err)
+ }
+ if _, err := mf.Write(bytes.Repeat([]byte("A"), maxMemory*2)); err != nil {
+ t.Fatal(err)
+ }
+ if err := mw.Close(); err != nil {
+ t.Fatal(err)
+ }
+ req, err := NewRequest("POST", cst.ts.URL, fBuf)
+ if err != nil {
+ t.Fatal(err)
+ }
+ req.Header.Set("Content-Type", mw.FormDataContentType())
+ res, err := cst.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ fname, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ cst.close()
+ if _, err := os.Stat(string(fname)); !errors.Is(err, os.ErrNotExist) {
+ t.Errorf("file %q exists after HTTP handler returned", string(fname))
+ }
+}
+
+func TestHeadBody(t *testing.T) {
+ const identityMode = false
+ const chunkedMode = true
+ run(t, func(t *testing.T, mode testMode) {
+ t.Run("identity", func(t *testing.T) { testHeadBody(t, mode, identityMode, "HEAD") })
+ t.Run("chunked", func(t *testing.T) { testHeadBody(t, mode, chunkedMode, "HEAD") })
+ })
+}
+
+func TestGetBody(t *testing.T) {
+ const identityMode = false
+ const chunkedMode = true
+ run(t, func(t *testing.T, mode testMode) {
+ t.Run("identity", func(t *testing.T) { testHeadBody(t, mode, identityMode, "GET") })
+ t.Run("chunked", func(t *testing.T) { testHeadBody(t, mode, chunkedMode, "GET") })
+ })
+}
+
+func testHeadBody(t *testing.T, mode testMode, chunked bool, method string) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ b, err := io.ReadAll(r.Body)
+ if err != nil {
+ t.Errorf("server reading body: %v", err)
+ return
+ }
+ w.Header().Set("X-Request-Body", string(b))
+ w.Header().Set("Content-Length", "0")
+ }))
+ defer cst.close()
+ for _, reqBody := range []string{
+ "",
+ "",
+ "request_body",
+ "",
+ } {
+ var bodyReader io.Reader
+ if reqBody != "" {
+ bodyReader = strings.NewReader(reqBody)
+ if chunked {
+ bodyReader = bufio.NewReader(bodyReader)
+ }
+ }
+ req, err := NewRequest(method, cst.ts.URL, bodyReader)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res, err := cst.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ if got, want := res.StatusCode, 200; got != want {
+ t.Errorf("%v request with %d-byte body: StatusCode = %v, want %v", method, len(reqBody), got, want)
+ }
+ if got, want := res.Header.Get("X-Request-Body"), reqBody; got != want {
+ t.Errorf("%v request with %d-byte body: handler read body %q, want %q", method, len(reqBody), got, want)
+ }
+ }
+}
+
+// TestDisableContentLength verifies that the Content-Length is set by default
+// or disabled when the header is set to nil.
+func TestDisableContentLength(t *testing.T) { run(t, testDisableContentLength) }
+func testDisableContentLength(t *testing.T, mode testMode) {
+ if mode == http2Mode {
+ t.Skip("skipping until h2_bundle.go is updated; see https://go-review.googlesource.com/c/net/+/471535")
+ }
+
+ noCL := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header()["Content-Length"] = nil // disable the default Content-Length response
+ fmt.Fprintf(w, "OK")
+ }))
+
+ res, err := noCL.c.Get(noCL.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got, haveCL := res.Header["Content-Length"]; haveCL {
+ t.Errorf("Unexpected Content-Length: %q", got)
+ }
+ if err := res.Body.Close(); err != nil {
+ t.Fatal(err)
+ }
+
+ withCL := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ fmt.Fprintf(w, "OK")
+ }))
+
+ res, err = withCL.c.Get(withCL.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got := res.Header.Get("Content-Length"); got != "2" {
+ t.Errorf("Content-Length: %q; want 2", got)
+ }
+ if err := res.Body.Close(); err != nil {
+ t.Fatal(err)
+ }
+}
diff --git a/src/net/http/server.go b/src/net/http/server.go
new file mode 100644
index 0000000..8f63a90
--- /dev/null
+++ b/src/net/http/server.go
@@ -0,0 +1,3645 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// HTTP server. See RFC 7230 through 7235.
+
+package http
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "crypto/tls"
+ "errors"
+ "fmt"
+ "internal/godebug"
+ "io"
+ "log"
+ "math/rand"
+ "net"
+ "net/textproto"
+ "net/url"
+ urlpkg "net/url"
+ "path"
+ "runtime"
+ "sort"
+ "strconv"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "golang.org/x/net/http/httpguts"
+)
+
+// Errors used by the HTTP server.
+var (
+ // ErrBodyNotAllowed is returned by ResponseWriter.Write calls
+ // when the HTTP method or response code does not permit a
+ // body.
+ ErrBodyNotAllowed = errors.New("http: request method or response status code does not allow body")
+
+ // ErrHijacked is returned by ResponseWriter.Write calls when
+ // the underlying connection has been hijacked using the
+ // Hijacker interface. A zero-byte write on a hijacked
+ // connection will return ErrHijacked without any other side
+ // effects.
+ ErrHijacked = errors.New("http: connection has been hijacked")
+
+ // ErrContentLength is returned by ResponseWriter.Write calls
+ // when a Handler set a Content-Length response header with a
+ // declared size and then attempted to write more bytes than
+ // declared.
+ ErrContentLength = errors.New("http: wrote more than the declared Content-Length")
+
+ // Deprecated: ErrWriteAfterFlush is no longer returned by
+ // anything in the net/http package. Callers should not
+ // compare errors against this variable.
+ ErrWriteAfterFlush = errors.New("unused")
+)
+
+// A Handler responds to an HTTP request.
+//
+// ServeHTTP should write reply headers and data to the ResponseWriter
+// and then return. Returning signals that the request is finished; it
+// is not valid to use the ResponseWriter or read from the
+// Request.Body after or concurrently with the completion of the
+// ServeHTTP call.
+//
+// Depending on the HTTP client software, HTTP protocol version, and
+// any intermediaries between the client and the Go server, it may not
+// be possible to read from the Request.Body after writing to the
+// ResponseWriter. Cautious handlers should read the Request.Body
+// first, and then reply.
+//
+// Except for reading the body, handlers should not modify the
+// provided Request.
+//
+// If ServeHTTP panics, the server (the caller of ServeHTTP) assumes
+// that the effect of the panic was isolated to the active request.
+// It recovers the panic, logs a stack trace to the server error log,
+// and either closes the network connection or sends an HTTP/2
+// RST_STREAM, depending on the HTTP protocol. To abort a handler so
+// the client sees an interrupted response but the server doesn't log
+// an error, panic with the value ErrAbortHandler.
+type Handler interface {
+ ServeHTTP(ResponseWriter, *Request)
+}
+
+// A ResponseWriter interface is used by an HTTP handler to
+// construct an HTTP response.
+//
+// A ResponseWriter may not be used after the Handler.ServeHTTP method
+// has returned.
+type ResponseWriter interface {
+ // Header returns the header map that will be sent by
+ // WriteHeader. The Header map also is the mechanism with which
+ // Handlers can set HTTP trailers.
+ //
+ // Changing the header map after a call to WriteHeader (or
+ // Write) has no effect unless the HTTP status code was of the
+ // 1xx class or the modified headers are trailers.
+ //
+ // There are two ways to set Trailers. The preferred way is to
+ // predeclare in the headers which trailers you will later
+ // send by setting the "Trailer" header to the names of the
+ // trailer keys which will come later. In this case, those
+ // keys of the Header map are treated as if they were
+ // trailers. See the example. The second way, for trailer
+ // keys not known to the Handler until after the first Write,
+ // is to prefix the Header map keys with the TrailerPrefix
+ // constant value. See TrailerPrefix.
+ //
+ // To suppress automatic response headers (such as "Date"), set
+ // their value to nil.
+ Header() Header
+
+ // Write writes the data to the connection as part of an HTTP reply.
+ //
+ // If WriteHeader has not yet been called, Write calls
+ // WriteHeader(http.StatusOK) before writing the data. If the Header
+ // does not contain a Content-Type line, Write adds a Content-Type set
+ // to the result of passing the initial 512 bytes of written data to
+ // DetectContentType. Additionally, if the total size of all written
+ // data is under a few KB and there are no Flush calls, the
+ // Content-Length header is added automatically.
+ //
+ // Depending on the HTTP protocol version and the client, calling
+ // Write or WriteHeader may prevent future reads on the
+ // Request.Body. For HTTP/1.x requests, handlers should read any
+ // needed request body data before writing the response. Once the
+ // headers have been flushed (due to either an explicit Flusher.Flush
+ // call or writing enough data to trigger a flush), the request body
+ // may be unavailable. For HTTP/2 requests, the Go HTTP server permits
+ // handlers to continue to read the request body while concurrently
+ // writing the response. However, such behavior may not be supported
+ // by all HTTP/2 clients. Handlers should read before writing if
+ // possible to maximize compatibility.
+ Write([]byte) (int, error)
+
+ // WriteHeader sends an HTTP response header with the provided
+ // status code.
+ //
+ // If WriteHeader is not called explicitly, the first call to Write
+ // will trigger an implicit WriteHeader(http.StatusOK).
+ // Thus explicit calls to WriteHeader are mainly used to
+ // send error codes or 1xx informational responses.
+ //
+ // The provided code must be a valid HTTP 1xx-5xx status code.
+ // Any number of 1xx headers may be written, followed by at most
+ // one 2xx-5xx header. 1xx headers are sent immediately, but 2xx-5xx
+ // headers may be buffered. Use the Flusher interface to send
+ // buffered data. The header map is cleared when 2xx-5xx headers are
+ // sent, but not with 1xx headers.
+ //
+ // The server will automatically send a 100 (Continue) header
+ // on the first read from the request body if the request has
+ // an "Expect: 100-continue" header.
+ WriteHeader(statusCode int)
+}
+
+// The Flusher interface is implemented by ResponseWriters that allow
+// an HTTP handler to flush buffered data to the client.
+//
+// The default HTTP/1.x and HTTP/2 ResponseWriter implementations
+// support Flusher, but ResponseWriter wrappers may not. Handlers
+// should always test for this ability at runtime.
+//
+// Note that even for ResponseWriters that support Flush,
+// if the client is connected through an HTTP proxy,
+// the buffered data may not reach the client until the response
+// completes.
+type Flusher interface {
+ // Flush sends any buffered data to the client.
+ Flush()
+}
+
+// The Hijacker interface is implemented by ResponseWriters that allow
+// an HTTP handler to take over the connection.
+//
+// The default ResponseWriter for HTTP/1.x connections supports
+// Hijacker, but HTTP/2 connections intentionally do not.
+// ResponseWriter wrappers may also not support Hijacker. Handlers
+// should always test for this ability at runtime.
+type Hijacker interface {
+ // Hijack lets the caller take over the connection.
+ // After a call to Hijack the HTTP server library
+ // will not do anything else with the connection.
+ //
+ // It becomes the caller's responsibility to manage
+ // and close the connection.
+ //
+ // The returned net.Conn may have read or write deadlines
+ // already set, depending on the configuration of the
+ // Server. It is the caller's responsibility to set
+ // or clear those deadlines as needed.
+ //
+ // The returned bufio.Reader may contain unprocessed buffered
+ // data from the client.
+ //
+ // After a call to Hijack, the original Request.Body must not
+ // be used. The original Request's Context remains valid and
+ // is not canceled until the Request's ServeHTTP method
+ // returns.
+ Hijack() (net.Conn, *bufio.ReadWriter, error)
+}
+
+// The CloseNotifier interface is implemented by ResponseWriters which
+// allow detecting when the underlying connection has gone away.
+//
+// This mechanism can be used to cancel long operations on the server
+// if the client has disconnected before the response is ready.
+//
+// Deprecated: the CloseNotifier interface predates Go's context package.
+// New code should use Request.Context instead.
+type CloseNotifier interface {
+ // CloseNotify returns a channel that receives at most a
+ // single value (true) when the client connection has gone
+ // away.
+ //
+ // CloseNotify may wait to notify until Request.Body has been
+ // fully read.
+ //
+ // After the Handler has returned, there is no guarantee
+ // that the channel receives a value.
+ //
+ // If the protocol is HTTP/1.1 and CloseNotify is called while
+ // processing an idempotent request (such a GET) while
+ // HTTP/1.1 pipelining is in use, the arrival of a subsequent
+ // pipelined request may cause a value to be sent on the
+ // returned channel. In practice HTTP/1.1 pipelining is not
+ // enabled in browsers and not seen often in the wild. If this
+ // is a problem, use HTTP/2 or only use CloseNotify on methods
+ // such as POST.
+ CloseNotify() <-chan bool
+}
+
+var (
+ // ServerContextKey is a context key. It can be used in HTTP
+ // handlers with Context.Value to access the server that
+ // started the handler. The associated value will be of
+ // type *Server.
+ ServerContextKey = &contextKey{"http-server"}
+
+ // LocalAddrContextKey is a context key. It can be used in
+ // HTTP handlers with Context.Value to access the local
+ // address the connection arrived on.
+ // The associated value will be of type net.Addr.
+ LocalAddrContextKey = &contextKey{"local-addr"}
+)
+
+// A conn represents the server side of an HTTP connection.
+type conn struct {
+ // server is the server on which the connection arrived.
+ // Immutable; never nil.
+ server *Server
+
+ // cancelCtx cancels the connection-level context.
+ cancelCtx context.CancelFunc
+
+ // rwc is the underlying network connection.
+ // This is never wrapped by other types and is the value given out
+ // to CloseNotifier callers. It is usually of type *net.TCPConn or
+ // *tls.Conn.
+ rwc net.Conn
+
+ // remoteAddr is rwc.RemoteAddr().String(). It is not populated synchronously
+ // inside the Listener's Accept goroutine, as some implementations block.
+ // It is populated immediately inside the (*conn).serve goroutine.
+ // This is the value of a Handler's (*Request).RemoteAddr.
+ remoteAddr string
+
+ // tlsState is the TLS connection state when using TLS.
+ // nil means not TLS.
+ tlsState *tls.ConnectionState
+
+ // werr is set to the first write error to rwc.
+ // It is set via checkConnErrorWriter{w}, where bufw writes.
+ werr error
+
+ // r is bufr's read source. It's a wrapper around rwc that provides
+ // io.LimitedReader-style limiting (while reading request headers)
+ // and functionality to support CloseNotifier. See *connReader docs.
+ r *connReader
+
+ // bufr reads from r.
+ bufr *bufio.Reader
+
+ // bufw writes to checkConnErrorWriter{c}, which populates werr on error.
+ bufw *bufio.Writer
+
+ // lastMethod is the method of the most recent request
+ // on this connection, if any.
+ lastMethod string
+
+ curReq atomic.Pointer[response] // (which has a Request in it)
+
+ curState atomic.Uint64 // packed (unixtime<<8|uint8(ConnState))
+
+ // mu guards hijackedv
+ mu sync.Mutex
+
+ // hijackedv is whether this connection has been hijacked
+ // by a Handler with the Hijacker interface.
+ // It is guarded by mu.
+ hijackedv bool
+}
+
+func (c *conn) hijacked() bool {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ return c.hijackedv
+}
+
+// c.mu must be held.
+func (c *conn) hijackLocked() (rwc net.Conn, buf *bufio.ReadWriter, err error) {
+ if c.hijackedv {
+ return nil, nil, ErrHijacked
+ }
+ c.r.abortPendingRead()
+
+ c.hijackedv = true
+ rwc = c.rwc
+ rwc.SetDeadline(time.Time{})
+
+ buf = bufio.NewReadWriter(c.bufr, bufio.NewWriter(rwc))
+ if c.r.hasByte {
+ if _, err := c.bufr.Peek(c.bufr.Buffered() + 1); err != nil {
+ return nil, nil, fmt.Errorf("unexpected Peek failure reading buffered byte: %v", err)
+ }
+ }
+ c.setState(rwc, StateHijacked, runHooks)
+ return
+}
+
+// This should be >= 512 bytes for DetectContentType,
+// but otherwise it's somewhat arbitrary.
+const bufferBeforeChunkingSize = 2048
+
+// chunkWriter writes to a response's conn buffer, and is the writer
+// wrapped by the response.w buffered writer.
+//
+// chunkWriter also is responsible for finalizing the Header, including
+// conditionally setting the Content-Type and setting a Content-Length
+// in cases where the handler's final output is smaller than the buffer
+// size. It also conditionally adds chunk headers, when in chunking mode.
+//
+// See the comment above (*response).Write for the entire write flow.
+type chunkWriter struct {
+ res *response
+
+ // header is either nil or a deep clone of res.handlerHeader
+ // at the time of res.writeHeader, if res.writeHeader is
+ // called and extra buffering is being done to calculate
+ // Content-Type and/or Content-Length.
+ header Header
+
+ // wroteHeader tells whether the header's been written to "the
+ // wire" (or rather: w.conn.buf). this is unlike
+ // (*response).wroteHeader, which tells only whether it was
+ // logically written.
+ wroteHeader bool
+
+ // set by the writeHeader method:
+ chunking bool // using chunked transfer encoding for reply body
+}
+
+var (
+ crlf = []byte("\r\n")
+ colonSpace = []byte(": ")
+)
+
+func (cw *chunkWriter) Write(p []byte) (n int, err error) {
+ if !cw.wroteHeader {
+ cw.writeHeader(p)
+ }
+ if cw.res.req.Method == "HEAD" {
+ // Eat writes.
+ return len(p), nil
+ }
+ if cw.chunking {
+ _, err = fmt.Fprintf(cw.res.conn.bufw, "%x\r\n", len(p))
+ if err != nil {
+ cw.res.conn.rwc.Close()
+ return
+ }
+ }
+ n, err = cw.res.conn.bufw.Write(p)
+ if cw.chunking && err == nil {
+ _, err = cw.res.conn.bufw.Write(crlf)
+ }
+ if err != nil {
+ cw.res.conn.rwc.Close()
+ }
+ return
+}
+
+func (cw *chunkWriter) flush() error {
+ if !cw.wroteHeader {
+ cw.writeHeader(nil)
+ }
+ return cw.res.conn.bufw.Flush()
+}
+
+func (cw *chunkWriter) close() {
+ if !cw.wroteHeader {
+ cw.writeHeader(nil)
+ }
+ if cw.chunking {
+ bw := cw.res.conn.bufw // conn's bufio writer
+ // zero chunk to mark EOF
+ bw.WriteString("0\r\n")
+ if trailers := cw.res.finalTrailers(); trailers != nil {
+ trailers.Write(bw) // the writer handles noting errors
+ }
+ // final blank line after the trailers (whether
+ // present or not)
+ bw.WriteString("\r\n")
+ }
+}
+
+// A response represents the server side of an HTTP response.
+type response struct {
+ conn *conn
+ req *Request // request for this response
+ reqBody io.ReadCloser
+ cancelCtx context.CancelFunc // when ServeHTTP exits
+ wroteHeader bool // a non-1xx header has been (logically) written
+ wroteContinue bool // 100 Continue response was written
+ wants10KeepAlive bool // HTTP/1.0 w/ Connection "keep-alive"
+ wantsClose bool // HTTP request has Connection "close"
+
+ // canWriteContinue is an atomic boolean that says whether or
+ // not a 100 Continue header can be written to the
+ // connection.
+ // writeContinueMu must be held while writing the header.
+ // These two fields together synchronize the body reader (the
+ // expectContinueReader, which wants to write 100 Continue)
+ // against the main writer.
+ canWriteContinue atomic.Bool
+ writeContinueMu sync.Mutex
+
+ w *bufio.Writer // buffers output in chunks to chunkWriter
+ cw chunkWriter
+
+ // handlerHeader is the Header that Handlers get access to,
+ // which may be retained and mutated even after WriteHeader.
+ // handlerHeader is copied into cw.header at WriteHeader
+ // time, and privately mutated thereafter.
+ handlerHeader Header
+ calledHeader bool // handler accessed handlerHeader via Header
+
+ written int64 // number of bytes written in body
+ contentLength int64 // explicitly-declared Content-Length; or -1
+ status int // status code passed to WriteHeader
+
+ // close connection after this reply. set on request and
+ // updated after response from handler if there's a
+ // "Connection: keep-alive" response header and a
+ // Content-Length.
+ closeAfterReply bool
+
+ // When fullDuplex is false (the default), we consume any remaining
+ // request body before starting to write a response.
+ fullDuplex bool
+
+ // requestBodyLimitHit is set by requestTooLarge when
+ // maxBytesReader hits its max size. It is checked in
+ // WriteHeader, to make sure we don't consume the
+ // remaining request body to try to advance to the next HTTP
+ // request. Instead, when this is set, we stop reading
+ // subsequent requests on this connection and stop reading
+ // input from it.
+ requestBodyLimitHit bool
+
+ // trailers are the headers to be sent after the handler
+ // finishes writing the body. This field is initialized from
+ // the Trailer response header when the response header is
+ // written.
+ trailers []string
+
+ handlerDone atomic.Bool // set true when the handler exits
+
+ // Buffers for Date, Content-Length, and status code
+ dateBuf [len(TimeFormat)]byte
+ clenBuf [10]byte
+ statusBuf [3]byte
+
+ // closeNotifyCh is the channel returned by CloseNotify.
+ // TODO(bradfitz): this is currently (for Go 1.8) always
+ // non-nil. Make this lazily-created again as it used to be?
+ closeNotifyCh chan bool
+ didCloseNotify atomic.Bool // atomic (only false->true winner should send)
+}
+
+func (c *response) SetReadDeadline(deadline time.Time) error {
+ return c.conn.rwc.SetReadDeadline(deadline)
+}
+
+func (c *response) SetWriteDeadline(deadline time.Time) error {
+ return c.conn.rwc.SetWriteDeadline(deadline)
+}
+
+func (c *response) EnableFullDuplex() error {
+ c.fullDuplex = true
+ return nil
+}
+
+// TrailerPrefix is a magic prefix for ResponseWriter.Header map keys
+// that, if present, signals that the map entry is actually for
+// the response trailers, and not the response headers. The prefix
+// is stripped after the ServeHTTP call finishes and the values are
+// sent in the trailers.
+//
+// This mechanism is intended only for trailers that are not known
+// prior to the headers being written. If the set of trailers is fixed
+// or known before the header is written, the normal Go trailers mechanism
+// is preferred:
+//
+// https://pkg.go.dev/net/http#ResponseWriter
+// https://pkg.go.dev/net/http#example-ResponseWriter-Trailers
+const TrailerPrefix = "Trailer:"
+
+// finalTrailers is called after the Handler exits and returns a non-nil
+// value if the Handler set any trailers.
+func (w *response) finalTrailers() Header {
+ var t Header
+ for k, vv := range w.handlerHeader {
+ if kk, found := strings.CutPrefix(k, TrailerPrefix); found {
+ if t == nil {
+ t = make(Header)
+ }
+ t[kk] = vv
+ }
+ }
+ for _, k := range w.trailers {
+ if t == nil {
+ t = make(Header)
+ }
+ for _, v := range w.handlerHeader[k] {
+ t.Add(k, v)
+ }
+ }
+ return t
+}
+
+// declareTrailer is called for each Trailer header when the
+// response header is written. It notes that a header will need to be
+// written in the trailers at the end of the response.
+func (w *response) declareTrailer(k string) {
+ k = CanonicalHeaderKey(k)
+ if !httpguts.ValidTrailerHeader(k) {
+ // Forbidden by RFC 7230, section 4.1.2
+ return
+ }
+ w.trailers = append(w.trailers, k)
+}
+
+// requestTooLarge is called by maxBytesReader when too much input has
+// been read from the client.
+func (w *response) requestTooLarge() {
+ w.closeAfterReply = true
+ w.requestBodyLimitHit = true
+ if !w.wroteHeader {
+ w.Header().Set("Connection", "close")
+ }
+}
+
+// writerOnly hides an io.Writer value's optional ReadFrom method
+// from io.Copy.
+type writerOnly struct {
+ io.Writer
+}
+
+// ReadFrom is here to optimize copying from an *os.File regular file
+// to a *net.TCPConn with sendfile, or from a supported src type such
+// as a *net.TCPConn on Linux with splice.
+func (w *response) ReadFrom(src io.Reader) (n int64, err error) {
+ bufp := copyBufPool.Get().(*[]byte)
+ buf := *bufp
+ defer copyBufPool.Put(bufp)
+
+ // Our underlying w.conn.rwc is usually a *TCPConn (with its
+ // own ReadFrom method). If not, just fall back to the normal
+ // copy method.
+ rf, ok := w.conn.rwc.(io.ReaderFrom)
+ if !ok {
+ return io.CopyBuffer(writerOnly{w}, src, buf)
+ }
+
+ // Copy the first sniffLen bytes before switching to ReadFrom.
+ // This ensures we don't start writing the response before the
+ // source is available (see golang.org/issue/5660) and provides
+ // enough bytes to perform Content-Type sniffing when required.
+ if !w.cw.wroteHeader {
+ n0, err := io.CopyBuffer(writerOnly{w}, io.LimitReader(src, sniffLen), buf)
+ n += n0
+ if err != nil || n0 < sniffLen {
+ return n, err
+ }
+ }
+
+ w.w.Flush() // get rid of any previous writes
+ w.cw.flush() // make sure Header is written; flush data to rwc
+
+ // Now that cw has been flushed, its chunking field is guaranteed initialized.
+ if !w.cw.chunking && w.bodyAllowed() {
+ n0, err := rf.ReadFrom(src)
+ n += n0
+ w.written += n0
+ return n, err
+ }
+
+ n0, err := io.CopyBuffer(writerOnly{w}, src, buf)
+ n += n0
+ return n, err
+}
+
+// debugServerConnections controls whether all server connections are wrapped
+// with a verbose logging wrapper.
+const debugServerConnections = false
+
+// Create new connection from rwc.
+func (srv *Server) newConn(rwc net.Conn) *conn {
+ c := &conn{
+ server: srv,
+ rwc: rwc,
+ }
+ if debugServerConnections {
+ c.rwc = newLoggingConn("server", c.rwc)
+ }
+ return c
+}
+
+type readResult struct {
+ _ incomparable
+ n int
+ err error
+ b byte // byte read, if n == 1
+}
+
+// connReader is the io.Reader wrapper used by *conn. It combines a
+// selectively-activated io.LimitedReader (to bound request header
+// read sizes) with support for selectively keeping an io.Reader.Read
+// call blocked in a background goroutine to wait for activity and
+// trigger a CloseNotifier channel.
+type connReader struct {
+ conn *conn
+
+ mu sync.Mutex // guards following
+ hasByte bool
+ byteBuf [1]byte
+ cond *sync.Cond
+ inRead bool
+ aborted bool // set true before conn.rwc deadline is set to past
+ remain int64 // bytes remaining
+}
+
+func (cr *connReader) lock() {
+ cr.mu.Lock()
+ if cr.cond == nil {
+ cr.cond = sync.NewCond(&cr.mu)
+ }
+}
+
+func (cr *connReader) unlock() { cr.mu.Unlock() }
+
+func (cr *connReader) startBackgroundRead() {
+ cr.lock()
+ defer cr.unlock()
+ if cr.inRead {
+ panic("invalid concurrent Body.Read call")
+ }
+ if cr.hasByte {
+ return
+ }
+ cr.inRead = true
+ cr.conn.rwc.SetReadDeadline(time.Time{})
+ go cr.backgroundRead()
+}
+
+func (cr *connReader) backgroundRead() {
+ n, err := cr.conn.rwc.Read(cr.byteBuf[:])
+ cr.lock()
+ if n == 1 {
+ cr.hasByte = true
+ // We were past the end of the previous request's body already
+ // (since we wouldn't be in a background read otherwise), so
+ // this is a pipelined HTTP request. Prior to Go 1.11 we used to
+ // send on the CloseNotify channel and cancel the context here,
+ // but the behavior was documented as only "may", and we only
+ // did that because that's how CloseNotify accidentally behaved
+ // in very early Go releases prior to context support. Once we
+ // added context support, people used a Handler's
+ // Request.Context() and passed it along. Having that context
+ // cancel on pipelined HTTP requests caused problems.
+ // Fortunately, almost nothing uses HTTP/1.x pipelining.
+ // Unfortunately, apt-get does, or sometimes does.
+ // New Go 1.11 behavior: don't fire CloseNotify or cancel
+ // contexts on pipelined requests. Shouldn't affect people, but
+ // fixes cases like Issue 23921. This does mean that a client
+ // closing their TCP connection after sending a pipelined
+ // request won't cancel the context, but we'll catch that on any
+ // write failure (in checkConnErrorWriter.Write).
+ // If the server never writes, yes, there are still contrived
+ // server & client behaviors where this fails to ever cancel the
+ // context, but that's kinda why HTTP/1.x pipelining died
+ // anyway.
+ }
+ if ne, ok := err.(net.Error); ok && cr.aborted && ne.Timeout() {
+ // Ignore this error. It's the expected error from
+ // another goroutine calling abortPendingRead.
+ } else if err != nil {
+ cr.handleReadError(err)
+ }
+ cr.aborted = false
+ cr.inRead = false
+ cr.unlock()
+ cr.cond.Broadcast()
+}
+
+func (cr *connReader) abortPendingRead() {
+ cr.lock()
+ defer cr.unlock()
+ if !cr.inRead {
+ return
+ }
+ cr.aborted = true
+ cr.conn.rwc.SetReadDeadline(aLongTimeAgo)
+ for cr.inRead {
+ cr.cond.Wait()
+ }
+ cr.conn.rwc.SetReadDeadline(time.Time{})
+}
+
+func (cr *connReader) setReadLimit(remain int64) { cr.remain = remain }
+func (cr *connReader) setInfiniteReadLimit() { cr.remain = maxInt64 }
+func (cr *connReader) hitReadLimit() bool { return cr.remain <= 0 }
+
+// handleReadError is called whenever a Read from the client returns a
+// non-nil error.
+//
+// The provided non-nil err is almost always io.EOF or a "use of
+// closed network connection". In any case, the error is not
+// particularly interesting, except perhaps for debugging during
+// development. Any error means the connection is dead and we should
+// down its context.
+//
+// It may be called from multiple goroutines.
+func (cr *connReader) handleReadError(_ error) {
+ cr.conn.cancelCtx()
+ cr.closeNotify()
+}
+
+// may be called from multiple goroutines.
+func (cr *connReader) closeNotify() {
+ res := cr.conn.curReq.Load()
+ if res != nil && !res.didCloseNotify.Swap(true) {
+ res.closeNotifyCh <- true
+ }
+}
+
+func (cr *connReader) Read(p []byte) (n int, err error) {
+ cr.lock()
+ if cr.inRead {
+ cr.unlock()
+ if cr.conn.hijacked() {
+ panic("invalid Body.Read call. After hijacked, the original Request must not be used")
+ }
+ panic("invalid concurrent Body.Read call")
+ }
+ if cr.hitReadLimit() {
+ cr.unlock()
+ return 0, io.EOF
+ }
+ if len(p) == 0 {
+ cr.unlock()
+ return 0, nil
+ }
+ if int64(len(p)) > cr.remain {
+ p = p[:cr.remain]
+ }
+ if cr.hasByte {
+ p[0] = cr.byteBuf[0]
+ cr.hasByte = false
+ cr.unlock()
+ return 1, nil
+ }
+ cr.inRead = true
+ cr.unlock()
+ n, err = cr.conn.rwc.Read(p)
+
+ cr.lock()
+ cr.inRead = false
+ if err != nil {
+ cr.handleReadError(err)
+ }
+ cr.remain -= int64(n)
+ cr.unlock()
+
+ cr.cond.Broadcast()
+ return n, err
+}
+
+var (
+ bufioReaderPool sync.Pool
+ bufioWriter2kPool sync.Pool
+ bufioWriter4kPool sync.Pool
+)
+
+var copyBufPool = sync.Pool{
+ New: func() any {
+ b := make([]byte, 32*1024)
+ return &b
+ },
+}
+
+func bufioWriterPool(size int) *sync.Pool {
+ switch size {
+ case 2 << 10:
+ return &bufioWriter2kPool
+ case 4 << 10:
+ return &bufioWriter4kPool
+ }
+ return nil
+}
+
+func newBufioReader(r io.Reader) *bufio.Reader {
+ if v := bufioReaderPool.Get(); v != nil {
+ br := v.(*bufio.Reader)
+ br.Reset(r)
+ return br
+ }
+ // Note: if this reader size is ever changed, update
+ // TestHandlerBodyClose's assumptions.
+ return bufio.NewReader(r)
+}
+
+func putBufioReader(br *bufio.Reader) {
+ br.Reset(nil)
+ bufioReaderPool.Put(br)
+}
+
+func newBufioWriterSize(w io.Writer, size int) *bufio.Writer {
+ pool := bufioWriterPool(size)
+ if pool != nil {
+ if v := pool.Get(); v != nil {
+ bw := v.(*bufio.Writer)
+ bw.Reset(w)
+ return bw
+ }
+ }
+ return bufio.NewWriterSize(w, size)
+}
+
+func putBufioWriter(bw *bufio.Writer) {
+ bw.Reset(nil)
+ if pool := bufioWriterPool(bw.Available()); pool != nil {
+ pool.Put(bw)
+ }
+}
+
+// DefaultMaxHeaderBytes is the maximum permitted size of the headers
+// in an HTTP request.
+// This can be overridden by setting Server.MaxHeaderBytes.
+const DefaultMaxHeaderBytes = 1 << 20 // 1 MB
+
+func (srv *Server) maxHeaderBytes() int {
+ if srv.MaxHeaderBytes > 0 {
+ return srv.MaxHeaderBytes
+ }
+ return DefaultMaxHeaderBytes
+}
+
+func (srv *Server) initialReadLimitSize() int64 {
+ return int64(srv.maxHeaderBytes()) + 4096 // bufio slop
+}
+
+// tlsHandshakeTimeout returns the time limit permitted for the TLS
+// handshake, or zero for unlimited.
+//
+// It returns the minimum of any positive ReadHeaderTimeout,
+// ReadTimeout, or WriteTimeout.
+func (srv *Server) tlsHandshakeTimeout() time.Duration {
+ var ret time.Duration
+ for _, v := range [...]time.Duration{
+ srv.ReadHeaderTimeout,
+ srv.ReadTimeout,
+ srv.WriteTimeout,
+ } {
+ if v <= 0 {
+ continue
+ }
+ if ret == 0 || v < ret {
+ ret = v
+ }
+ }
+ return ret
+}
+
+// wrapper around io.ReadCloser which on first read, sends an
+// HTTP/1.1 100 Continue header
+type expectContinueReader struct {
+ resp *response
+ readCloser io.ReadCloser
+ closed atomic.Bool
+ sawEOF atomic.Bool
+}
+
+func (ecr *expectContinueReader) Read(p []byte) (n int, err error) {
+ if ecr.closed.Load() {
+ return 0, ErrBodyReadAfterClose
+ }
+ w := ecr.resp
+ if !w.wroteContinue && w.canWriteContinue.Load() && !w.conn.hijacked() {
+ w.wroteContinue = true
+ w.writeContinueMu.Lock()
+ if w.canWriteContinue.Load() {
+ w.conn.bufw.WriteString("HTTP/1.1 100 Continue\r\n\r\n")
+ w.conn.bufw.Flush()
+ w.canWriteContinue.Store(false)
+ }
+ w.writeContinueMu.Unlock()
+ }
+ n, err = ecr.readCloser.Read(p)
+ if err == io.EOF {
+ ecr.sawEOF.Store(true)
+ }
+ return
+}
+
+func (ecr *expectContinueReader) Close() error {
+ ecr.closed.Store(true)
+ return ecr.readCloser.Close()
+}
+
+// TimeFormat is the time format to use when generating times in HTTP
+// headers. It is like time.RFC1123 but hard-codes GMT as the time
+// zone. The time being formatted must be in UTC for Format to
+// generate the correct format.
+//
+// For parsing this time format, see ParseTime.
+const TimeFormat = "Mon, 02 Jan 2006 15:04:05 GMT"
+
+// appendTime is a non-allocating version of []byte(t.UTC().Format(TimeFormat))
+func appendTime(b []byte, t time.Time) []byte {
+ const days = "SunMonTueWedThuFriSat"
+ const months = "JanFebMarAprMayJunJulAugSepOctNovDec"
+
+ t = t.UTC()
+ yy, mm, dd := t.Date()
+ hh, mn, ss := t.Clock()
+ day := days[3*t.Weekday():]
+ mon := months[3*(mm-1):]
+
+ return append(b,
+ day[0], day[1], day[2], ',', ' ',
+ byte('0'+dd/10), byte('0'+dd%10), ' ',
+ mon[0], mon[1], mon[2], ' ',
+ byte('0'+yy/1000), byte('0'+(yy/100)%10), byte('0'+(yy/10)%10), byte('0'+yy%10), ' ',
+ byte('0'+hh/10), byte('0'+hh%10), ':',
+ byte('0'+mn/10), byte('0'+mn%10), ':',
+ byte('0'+ss/10), byte('0'+ss%10), ' ',
+ 'G', 'M', 'T')
+}
+
+var errTooLarge = errors.New("http: request too large")
+
+// Read next request from connection.
+func (c *conn) readRequest(ctx context.Context) (w *response, err error) {
+ if c.hijacked() {
+ return nil, ErrHijacked
+ }
+
+ var (
+ wholeReqDeadline time.Time // or zero if none
+ hdrDeadline time.Time // or zero if none
+ )
+ t0 := time.Now()
+ if d := c.server.readHeaderTimeout(); d > 0 {
+ hdrDeadline = t0.Add(d)
+ }
+ if d := c.server.ReadTimeout; d > 0 {
+ wholeReqDeadline = t0.Add(d)
+ }
+ c.rwc.SetReadDeadline(hdrDeadline)
+ if d := c.server.WriteTimeout; d > 0 {
+ defer func() {
+ c.rwc.SetWriteDeadline(time.Now().Add(d))
+ }()
+ }
+
+ c.r.setReadLimit(c.server.initialReadLimitSize())
+ if c.lastMethod == "POST" {
+ // RFC 7230 section 3 tolerance for old buggy clients.
+ peek, _ := c.bufr.Peek(4) // ReadRequest will get err below
+ c.bufr.Discard(numLeadingCRorLF(peek))
+ }
+ req, err := readRequest(c.bufr)
+ if err != nil {
+ if c.r.hitReadLimit() {
+ return nil, errTooLarge
+ }
+ return nil, err
+ }
+
+ if !http1ServerSupportsRequest(req) {
+ return nil, statusError{StatusHTTPVersionNotSupported, "unsupported protocol version"}
+ }
+
+ c.lastMethod = req.Method
+ c.r.setInfiniteReadLimit()
+
+ hosts, haveHost := req.Header["Host"]
+ isH2Upgrade := req.isH2Upgrade()
+ if req.ProtoAtLeast(1, 1) && (!haveHost || len(hosts) == 0) && !isH2Upgrade && req.Method != "CONNECT" {
+ return nil, badRequestError("missing required Host header")
+ }
+ if len(hosts) == 1 && !httpguts.ValidHostHeader(hosts[0]) {
+ return nil, badRequestError("malformed Host header")
+ }
+ for k, vv := range req.Header {
+ if !httpguts.ValidHeaderFieldName(k) {
+ return nil, badRequestError("invalid header name")
+ }
+ for _, v := range vv {
+ if !httpguts.ValidHeaderFieldValue(v) {
+ return nil, badRequestError("invalid header value")
+ }
+ }
+ }
+ delete(req.Header, "Host")
+
+ ctx, cancelCtx := context.WithCancel(ctx)
+ req.ctx = ctx
+ req.RemoteAddr = c.remoteAddr
+ req.TLS = c.tlsState
+ if body, ok := req.Body.(*body); ok {
+ body.doEarlyClose = true
+ }
+
+ // Adjust the read deadline if necessary.
+ if !hdrDeadline.Equal(wholeReqDeadline) {
+ c.rwc.SetReadDeadline(wholeReqDeadline)
+ }
+
+ w = &response{
+ conn: c,
+ cancelCtx: cancelCtx,
+ req: req,
+ reqBody: req.Body,
+ handlerHeader: make(Header),
+ contentLength: -1,
+ closeNotifyCh: make(chan bool, 1),
+
+ // We populate these ahead of time so we're not
+ // reading from req.Header after their Handler starts
+ // and maybe mutates it (Issue 14940)
+ wants10KeepAlive: req.wantsHttp10KeepAlive(),
+ wantsClose: req.wantsClose(),
+ }
+ if isH2Upgrade {
+ w.closeAfterReply = true
+ }
+ w.cw.res = w
+ w.w = newBufioWriterSize(&w.cw, bufferBeforeChunkingSize)
+ return w, nil
+}
+
+// http1ServerSupportsRequest reports whether Go's HTTP/1.x server
+// supports the given request.
+func http1ServerSupportsRequest(req *Request) bool {
+ if req.ProtoMajor == 1 {
+ return true
+ }
+ // Accept "PRI * HTTP/2.0" upgrade requests, so Handlers can
+ // wire up their own HTTP/2 upgrades.
+ if req.ProtoMajor == 2 && req.ProtoMinor == 0 &&
+ req.Method == "PRI" && req.RequestURI == "*" {
+ return true
+ }
+ // Reject HTTP/0.x, and all other HTTP/2+ requests (which
+ // aren't encoded in ASCII anyway).
+ return false
+}
+
+func (w *response) Header() Header {
+ if w.cw.header == nil && w.wroteHeader && !w.cw.wroteHeader {
+ // Accessing the header between logically writing it
+ // and physically writing it means we need to allocate
+ // a clone to snapshot the logically written state.
+ w.cw.header = w.handlerHeader.Clone()
+ }
+ w.calledHeader = true
+ return w.handlerHeader
+}
+
+// maxPostHandlerReadBytes is the max number of Request.Body bytes not
+// consumed by a handler that the server will read from the client
+// in order to keep a connection alive. If there are more bytes than
+// this then the server to be paranoid instead sends a "Connection:
+// close" response.
+//
+// This number is approximately what a typical machine's TCP buffer
+// size is anyway. (if we have the bytes on the machine, we might as
+// well read them)
+const maxPostHandlerReadBytes = 256 << 10
+
+func checkWriteHeaderCode(code int) {
+ // Issue 22880: require valid WriteHeader status codes.
+ // For now we only enforce that it's three digits.
+ // In the future we might block things over 599 (600 and above aren't defined
+ // at https://httpwg.org/specs/rfc7231.html#status.codes).
+ // But for now any three digits.
+ //
+ // We used to send "HTTP/1.1 000 0" on the wire in responses but there's
+ // no equivalent bogus thing we can realistically send in HTTP/2,
+ // so we'll consistently panic instead and help people find their bugs
+ // early. (We can't return an error from WriteHeader even if we wanted to.)
+ if code < 100 || code > 999 {
+ panic(fmt.Sprintf("invalid WriteHeader code %v", code))
+ }
+}
+
+// relevantCaller searches the call stack for the first function outside of net/http.
+// The purpose of this function is to provide more helpful error messages.
+func relevantCaller() runtime.Frame {
+ pc := make([]uintptr, 16)
+ n := runtime.Callers(1, pc)
+ frames := runtime.CallersFrames(pc[:n])
+ var frame runtime.Frame
+ for {
+ frame, more := frames.Next()
+ if !strings.HasPrefix(frame.Function, "net/http.") {
+ return frame
+ }
+ if !more {
+ break
+ }
+ }
+ return frame
+}
+
+func (w *response) WriteHeader(code int) {
+ if w.conn.hijacked() {
+ caller := relevantCaller()
+ w.conn.server.logf("http: response.WriteHeader on hijacked connection from %s (%s:%d)", caller.Function, path.Base(caller.File), caller.Line)
+ return
+ }
+ if w.wroteHeader {
+ caller := relevantCaller()
+ w.conn.server.logf("http: superfluous response.WriteHeader call from %s (%s:%d)", caller.Function, path.Base(caller.File), caller.Line)
+ return
+ }
+ checkWriteHeaderCode(code)
+
+ // Handle informational headers.
+ //
+ // We shouldn't send any further headers after 101 Switching Protocols,
+ // so it takes the non-informational path.
+ if code >= 100 && code <= 199 && code != StatusSwitchingProtocols {
+ // Prevent a potential race with an automatically-sent 100 Continue triggered by Request.Body.Read()
+ if code == 100 && w.canWriteContinue.Load() {
+ w.writeContinueMu.Lock()
+ w.canWriteContinue.Store(false)
+ w.writeContinueMu.Unlock()
+ }
+
+ writeStatusLine(w.conn.bufw, w.req.ProtoAtLeast(1, 1), code, w.statusBuf[:])
+
+ // Per RFC 8297 we must not clear the current header map
+ w.handlerHeader.WriteSubset(w.conn.bufw, excludedHeadersNoBody)
+ w.conn.bufw.Write(crlf)
+ w.conn.bufw.Flush()
+
+ return
+ }
+
+ w.wroteHeader = true
+ w.status = code
+
+ if w.calledHeader && w.cw.header == nil {
+ w.cw.header = w.handlerHeader.Clone()
+ }
+
+ if cl := w.handlerHeader.get("Content-Length"); cl != "" {
+ v, err := strconv.ParseInt(cl, 10, 64)
+ if err == nil && v >= 0 {
+ w.contentLength = v
+ } else {
+ w.conn.server.logf("http: invalid Content-Length of %q", cl)
+ w.handlerHeader.Del("Content-Length")
+ }
+ }
+}
+
+// extraHeader is the set of headers sometimes added by chunkWriter.writeHeader.
+// This type is used to avoid extra allocations from cloning and/or populating
+// the response Header map and all its 1-element slices.
+type extraHeader struct {
+ contentType string
+ connection string
+ transferEncoding string
+ date []byte // written if not nil
+ contentLength []byte // written if not nil
+}
+
+// Sorted the same as extraHeader.Write's loop.
+var extraHeaderKeys = [][]byte{
+ []byte("Content-Type"),
+ []byte("Connection"),
+ []byte("Transfer-Encoding"),
+}
+
+var (
+ headerContentLength = []byte("Content-Length: ")
+ headerDate = []byte("Date: ")
+)
+
+// Write writes the headers described in h to w.
+//
+// This method has a value receiver, despite the somewhat large size
+// of h, because it prevents an allocation. The escape analysis isn't
+// smart enough to realize this function doesn't mutate h.
+func (h extraHeader) Write(w *bufio.Writer) {
+ if h.date != nil {
+ w.Write(headerDate)
+ w.Write(h.date)
+ w.Write(crlf)
+ }
+ if h.contentLength != nil {
+ w.Write(headerContentLength)
+ w.Write(h.contentLength)
+ w.Write(crlf)
+ }
+ for i, v := range []string{h.contentType, h.connection, h.transferEncoding} {
+ if v != "" {
+ w.Write(extraHeaderKeys[i])
+ w.Write(colonSpace)
+ w.WriteString(v)
+ w.Write(crlf)
+ }
+ }
+}
+
+// writeHeader finalizes the header sent to the client and writes it
+// to cw.res.conn.bufw.
+//
+// p is not written by writeHeader, but is the first chunk of the body
+// that will be written. It is sniffed for a Content-Type if none is
+// set explicitly. It's also used to set the Content-Length, if the
+// total body size was small and the handler has already finished
+// running.
+func (cw *chunkWriter) writeHeader(p []byte) {
+ if cw.wroteHeader {
+ return
+ }
+ cw.wroteHeader = true
+
+ w := cw.res
+ keepAlivesEnabled := w.conn.server.doKeepAlives()
+ isHEAD := w.req.Method == "HEAD"
+
+ // header is written out to w.conn.buf below. Depending on the
+ // state of the handler, we either own the map or not. If we
+ // don't own it, the exclude map is created lazily for
+ // WriteSubset to remove headers. The setHeader struct holds
+ // headers we need to add.
+ header := cw.header
+ owned := header != nil
+ if !owned {
+ header = w.handlerHeader
+ }
+ var excludeHeader map[string]bool
+ delHeader := func(key string) {
+ if owned {
+ header.Del(key)
+ return
+ }
+ if _, ok := header[key]; !ok {
+ return
+ }
+ if excludeHeader == nil {
+ excludeHeader = make(map[string]bool)
+ }
+ excludeHeader[key] = true
+ }
+ var setHeader extraHeader
+
+ // Don't write out the fake "Trailer:foo" keys. See TrailerPrefix.
+ trailers := false
+ for k := range cw.header {
+ if strings.HasPrefix(k, TrailerPrefix) {
+ if excludeHeader == nil {
+ excludeHeader = make(map[string]bool)
+ }
+ excludeHeader[k] = true
+ trailers = true
+ }
+ }
+ for _, v := range cw.header["Trailer"] {
+ trailers = true
+ foreachHeaderElement(v, cw.res.declareTrailer)
+ }
+
+ te := header.get("Transfer-Encoding")
+ hasTE := te != ""
+
+ // If the handler is done but never sent a Content-Length
+ // response header and this is our first (and last) write, set
+ // it, even to zero. This helps HTTP/1.0 clients keep their
+ // "keep-alive" connections alive.
+ // Exceptions: 304/204/1xx responses never get Content-Length, and if
+ // it was a HEAD request, we don't know the difference between
+ // 0 actual bytes and 0 bytes because the handler noticed it
+ // was a HEAD request and chose not to write anything. So for
+ // HEAD, the handler should either write the Content-Length or
+ // write non-zero bytes. If it's actually 0 bytes and the
+ // handler never looked at the Request.Method, we just don't
+ // send a Content-Length header.
+ // Further, we don't send an automatic Content-Length if they
+ // set a Transfer-Encoding, because they're generally incompatible.
+ if w.handlerDone.Load() && !trailers && !hasTE && bodyAllowedForStatus(w.status) && !header.has("Content-Length") && (!isHEAD || len(p) > 0) {
+ w.contentLength = int64(len(p))
+ setHeader.contentLength = strconv.AppendInt(cw.res.clenBuf[:0], int64(len(p)), 10)
+ }
+
+ // If this was an HTTP/1.0 request with keep-alive and we sent a
+ // Content-Length back, we can make this a keep-alive response ...
+ if w.wants10KeepAlive && keepAlivesEnabled {
+ sentLength := header.get("Content-Length") != ""
+ if sentLength && header.get("Connection") == "keep-alive" {
+ w.closeAfterReply = false
+ }
+ }
+
+ // Check for an explicit (and valid) Content-Length header.
+ hasCL := w.contentLength != -1
+
+ if w.wants10KeepAlive && (isHEAD || hasCL || !bodyAllowedForStatus(w.status)) {
+ _, connectionHeaderSet := header["Connection"]
+ if !connectionHeaderSet {
+ setHeader.connection = "keep-alive"
+ }
+ } else if !w.req.ProtoAtLeast(1, 1) || w.wantsClose {
+ w.closeAfterReply = true
+ }
+
+ if header.get("Connection") == "close" || !keepAlivesEnabled {
+ w.closeAfterReply = true
+ }
+
+ // If the client wanted a 100-continue but we never sent it to
+ // them (or, more strictly: we never finished reading their
+ // request body), don't reuse this connection because it's now
+ // in an unknown state: we might be sending this response at
+ // the same time the client is now sending its request body
+ // after a timeout. (Some HTTP clients send Expect:
+ // 100-continue but knowing that some servers don't support
+ // it, the clients set a timer and send the body later anyway)
+ // If we haven't seen EOF, we can't skip over the unread body
+ // because we don't know if the next bytes on the wire will be
+ // the body-following-the-timer or the subsequent request.
+ // See Issue 11549.
+ if ecr, ok := w.req.Body.(*expectContinueReader); ok && !ecr.sawEOF.Load() {
+ w.closeAfterReply = true
+ }
+
+ // We do this by default because there are a number of clients that
+ // send a full request before starting to read the response, and they
+ // can deadlock if we start writing the response with unconsumed body
+ // remaining. See Issue 15527 for some history.
+ //
+ // If full duplex mode has been enabled with ResponseController.EnableFullDuplex,
+ // then leave the request body alone.
+ if w.req.ContentLength != 0 && !w.closeAfterReply && !w.fullDuplex {
+ var discard, tooBig bool
+
+ switch bdy := w.req.Body.(type) {
+ case *expectContinueReader:
+ if bdy.resp.wroteContinue {
+ discard = true
+ }
+ case *body:
+ bdy.mu.Lock()
+ switch {
+ case bdy.closed:
+ if !bdy.sawEOF {
+ // Body was closed in handler with non-EOF error.
+ w.closeAfterReply = true
+ }
+ case bdy.unreadDataSizeLocked() >= maxPostHandlerReadBytes:
+ tooBig = true
+ default:
+ discard = true
+ }
+ bdy.mu.Unlock()
+ default:
+ discard = true
+ }
+
+ if discard {
+ _, err := io.CopyN(io.Discard, w.reqBody, maxPostHandlerReadBytes+1)
+ switch err {
+ case nil:
+ // There must be even more data left over.
+ tooBig = true
+ case ErrBodyReadAfterClose:
+ // Body was already consumed and closed.
+ case io.EOF:
+ // The remaining body was just consumed, close it.
+ err = w.reqBody.Close()
+ if err != nil {
+ w.closeAfterReply = true
+ }
+ default:
+ // Some other kind of error occurred, like a read timeout, or
+ // corrupt chunked encoding. In any case, whatever remains
+ // on the wire must not be parsed as another HTTP request.
+ w.closeAfterReply = true
+ }
+ }
+
+ if tooBig {
+ w.requestTooLarge()
+ delHeader("Connection")
+ setHeader.connection = "close"
+ }
+ }
+
+ code := w.status
+ if bodyAllowedForStatus(code) {
+ // If no content type, apply sniffing algorithm to body.
+ _, haveType := header["Content-Type"]
+
+ // If the Content-Encoding was set and is non-blank,
+ // we shouldn't sniff the body. See Issue 31753.
+ ce := header.Get("Content-Encoding")
+ hasCE := len(ce) > 0
+ if !hasCE && !haveType && !hasTE && len(p) > 0 {
+ setHeader.contentType = DetectContentType(p)
+ }
+ } else {
+ for _, k := range suppressedHeaders(code) {
+ delHeader(k)
+ }
+ }
+
+ if !header.has("Date") {
+ setHeader.date = appendTime(cw.res.dateBuf[:0], time.Now())
+ }
+
+ if hasCL && hasTE && te != "identity" {
+ // TODO: return an error if WriteHeader gets a return parameter
+ // For now just ignore the Content-Length.
+ w.conn.server.logf("http: WriteHeader called with both Transfer-Encoding of %q and a Content-Length of %d",
+ te, w.contentLength)
+ delHeader("Content-Length")
+ hasCL = false
+ }
+
+ if w.req.Method == "HEAD" || !bodyAllowedForStatus(code) || code == StatusNoContent {
+ // Response has no body.
+ delHeader("Transfer-Encoding")
+ } else if hasCL {
+ // Content-Length has been provided, so no chunking is to be done.
+ delHeader("Transfer-Encoding")
+ } else if w.req.ProtoAtLeast(1, 1) {
+ // HTTP/1.1 or greater: Transfer-Encoding has been set to identity, and no
+ // content-length has been provided. The connection must be closed after the
+ // reply is written, and no chunking is to be done. This is the setup
+ // recommended in the Server-Sent Events candidate recommendation 11,
+ // section 8.
+ if hasTE && te == "identity" {
+ cw.chunking = false
+ w.closeAfterReply = true
+ delHeader("Transfer-Encoding")
+ } else {
+ // HTTP/1.1 or greater: use chunked transfer encoding
+ // to avoid closing the connection at EOF.
+ cw.chunking = true
+ setHeader.transferEncoding = "chunked"
+ if hasTE && te == "chunked" {
+ // We will send the chunked Transfer-Encoding header later.
+ delHeader("Transfer-Encoding")
+ }
+ }
+ } else {
+ // HTTP version < 1.1: cannot do chunked transfer
+ // encoding and we don't know the Content-Length so
+ // signal EOF by closing connection.
+ w.closeAfterReply = true
+ delHeader("Transfer-Encoding") // in case already set
+ }
+
+ // Cannot use Content-Length with non-identity Transfer-Encoding.
+ if cw.chunking {
+ delHeader("Content-Length")
+ }
+ if !w.req.ProtoAtLeast(1, 0) {
+ return
+ }
+
+ // Only override the Connection header if it is not a successful
+ // protocol switch response and if KeepAlives are not enabled.
+ // See https://golang.org/issue/36381.
+ delConnectionHeader := w.closeAfterReply &&
+ (!keepAlivesEnabled || !hasToken(cw.header.get("Connection"), "close")) &&
+ !isProtocolSwitchResponse(w.status, header)
+ if delConnectionHeader {
+ delHeader("Connection")
+ if w.req.ProtoAtLeast(1, 1) {
+ setHeader.connection = "close"
+ }
+ }
+
+ writeStatusLine(w.conn.bufw, w.req.ProtoAtLeast(1, 1), code, w.statusBuf[:])
+ cw.header.WriteSubset(w.conn.bufw, excludeHeader)
+ setHeader.Write(w.conn.bufw)
+ w.conn.bufw.Write(crlf)
+}
+
+// foreachHeaderElement splits v according to the "#rule" construction
+// in RFC 7230 section 7 and calls fn for each non-empty element.
+func foreachHeaderElement(v string, fn func(string)) {
+ v = textproto.TrimString(v)
+ if v == "" {
+ return
+ }
+ if !strings.Contains(v, ",") {
+ fn(v)
+ return
+ }
+ for _, f := range strings.Split(v, ",") {
+ if f = textproto.TrimString(f); f != "" {
+ fn(f)
+ }
+ }
+}
+
+// writeStatusLine writes an HTTP/1.x Status-Line (RFC 7230 Section 3.1.2)
+// to bw. is11 is whether the HTTP request is HTTP/1.1. false means HTTP/1.0.
+// code is the response status code.
+// scratch is an optional scratch buffer. If it has at least capacity 3, it's used.
+func writeStatusLine(bw *bufio.Writer, is11 bool, code int, scratch []byte) {
+ if is11 {
+ bw.WriteString("HTTP/1.1 ")
+ } else {
+ bw.WriteString("HTTP/1.0 ")
+ }
+ if text := StatusText(code); text != "" {
+ bw.Write(strconv.AppendInt(scratch[:0], int64(code), 10))
+ bw.WriteByte(' ')
+ bw.WriteString(text)
+ bw.WriteString("\r\n")
+ } else {
+ // don't worry about performance
+ fmt.Fprintf(bw, "%03d status code %d\r\n", code, code)
+ }
+}
+
+// bodyAllowed reports whether a Write is allowed for this response type.
+// It's illegal to call this before the header has been flushed.
+func (w *response) bodyAllowed() bool {
+ if !w.wroteHeader {
+ panic("")
+ }
+ return bodyAllowedForStatus(w.status)
+}
+
+// The Life Of A Write is like this:
+//
+// Handler starts. No header has been sent. The handler can either
+// write a header, or just start writing. Writing before sending a header
+// sends an implicitly empty 200 OK header.
+//
+// If the handler didn't declare a Content-Length up front, we either
+// go into chunking mode or, if the handler finishes running before
+// the chunking buffer size, we compute a Content-Length and send that
+// in the header instead.
+//
+// Likewise, if the handler didn't set a Content-Type, we sniff that
+// from the initial chunk of output.
+//
+// The Writers are wired together like:
+//
+// 1. *response (the ResponseWriter) ->
+// 2. (*response).w, a *bufio.Writer of bufferBeforeChunkingSize bytes ->
+// 3. chunkWriter.Writer (whose writeHeader finalizes Content-Length/Type)
+// and which writes the chunk headers, if needed ->
+// 4. conn.bufw, a *bufio.Writer of default (4kB) bytes, writing to ->
+// 5. checkConnErrorWriter{c}, which notes any non-nil error on Write
+// and populates c.werr with it if so, but otherwise writes to ->
+// 6. the rwc, the net.Conn.
+//
+// TODO(bradfitz): short-circuit some of the buffering when the
+// initial header contains both a Content-Type and Content-Length.
+// Also short-circuit in (1) when the header's been sent and not in
+// chunking mode, writing directly to (4) instead, if (2) has no
+// buffered data. More generally, we could short-circuit from (1) to
+// (3) even in chunking mode if the write size from (1) is over some
+// threshold and nothing is in (2). The answer might be mostly making
+// bufferBeforeChunkingSize smaller and having bufio's fast-paths deal
+// with this instead.
+func (w *response) Write(data []byte) (n int, err error) {
+ return w.write(len(data), data, "")
+}
+
+func (w *response) WriteString(data string) (n int, err error) {
+ return w.write(len(data), nil, data)
+}
+
+// either dataB or dataS is non-zero.
+func (w *response) write(lenData int, dataB []byte, dataS string) (n int, err error) {
+ if w.conn.hijacked() {
+ if lenData > 0 {
+ caller := relevantCaller()
+ w.conn.server.logf("http: response.Write on hijacked connection from %s (%s:%d)", caller.Function, path.Base(caller.File), caller.Line)
+ }
+ return 0, ErrHijacked
+ }
+
+ if w.canWriteContinue.Load() {
+ // Body reader wants to write 100 Continue but hasn't yet.
+ // Tell it not to. The store must be done while holding the lock
+ // because the lock makes sure that there is not an active write
+ // this very moment.
+ w.writeContinueMu.Lock()
+ w.canWriteContinue.Store(false)
+ w.writeContinueMu.Unlock()
+ }
+
+ if !w.wroteHeader {
+ w.WriteHeader(StatusOK)
+ }
+ if lenData == 0 {
+ return 0, nil
+ }
+ if !w.bodyAllowed() {
+ return 0, ErrBodyNotAllowed
+ }
+
+ w.written += int64(lenData) // ignoring errors, for errorKludge
+ if w.contentLength != -1 && w.written > w.contentLength {
+ return 0, ErrContentLength
+ }
+ if dataB != nil {
+ return w.w.Write(dataB)
+ } else {
+ return w.w.WriteString(dataS)
+ }
+}
+
+func (w *response) finishRequest() {
+ w.handlerDone.Store(true)
+
+ if !w.wroteHeader {
+ w.WriteHeader(StatusOK)
+ }
+
+ w.w.Flush()
+ putBufioWriter(w.w)
+ w.cw.close()
+ w.conn.bufw.Flush()
+
+ w.conn.r.abortPendingRead()
+
+ // Close the body (regardless of w.closeAfterReply) so we can
+ // re-use its bufio.Reader later safely.
+ w.reqBody.Close()
+
+ if w.req.MultipartForm != nil {
+ w.req.MultipartForm.RemoveAll()
+ }
+}
+
+// shouldReuseConnection reports whether the underlying TCP connection can be reused.
+// It must only be called after the handler is done executing.
+func (w *response) shouldReuseConnection() bool {
+ if w.closeAfterReply {
+ // The request or something set while executing the
+ // handler indicated we shouldn't reuse this
+ // connection.
+ return false
+ }
+
+ if w.req.Method != "HEAD" && w.contentLength != -1 && w.bodyAllowed() && w.contentLength != w.written {
+ // Did not write enough. Avoid getting out of sync.
+ return false
+ }
+
+ // There was some error writing to the underlying connection
+ // during the request, so don't re-use this conn.
+ if w.conn.werr != nil {
+ return false
+ }
+
+ if w.closedRequestBodyEarly() {
+ return false
+ }
+
+ return true
+}
+
+func (w *response) closedRequestBodyEarly() bool {
+ body, ok := w.req.Body.(*body)
+ return ok && body.didEarlyClose()
+}
+
+func (w *response) Flush() {
+ w.FlushError()
+}
+
+func (w *response) FlushError() error {
+ if !w.wroteHeader {
+ w.WriteHeader(StatusOK)
+ }
+ err := w.w.Flush()
+ e2 := w.cw.flush()
+ if err == nil {
+ err = e2
+ }
+ return err
+}
+
+func (c *conn) finalFlush() {
+ if c.bufr != nil {
+ // Steal the bufio.Reader (~4KB worth of memory) and its associated
+ // reader for a future connection.
+ putBufioReader(c.bufr)
+ c.bufr = nil
+ }
+
+ if c.bufw != nil {
+ c.bufw.Flush()
+ // Steal the bufio.Writer (~4KB worth of memory) and its associated
+ // writer for a future connection.
+ putBufioWriter(c.bufw)
+ c.bufw = nil
+ }
+}
+
+// Close the connection.
+func (c *conn) close() {
+ c.finalFlush()
+ c.rwc.Close()
+}
+
+// rstAvoidanceDelay is the amount of time we sleep after closing the
+// write side of a TCP connection before closing the entire socket.
+// By sleeping, we increase the chances that the client sees our FIN
+// and processes its final data before they process the subsequent RST
+// from closing a connection with known unread data.
+// This RST seems to occur mostly on BSD systems. (And Windows?)
+// This timeout is somewhat arbitrary (~latency around the planet).
+const rstAvoidanceDelay = 500 * time.Millisecond
+
+type closeWriter interface {
+ CloseWrite() error
+}
+
+var _ closeWriter = (*net.TCPConn)(nil)
+
+// closeWriteAndWait flushes any outstanding data and sends a FIN packet (if
+// client is connected via TCP), signaling that we're done. We then
+// pause for a bit, hoping the client processes it before any
+// subsequent RST.
+//
+// See https://golang.org/issue/3595
+func (c *conn) closeWriteAndWait() {
+ c.finalFlush()
+ if tcp, ok := c.rwc.(closeWriter); ok {
+ tcp.CloseWrite()
+ }
+ time.Sleep(rstAvoidanceDelay)
+}
+
+// validNextProto reports whether the proto is a valid ALPN protocol name.
+// Everything is valid except the empty string and built-in protocol types,
+// so that those can't be overridden with alternate implementations.
+func validNextProto(proto string) bool {
+ switch proto {
+ case "", "http/1.1", "http/1.0":
+ return false
+ }
+ return true
+}
+
+const (
+ runHooks = true
+ skipHooks = false
+)
+
+func (c *conn) setState(nc net.Conn, state ConnState, runHook bool) {
+ srv := c.server
+ switch state {
+ case StateNew:
+ srv.trackConn(c, true)
+ case StateHijacked, StateClosed:
+ srv.trackConn(c, false)
+ }
+ if state > 0xff || state < 0 {
+ panic("internal error")
+ }
+ packedState := uint64(time.Now().Unix()<<8) | uint64(state)
+ c.curState.Store(packedState)
+ if !runHook {
+ return
+ }
+ if hook := srv.ConnState; hook != nil {
+ hook(nc, state)
+ }
+}
+
+func (c *conn) getState() (state ConnState, unixSec int64) {
+ packedState := c.curState.Load()
+ return ConnState(packedState & 0xff), int64(packedState >> 8)
+}
+
+// badRequestError is a literal string (used by in the server in HTML,
+// unescaped) to tell the user why their request was bad. It should
+// be plain text without user info or other embedded errors.
+func badRequestError(e string) error { return statusError{StatusBadRequest, e} }
+
+// statusError is an error used to respond to a request with an HTTP status.
+// The text should be plain text without user info or other embedded errors.
+type statusError struct {
+ code int
+ text string
+}
+
+func (e statusError) Error() string { return StatusText(e.code) + ": " + e.text }
+
+// ErrAbortHandler is a sentinel panic value to abort a handler.
+// While any panic from ServeHTTP aborts the response to the client,
+// panicking with ErrAbortHandler also suppresses logging of a stack
+// trace to the server's error log.
+var ErrAbortHandler = errors.New("net/http: abort Handler")
+
+// isCommonNetReadError reports whether err is a common error
+// encountered during reading a request off the network when the
+// client has gone away or had its read fail somehow. This is used to
+// determine which logs are interesting enough to log about.
+func isCommonNetReadError(err error) bool {
+ if err == io.EOF {
+ return true
+ }
+ if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
+ return true
+ }
+ if oe, ok := err.(*net.OpError); ok && oe.Op == "read" {
+ return true
+ }
+ return false
+}
+
+// Serve a new connection.
+func (c *conn) serve(ctx context.Context) {
+ if ra := c.rwc.RemoteAddr(); ra != nil {
+ c.remoteAddr = ra.String()
+ }
+ ctx = context.WithValue(ctx, LocalAddrContextKey, c.rwc.LocalAddr())
+ var inFlightResponse *response
+ defer func() {
+ if err := recover(); err != nil && err != ErrAbortHandler {
+ const size = 64 << 10
+ buf := make([]byte, size)
+ buf = buf[:runtime.Stack(buf, false)]
+ c.server.logf("http: panic serving %v: %v\n%s", c.remoteAddr, err, buf)
+ }
+ if inFlightResponse != nil {
+ inFlightResponse.cancelCtx()
+ }
+ if !c.hijacked() {
+ if inFlightResponse != nil {
+ inFlightResponse.conn.r.abortPendingRead()
+ inFlightResponse.reqBody.Close()
+ }
+ c.close()
+ c.setState(c.rwc, StateClosed, runHooks)
+ }
+ }()
+
+ if tlsConn, ok := c.rwc.(*tls.Conn); ok {
+ tlsTO := c.server.tlsHandshakeTimeout()
+ if tlsTO > 0 {
+ dl := time.Now().Add(tlsTO)
+ c.rwc.SetReadDeadline(dl)
+ c.rwc.SetWriteDeadline(dl)
+ }
+ if err := tlsConn.HandshakeContext(ctx); err != nil {
+ // If the handshake failed due to the client not speaking
+ // TLS, assume they're speaking plaintext HTTP and write a
+ // 400 response on the TLS conn's underlying net.Conn.
+ if re, ok := err.(tls.RecordHeaderError); ok && re.Conn != nil && tlsRecordHeaderLooksLikeHTTP(re.RecordHeader) {
+ io.WriteString(re.Conn, "HTTP/1.0 400 Bad Request\r\n\r\nClient sent an HTTP request to an HTTPS server.\n")
+ re.Conn.Close()
+ return
+ }
+ c.server.logf("http: TLS handshake error from %s: %v", c.rwc.RemoteAddr(), err)
+ return
+ }
+ // Restore Conn-level deadlines.
+ if tlsTO > 0 {
+ c.rwc.SetReadDeadline(time.Time{})
+ c.rwc.SetWriteDeadline(time.Time{})
+ }
+ c.tlsState = new(tls.ConnectionState)
+ *c.tlsState = tlsConn.ConnectionState()
+ if proto := c.tlsState.NegotiatedProtocol; validNextProto(proto) {
+ if fn := c.server.TLSNextProto[proto]; fn != nil {
+ h := initALPNRequest{ctx, tlsConn, serverHandler{c.server}}
+ // Mark freshly created HTTP/2 as active and prevent any server state hooks
+ // from being run on these connections. This prevents closeIdleConns from
+ // closing such connections. See issue https://golang.org/issue/39776.
+ c.setState(c.rwc, StateActive, skipHooks)
+ fn(c.server, tlsConn, h)
+ }
+ return
+ }
+ }
+
+ // HTTP/1.x from here on.
+
+ ctx, cancelCtx := context.WithCancel(ctx)
+ c.cancelCtx = cancelCtx
+ defer cancelCtx()
+
+ c.r = &connReader{conn: c}
+ c.bufr = newBufioReader(c.r)
+ c.bufw = newBufioWriterSize(checkConnErrorWriter{c}, 4<<10)
+
+ for {
+ w, err := c.readRequest(ctx)
+ if c.r.remain != c.server.initialReadLimitSize() {
+ // If we read any bytes off the wire, we're active.
+ c.setState(c.rwc, StateActive, runHooks)
+ }
+ if err != nil {
+ const errorHeaders = "\r\nContent-Type: text/plain; charset=utf-8\r\nConnection: close\r\n\r\n"
+
+ switch {
+ case err == errTooLarge:
+ // Their HTTP client may or may not be
+ // able to read this if we're
+ // responding to them and hanging up
+ // while they're still writing their
+ // request. Undefined behavior.
+ const publicErr = "431 Request Header Fields Too Large"
+ fmt.Fprintf(c.rwc, "HTTP/1.1 "+publicErr+errorHeaders+publicErr)
+ c.closeWriteAndWait()
+ return
+
+ case isUnsupportedTEError(err):
+ // Respond as per RFC 7230 Section 3.3.1 which says,
+ // A server that receives a request message with a
+ // transfer coding it does not understand SHOULD
+ // respond with 501 (Unimplemented).
+ code := StatusNotImplemented
+
+ // We purposefully aren't echoing back the transfer-encoding's value,
+ // so as to mitigate the risk of cross side scripting by an attacker.
+ fmt.Fprintf(c.rwc, "HTTP/1.1 %d %s%sUnsupported transfer encoding", code, StatusText(code), errorHeaders)
+ return
+
+ case isCommonNetReadError(err):
+ return // don't reply
+
+ default:
+ if v, ok := err.(statusError); ok {
+ fmt.Fprintf(c.rwc, "HTTP/1.1 %d %s: %s%s%d %s: %s", v.code, StatusText(v.code), v.text, errorHeaders, v.code, StatusText(v.code), v.text)
+ return
+ }
+ publicErr := "400 Bad Request"
+ fmt.Fprintf(c.rwc, "HTTP/1.1 "+publicErr+errorHeaders+publicErr)
+ return
+ }
+ }
+
+ // Expect 100 Continue support
+ req := w.req
+ if req.expectsContinue() {
+ if req.ProtoAtLeast(1, 1) && req.ContentLength != 0 {
+ // Wrap the Body reader with one that replies on the connection
+ req.Body = &expectContinueReader{readCloser: req.Body, resp: w}
+ w.canWriteContinue.Store(true)
+ }
+ } else if req.Header.get("Expect") != "" {
+ w.sendExpectationFailed()
+ return
+ }
+
+ c.curReq.Store(w)
+
+ if requestBodyRemains(req.Body) {
+ registerOnHitEOF(req.Body, w.conn.r.startBackgroundRead)
+ } else {
+ w.conn.r.startBackgroundRead()
+ }
+
+ // HTTP cannot have multiple simultaneous active requests.[*]
+ // Until the server replies to this request, it can't read another,
+ // so we might as well run the handler in this goroutine.
+ // [*] Not strictly true: HTTP pipelining. We could let them all process
+ // in parallel even if their responses need to be serialized.
+ // But we're not going to implement HTTP pipelining because it
+ // was never deployed in the wild and the answer is HTTP/2.
+ inFlightResponse = w
+ serverHandler{c.server}.ServeHTTP(w, w.req)
+ inFlightResponse = nil
+ w.cancelCtx()
+ if c.hijacked() {
+ return
+ }
+ w.finishRequest()
+ c.rwc.SetWriteDeadline(time.Time{})
+ if !w.shouldReuseConnection() {
+ if w.requestBodyLimitHit || w.closedRequestBodyEarly() {
+ c.closeWriteAndWait()
+ }
+ return
+ }
+ c.setState(c.rwc, StateIdle, runHooks)
+ c.curReq.Store(nil)
+
+ if !w.conn.server.doKeepAlives() {
+ // We're in shutdown mode. We might've replied
+ // to the user without "Connection: close" and
+ // they might think they can send another
+ // request, but such is life with HTTP/1.1.
+ return
+ }
+
+ if d := c.server.idleTimeout(); d != 0 {
+ c.rwc.SetReadDeadline(time.Now().Add(d))
+ } else {
+ c.rwc.SetReadDeadline(time.Time{})
+ }
+
+ // Wait for the connection to become readable again before trying to
+ // read the next request. This prevents a ReadHeaderTimeout or
+ // ReadTimeout from starting until the first bytes of the next request
+ // have been received.
+ if _, err := c.bufr.Peek(4); err != nil {
+ return
+ }
+
+ c.rwc.SetReadDeadline(time.Time{})
+ }
+}
+
+func (w *response) sendExpectationFailed() {
+ // TODO(bradfitz): let ServeHTTP handlers handle
+ // requests with non-standard expectation[s]? Seems
+ // theoretical at best, and doesn't fit into the
+ // current ServeHTTP model anyway. We'd need to
+ // make the ResponseWriter an optional
+ // "ExpectReplier" interface or something.
+ //
+ // For now we'll just obey RFC 7231 5.1.1 which says
+ // "A server that receives an Expect field-value other
+ // than 100-continue MAY respond with a 417 (Expectation
+ // Failed) status code to indicate that the unexpected
+ // expectation cannot be met."
+ w.Header().Set("Connection", "close")
+ w.WriteHeader(StatusExpectationFailed)
+ w.finishRequest()
+}
+
+// Hijack implements the Hijacker.Hijack method. Our response is both a ResponseWriter
+// and a Hijacker.
+func (w *response) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) {
+ if w.handlerDone.Load() {
+ panic("net/http: Hijack called after ServeHTTP finished")
+ }
+ if w.wroteHeader {
+ w.cw.flush()
+ }
+
+ c := w.conn
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ // Release the bufioWriter that writes to the chunk writer, it is not
+ // used after a connection has been hijacked.
+ rwc, buf, err = c.hijackLocked()
+ if err == nil {
+ putBufioWriter(w.w)
+ w.w = nil
+ }
+ return rwc, buf, err
+}
+
+func (w *response) CloseNotify() <-chan bool {
+ if w.handlerDone.Load() {
+ panic("net/http: CloseNotify called after ServeHTTP finished")
+ }
+ return w.closeNotifyCh
+}
+
+func registerOnHitEOF(rc io.ReadCloser, fn func()) {
+ switch v := rc.(type) {
+ case *expectContinueReader:
+ registerOnHitEOF(v.readCloser, fn)
+ case *body:
+ v.registerOnHitEOF(fn)
+ default:
+ panic("unexpected type " + fmt.Sprintf("%T", rc))
+ }
+}
+
+// requestBodyRemains reports whether future calls to Read
+// on rc might yield more data.
+func requestBodyRemains(rc io.ReadCloser) bool {
+ if rc == NoBody {
+ return false
+ }
+ switch v := rc.(type) {
+ case *expectContinueReader:
+ return requestBodyRemains(v.readCloser)
+ case *body:
+ return v.bodyRemains()
+ default:
+ panic("unexpected type " + fmt.Sprintf("%T", rc))
+ }
+}
+
+// The HandlerFunc type is an adapter to allow the use of
+// ordinary functions as HTTP handlers. If f is a function
+// with the appropriate signature, HandlerFunc(f) is a
+// Handler that calls f.
+type HandlerFunc func(ResponseWriter, *Request)
+
+// ServeHTTP calls f(w, r).
+func (f HandlerFunc) ServeHTTP(w ResponseWriter, r *Request) {
+ f(w, r)
+}
+
+// Helper handlers
+
+// Error replies to the request with the specified error message and HTTP code.
+// It does not otherwise end the request; the caller should ensure no further
+// writes are done to w.
+// The error message should be plain text.
+func Error(w ResponseWriter, error string, code int) {
+ w.Header().Set("Content-Type", "text/plain; charset=utf-8")
+ w.Header().Set("X-Content-Type-Options", "nosniff")
+ w.WriteHeader(code)
+ fmt.Fprintln(w, error)
+}
+
+// NotFound replies to the request with an HTTP 404 not found error.
+func NotFound(w ResponseWriter, r *Request) { Error(w, "404 page not found", StatusNotFound) }
+
+// NotFoundHandler returns a simple request handler
+// that replies to each request with a “404 page not found” reply.
+func NotFoundHandler() Handler { return HandlerFunc(NotFound) }
+
+// StripPrefix returns a handler that serves HTTP requests by removing the
+// given prefix from the request URL's Path (and RawPath if set) and invoking
+// the handler h. StripPrefix handles a request for a path that doesn't begin
+// with prefix by replying with an HTTP 404 not found error. The prefix must
+// match exactly: if the prefix in the request contains escaped characters
+// the reply is also an HTTP 404 not found error.
+func StripPrefix(prefix string, h Handler) Handler {
+ if prefix == "" {
+ return h
+ }
+ return HandlerFunc(func(w ResponseWriter, r *Request) {
+ p := strings.TrimPrefix(r.URL.Path, prefix)
+ rp := strings.TrimPrefix(r.URL.RawPath, prefix)
+ if len(p) < len(r.URL.Path) && (r.URL.RawPath == "" || len(rp) < len(r.URL.RawPath)) {
+ r2 := new(Request)
+ *r2 = *r
+ r2.URL = new(url.URL)
+ *r2.URL = *r.URL
+ r2.URL.Path = p
+ r2.URL.RawPath = rp
+ h.ServeHTTP(w, r2)
+ } else {
+ NotFound(w, r)
+ }
+ })
+}
+
+// Redirect replies to the request with a redirect to url,
+// which may be a path relative to the request path.
+//
+// The provided code should be in the 3xx range and is usually
+// StatusMovedPermanently, StatusFound or StatusSeeOther.
+//
+// If the Content-Type header has not been set, Redirect sets it
+// to "text/html; charset=utf-8" and writes a small HTML body.
+// Setting the Content-Type header to any value, including nil,
+// disables that behavior.
+func Redirect(w ResponseWriter, r *Request, url string, code int) {
+ if u, err := urlpkg.Parse(url); err == nil {
+ // If url was relative, make its path absolute by
+ // combining with request path.
+ // The client would probably do this for us,
+ // but doing it ourselves is more reliable.
+ // See RFC 7231, section 7.1.2
+ if u.Scheme == "" && u.Host == "" {
+ oldpath := r.URL.Path
+ if oldpath == "" { // should not happen, but avoid a crash if it does
+ oldpath = "/"
+ }
+
+ // no leading http://server
+ if url == "" || url[0] != '/' {
+ // make relative path absolute
+ olddir, _ := path.Split(oldpath)
+ url = olddir + url
+ }
+
+ var query string
+ if i := strings.Index(url, "?"); i != -1 {
+ url, query = url[:i], url[i:]
+ }
+
+ // clean up but preserve trailing slash
+ trailing := strings.HasSuffix(url, "/")
+ url = path.Clean(url)
+ if trailing && !strings.HasSuffix(url, "/") {
+ url += "/"
+ }
+ url += query
+ }
+ }
+
+ h := w.Header()
+
+ // RFC 7231 notes that a short HTML body is usually included in
+ // the response because older user agents may not understand 301/307.
+ // Do it only if the request didn't already have a Content-Type header.
+ _, hadCT := h["Content-Type"]
+
+ h.Set("Location", hexEscapeNonASCII(url))
+ if !hadCT && (r.Method == "GET" || r.Method == "HEAD") {
+ h.Set("Content-Type", "text/html; charset=utf-8")
+ }
+ w.WriteHeader(code)
+
+ // Shouldn't send the body for POST or HEAD; that leaves GET.
+ if !hadCT && r.Method == "GET" {
+ body := "<a href=\"" + htmlEscape(url) + "\">" + StatusText(code) + "</a>.\n"
+ fmt.Fprintln(w, body)
+ }
+}
+
+var htmlReplacer = strings.NewReplacer(
+ "&", "&amp;",
+ "<", "&lt;",
+ ">", "&gt;",
+ // "&#34;" is shorter than "&quot;".
+ `"`, "&#34;",
+ // "&#39;" is shorter than "&apos;" and apos was not in HTML until HTML5.
+ "'", "&#39;",
+)
+
+func htmlEscape(s string) string {
+ return htmlReplacer.Replace(s)
+}
+
+// Redirect to a fixed URL
+type redirectHandler struct {
+ url string
+ code int
+}
+
+func (rh *redirectHandler) ServeHTTP(w ResponseWriter, r *Request) {
+ Redirect(w, r, rh.url, rh.code)
+}
+
+// RedirectHandler returns a request handler that redirects
+// each request it receives to the given url using the given
+// status code.
+//
+// The provided code should be in the 3xx range and is usually
+// StatusMovedPermanently, StatusFound or StatusSeeOther.
+func RedirectHandler(url string, code int) Handler {
+ return &redirectHandler{url, code}
+}
+
+// ServeMux is an HTTP request multiplexer.
+// It matches the URL of each incoming request against a list of registered
+// patterns and calls the handler for the pattern that
+// most closely matches the URL.
+//
+// Patterns name fixed, rooted paths, like "/favicon.ico",
+// or rooted subtrees, like "/images/" (note the trailing slash).
+// Longer patterns take precedence over shorter ones, so that
+// if there are handlers registered for both "/images/"
+// and "/images/thumbnails/", the latter handler will be
+// called for paths beginning with "/images/thumbnails/" and the
+// former will receive requests for any other paths in the
+// "/images/" subtree.
+//
+// Note that since a pattern ending in a slash names a rooted subtree,
+// the pattern "/" matches all paths not matched by other registered
+// patterns, not just the URL with Path == "/".
+//
+// If a subtree has been registered and a request is received naming the
+// subtree root without its trailing slash, ServeMux redirects that
+// request to the subtree root (adding the trailing slash). This behavior can
+// be overridden with a separate registration for the path without
+// the trailing slash. For example, registering "/images/" causes ServeMux
+// to redirect a request for "/images" to "/images/", unless "/images" has
+// been registered separately.
+//
+// Patterns may optionally begin with a host name, restricting matches to
+// URLs on that host only. Host-specific patterns take precedence over
+// general patterns, so that a handler might register for the two patterns
+// "/codesearch" and "codesearch.google.com/" without also taking over
+// requests for "http://www.google.com/".
+//
+// ServeMux also takes care of sanitizing the URL request path and the Host
+// header, stripping the port number and redirecting any request containing . or
+// .. elements or repeated slashes to an equivalent, cleaner URL.
+type ServeMux struct {
+ mu sync.RWMutex
+ m map[string]muxEntry
+ es []muxEntry // slice of entries sorted from longest to shortest.
+ hosts bool // whether any patterns contain hostnames
+}
+
+type muxEntry struct {
+ h Handler
+ pattern string
+}
+
+// NewServeMux allocates and returns a new ServeMux.
+func NewServeMux() *ServeMux { return new(ServeMux) }
+
+// DefaultServeMux is the default ServeMux used by Serve.
+var DefaultServeMux = &defaultServeMux
+
+var defaultServeMux ServeMux
+
+// cleanPath returns the canonical path for p, eliminating . and .. elements.
+func cleanPath(p string) string {
+ if p == "" {
+ return "/"
+ }
+ if p[0] != '/' {
+ p = "/" + p
+ }
+ np := path.Clean(p)
+ // path.Clean removes trailing slash except for root;
+ // put the trailing slash back if necessary.
+ if p[len(p)-1] == '/' && np != "/" {
+ // Fast path for common case of p being the string we want:
+ if len(p) == len(np)+1 && strings.HasPrefix(p, np) {
+ np = p
+ } else {
+ np += "/"
+ }
+ }
+ return np
+}
+
+// stripHostPort returns h without any trailing ":<port>".
+func stripHostPort(h string) string {
+ // If no port on host, return unchanged
+ if !strings.Contains(h, ":") {
+ return h
+ }
+ host, _, err := net.SplitHostPort(h)
+ if err != nil {
+ return h // on error, return unchanged
+ }
+ return host
+}
+
+// Find a handler on a handler map given a path string.
+// Most-specific (longest) pattern wins.
+func (mux *ServeMux) match(path string) (h Handler, pattern string) {
+ // Check for exact match first.
+ v, ok := mux.m[path]
+ if ok {
+ return v.h, v.pattern
+ }
+
+ // Check for longest valid match. mux.es contains all patterns
+ // that end in / sorted from longest to shortest.
+ for _, e := range mux.es {
+ if strings.HasPrefix(path, e.pattern) {
+ return e.h, e.pattern
+ }
+ }
+ return nil, ""
+}
+
+// redirectToPathSlash determines if the given path needs appending "/" to it.
+// This occurs when a handler for path + "/" was already registered, but
+// not for path itself. If the path needs appending to, it creates a new
+// URL, setting the path to u.Path + "/" and returning true to indicate so.
+func (mux *ServeMux) redirectToPathSlash(host, path string, u *url.URL) (*url.URL, bool) {
+ mux.mu.RLock()
+ shouldRedirect := mux.shouldRedirectRLocked(host, path)
+ mux.mu.RUnlock()
+ if !shouldRedirect {
+ return u, false
+ }
+ path = path + "/"
+ u = &url.URL{Path: path, RawQuery: u.RawQuery}
+ return u, true
+}
+
+// shouldRedirectRLocked reports whether the given path and host should be redirected to
+// path+"/". This should happen if a handler is registered for path+"/" but
+// not path -- see comments at ServeMux.
+func (mux *ServeMux) shouldRedirectRLocked(host, path string) bool {
+ p := []string{path, host + path}
+
+ for _, c := range p {
+ if _, exist := mux.m[c]; exist {
+ return false
+ }
+ }
+
+ n := len(path)
+ if n == 0 {
+ return false
+ }
+ for _, c := range p {
+ if _, exist := mux.m[c+"/"]; exist {
+ return path[n-1] != '/'
+ }
+ }
+
+ return false
+}
+
+// Handler returns the handler to use for the given request,
+// consulting r.Method, r.Host, and r.URL.Path. It always returns
+// a non-nil handler. If the path is not in its canonical form, the
+// handler will be an internally-generated handler that redirects
+// to the canonical path. If the host contains a port, it is ignored
+// when matching handlers.
+//
+// The path and host are used unchanged for CONNECT requests.
+//
+// Handler also returns the registered pattern that matches the
+// request or, in the case of internally-generated redirects,
+// the pattern that will match after following the redirect.
+//
+// If there is no registered handler that applies to the request,
+// Handler returns a “page not found” handler and an empty pattern.
+func (mux *ServeMux) Handler(r *Request) (h Handler, pattern string) {
+
+ // CONNECT requests are not canonicalized.
+ if r.Method == "CONNECT" {
+ // If r.URL.Path is /tree and its handler is not registered,
+ // the /tree -> /tree/ redirect applies to CONNECT requests
+ // but the path canonicalization does not.
+ if u, ok := mux.redirectToPathSlash(r.URL.Host, r.URL.Path, r.URL); ok {
+ return RedirectHandler(u.String(), StatusMovedPermanently), u.Path
+ }
+
+ return mux.handler(r.Host, r.URL.Path)
+ }
+
+ // All other requests have any port stripped and path cleaned
+ // before passing to mux.handler.
+ host := stripHostPort(r.Host)
+ path := cleanPath(r.URL.Path)
+
+ // If the given path is /tree and its handler is not registered,
+ // redirect for /tree/.
+ if u, ok := mux.redirectToPathSlash(host, path, r.URL); ok {
+ return RedirectHandler(u.String(), StatusMovedPermanently), u.Path
+ }
+
+ if path != r.URL.Path {
+ _, pattern = mux.handler(host, path)
+ u := &url.URL{Path: path, RawQuery: r.URL.RawQuery}
+ return RedirectHandler(u.String(), StatusMovedPermanently), pattern
+ }
+
+ return mux.handler(host, r.URL.Path)
+}
+
+// handler is the main implementation of Handler.
+// The path is known to be in canonical form, except for CONNECT methods.
+func (mux *ServeMux) handler(host, path string) (h Handler, pattern string) {
+ mux.mu.RLock()
+ defer mux.mu.RUnlock()
+
+ // Host-specific pattern takes precedence over generic ones
+ if mux.hosts {
+ h, pattern = mux.match(host + path)
+ }
+ if h == nil {
+ h, pattern = mux.match(path)
+ }
+ if h == nil {
+ h, pattern = NotFoundHandler(), ""
+ }
+ return
+}
+
+// ServeHTTP dispatches the request to the handler whose
+// pattern most closely matches the request URL.
+func (mux *ServeMux) ServeHTTP(w ResponseWriter, r *Request) {
+ if r.RequestURI == "*" {
+ if r.ProtoAtLeast(1, 1) {
+ w.Header().Set("Connection", "close")
+ }
+ w.WriteHeader(StatusBadRequest)
+ return
+ }
+ h, _ := mux.Handler(r)
+ h.ServeHTTP(w, r)
+}
+
+// Handle registers the handler for the given pattern.
+// If a handler already exists for pattern, Handle panics.
+func (mux *ServeMux) Handle(pattern string, handler Handler) {
+ mux.mu.Lock()
+ defer mux.mu.Unlock()
+
+ if pattern == "" {
+ panic("http: invalid pattern")
+ }
+ if handler == nil {
+ panic("http: nil handler")
+ }
+ if _, exist := mux.m[pattern]; exist {
+ panic("http: multiple registrations for " + pattern)
+ }
+
+ if mux.m == nil {
+ mux.m = make(map[string]muxEntry)
+ }
+ e := muxEntry{h: handler, pattern: pattern}
+ mux.m[pattern] = e
+ if pattern[len(pattern)-1] == '/' {
+ mux.es = appendSorted(mux.es, e)
+ }
+
+ if pattern[0] != '/' {
+ mux.hosts = true
+ }
+}
+
+func appendSorted(es []muxEntry, e muxEntry) []muxEntry {
+ n := len(es)
+ i := sort.Search(n, func(i int) bool {
+ return len(es[i].pattern) < len(e.pattern)
+ })
+ if i == n {
+ return append(es, e)
+ }
+ // we now know that i points at where we want to insert
+ es = append(es, muxEntry{}) // try to grow the slice in place, any entry works.
+ copy(es[i+1:], es[i:]) // Move shorter entries down
+ es[i] = e
+ return es
+}
+
+// HandleFunc registers the handler function for the given pattern.
+func (mux *ServeMux) HandleFunc(pattern string, handler func(ResponseWriter, *Request)) {
+ if handler == nil {
+ panic("http: nil handler")
+ }
+ mux.Handle(pattern, HandlerFunc(handler))
+}
+
+// Handle registers the handler for the given pattern
+// in the DefaultServeMux.
+// The documentation for ServeMux explains how patterns are matched.
+func Handle(pattern string, handler Handler) { DefaultServeMux.Handle(pattern, handler) }
+
+// HandleFunc registers the handler function for the given pattern
+// in the DefaultServeMux.
+// The documentation for ServeMux explains how patterns are matched.
+func HandleFunc(pattern string, handler func(ResponseWriter, *Request)) {
+ DefaultServeMux.HandleFunc(pattern, handler)
+}
+
+// Serve accepts incoming HTTP connections on the listener l,
+// creating a new service goroutine for each. The service goroutines
+// read requests and then call handler to reply to them.
+//
+// The handler is typically nil, in which case the DefaultServeMux is used.
+//
+// HTTP/2 support is only enabled if the Listener returns *tls.Conn
+// connections and they were configured with "h2" in the TLS
+// Config.NextProtos.
+//
+// Serve always returns a non-nil error.
+func Serve(l net.Listener, handler Handler) error {
+ srv := &Server{Handler: handler}
+ return srv.Serve(l)
+}
+
+// ServeTLS accepts incoming HTTPS connections on the listener l,
+// creating a new service goroutine for each. The service goroutines
+// read requests and then call handler to reply to them.
+//
+// The handler is typically nil, in which case the DefaultServeMux is used.
+//
+// Additionally, files containing a certificate and matching private key
+// for the server must be provided. If the certificate is signed by a
+// certificate authority, the certFile should be the concatenation
+// of the server's certificate, any intermediates, and the CA's certificate.
+//
+// ServeTLS always returns a non-nil error.
+func ServeTLS(l net.Listener, handler Handler, certFile, keyFile string) error {
+ srv := &Server{Handler: handler}
+ return srv.ServeTLS(l, certFile, keyFile)
+}
+
+// A Server defines parameters for running an HTTP server.
+// The zero value for Server is a valid configuration.
+type Server struct {
+ // Addr optionally specifies the TCP address for the server to listen on,
+ // in the form "host:port". If empty, ":http" (port 80) is used.
+ // The service names are defined in RFC 6335 and assigned by IANA.
+ // See net.Dial for details of the address format.
+ Addr string
+
+ Handler Handler // handler to invoke, http.DefaultServeMux if nil
+
+ // DisableGeneralOptionsHandler, if true, passes "OPTIONS *" requests to the Handler,
+ // otherwise responds with 200 OK and Content-Length: 0.
+ DisableGeneralOptionsHandler bool
+
+ // TLSConfig optionally provides a TLS configuration for use
+ // by ServeTLS and ListenAndServeTLS. Note that this value is
+ // cloned by ServeTLS and ListenAndServeTLS, so it's not
+ // possible to modify the configuration with methods like
+ // tls.Config.SetSessionTicketKeys. To use
+ // SetSessionTicketKeys, use Server.Serve with a TLS Listener
+ // instead.
+ TLSConfig *tls.Config
+
+ // ReadTimeout is the maximum duration for reading the entire
+ // request, including the body. A zero or negative value means
+ // there will be no timeout.
+ //
+ // Because ReadTimeout does not let Handlers make per-request
+ // decisions on each request body's acceptable deadline or
+ // upload rate, most users will prefer to use
+ // ReadHeaderTimeout. It is valid to use them both.
+ ReadTimeout time.Duration
+
+ // ReadHeaderTimeout is the amount of time allowed to read
+ // request headers. The connection's read deadline is reset
+ // after reading the headers and the Handler can decide what
+ // is considered too slow for the body. If ReadHeaderTimeout
+ // is zero, the value of ReadTimeout is used. If both are
+ // zero, there is no timeout.
+ ReadHeaderTimeout time.Duration
+
+ // WriteTimeout is the maximum duration before timing out
+ // writes of the response. It is reset whenever a new
+ // request's header is read. Like ReadTimeout, it does not
+ // let Handlers make decisions on a per-request basis.
+ // A zero or negative value means there will be no timeout.
+ WriteTimeout time.Duration
+
+ // IdleTimeout is the maximum amount of time to wait for the
+ // next request when keep-alives are enabled. If IdleTimeout
+ // is zero, the value of ReadTimeout is used. If both are
+ // zero, there is no timeout.
+ IdleTimeout time.Duration
+
+ // MaxHeaderBytes controls the maximum number of bytes the
+ // server will read parsing the request header's keys and
+ // values, including the request line. It does not limit the
+ // size of the request body.
+ // If zero, DefaultMaxHeaderBytes is used.
+ MaxHeaderBytes int
+
+ // TLSNextProto optionally specifies a function to take over
+ // ownership of the provided TLS connection when an ALPN
+ // protocol upgrade has occurred. The map key is the protocol
+ // name negotiated. The Handler argument should be used to
+ // handle HTTP requests and will initialize the Request's TLS
+ // and RemoteAddr if not already set. The connection is
+ // automatically closed when the function returns.
+ // If TLSNextProto is not nil, HTTP/2 support is not enabled
+ // automatically.
+ TLSNextProto map[string]func(*Server, *tls.Conn, Handler)
+
+ // ConnState specifies an optional callback function that is
+ // called when a client connection changes state. See the
+ // ConnState type and associated constants for details.
+ ConnState func(net.Conn, ConnState)
+
+ // ErrorLog specifies an optional logger for errors accepting
+ // connections, unexpected behavior from handlers, and
+ // underlying FileSystem errors.
+ // If nil, logging is done via the log package's standard logger.
+ ErrorLog *log.Logger
+
+ // BaseContext optionally specifies a function that returns
+ // the base context for incoming requests on this server.
+ // The provided Listener is the specific Listener that's
+ // about to start accepting requests.
+ // If BaseContext is nil, the default is context.Background().
+ // If non-nil, it must return a non-nil context.
+ BaseContext func(net.Listener) context.Context
+
+ // ConnContext optionally specifies a function that modifies
+ // the context used for a new connection c. The provided ctx
+ // is derived from the base context and has a ServerContextKey
+ // value.
+ ConnContext func(ctx context.Context, c net.Conn) context.Context
+
+ inShutdown atomic.Bool // true when server is in shutdown
+
+ disableKeepAlives atomic.Bool
+ nextProtoOnce sync.Once // guards setupHTTP2_* init
+ nextProtoErr error // result of http2.ConfigureServer if used
+
+ mu sync.Mutex
+ listeners map[*net.Listener]struct{}
+ activeConn map[*conn]struct{}
+ onShutdown []func()
+
+ listenerGroup sync.WaitGroup
+}
+
+// Close immediately closes all active net.Listeners and any
+// connections in state StateNew, StateActive, or StateIdle. For a
+// graceful shutdown, use Shutdown.
+//
+// Close does not attempt to close (and does not even know about)
+// any hijacked connections, such as WebSockets.
+//
+// Close returns any error returned from closing the Server's
+// underlying Listener(s).
+func (srv *Server) Close() error {
+ srv.inShutdown.Store(true)
+ srv.mu.Lock()
+ defer srv.mu.Unlock()
+ err := srv.closeListenersLocked()
+
+ // Unlock srv.mu while waiting for listenerGroup.
+ // The group Add and Done calls are made with srv.mu held,
+ // to avoid adding a new listener in the window between
+ // us setting inShutdown above and waiting here.
+ srv.mu.Unlock()
+ srv.listenerGroup.Wait()
+ srv.mu.Lock()
+
+ for c := range srv.activeConn {
+ c.rwc.Close()
+ delete(srv.activeConn, c)
+ }
+ return err
+}
+
+// shutdownPollIntervalMax is the max polling interval when checking
+// quiescence during Server.Shutdown. Polling starts with a small
+// interval and backs off to the max.
+// Ideally we could find a solution that doesn't involve polling,
+// but which also doesn't have a high runtime cost (and doesn't
+// involve any contentious mutexes), but that is left as an
+// exercise for the reader.
+const shutdownPollIntervalMax = 500 * time.Millisecond
+
+// Shutdown gracefully shuts down the server without interrupting any
+// active connections. Shutdown works by first closing all open
+// listeners, then closing all idle connections, and then waiting
+// indefinitely for connections to return to idle and then shut down.
+// If the provided context expires before the shutdown is complete,
+// Shutdown returns the context's error, otherwise it returns any
+// error returned from closing the Server's underlying Listener(s).
+//
+// When Shutdown is called, Serve, ListenAndServe, and
+// ListenAndServeTLS immediately return ErrServerClosed. Make sure the
+// program doesn't exit and waits instead for Shutdown to return.
+//
+// Shutdown does not attempt to close nor wait for hijacked
+// connections such as WebSockets. The caller of Shutdown should
+// separately notify such long-lived connections of shutdown and wait
+// for them to close, if desired. See RegisterOnShutdown for a way to
+// register shutdown notification functions.
+//
+// Once Shutdown has been called on a server, it may not be reused;
+// future calls to methods such as Serve will return ErrServerClosed.
+func (srv *Server) Shutdown(ctx context.Context) error {
+ srv.inShutdown.Store(true)
+
+ srv.mu.Lock()
+ lnerr := srv.closeListenersLocked()
+ for _, f := range srv.onShutdown {
+ go f()
+ }
+ srv.mu.Unlock()
+ srv.listenerGroup.Wait()
+
+ pollIntervalBase := time.Millisecond
+ nextPollInterval := func() time.Duration {
+ // Add 10% jitter.
+ interval := pollIntervalBase + time.Duration(rand.Intn(int(pollIntervalBase/10)))
+ // Double and clamp for next time.
+ pollIntervalBase *= 2
+ if pollIntervalBase > shutdownPollIntervalMax {
+ pollIntervalBase = shutdownPollIntervalMax
+ }
+ return interval
+ }
+
+ timer := time.NewTimer(nextPollInterval())
+ defer timer.Stop()
+ for {
+ if srv.closeIdleConns() {
+ return lnerr
+ }
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-timer.C:
+ timer.Reset(nextPollInterval())
+ }
+ }
+}
+
+// RegisterOnShutdown registers a function to call on Shutdown.
+// This can be used to gracefully shutdown connections that have
+// undergone ALPN protocol upgrade or that have been hijacked.
+// This function should start protocol-specific graceful shutdown,
+// but should not wait for shutdown to complete.
+func (srv *Server) RegisterOnShutdown(f func()) {
+ srv.mu.Lock()
+ srv.onShutdown = append(srv.onShutdown, f)
+ srv.mu.Unlock()
+}
+
+// closeIdleConns closes all idle connections and reports whether the
+// server is quiescent.
+func (s *Server) closeIdleConns() bool {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ quiescent := true
+ for c := range s.activeConn {
+ st, unixSec := c.getState()
+ // Issue 22682: treat StateNew connections as if
+ // they're idle if we haven't read the first request's
+ // header in over 5 seconds.
+ if st == StateNew && unixSec < time.Now().Unix()-5 {
+ st = StateIdle
+ }
+ if st != StateIdle || unixSec == 0 {
+ // Assume unixSec == 0 means it's a very new
+ // connection, without state set yet.
+ quiescent = false
+ continue
+ }
+ c.rwc.Close()
+ delete(s.activeConn, c)
+ }
+ return quiescent
+}
+
+func (s *Server) closeListenersLocked() error {
+ var err error
+ for ln := range s.listeners {
+ if cerr := (*ln).Close(); cerr != nil && err == nil {
+ err = cerr
+ }
+ }
+ return err
+}
+
+// A ConnState represents the state of a client connection to a server.
+// It's used by the optional Server.ConnState hook.
+type ConnState int
+
+const (
+ // StateNew represents a new connection that is expected to
+ // send a request immediately. Connections begin at this
+ // state and then transition to either StateActive or
+ // StateClosed.
+ StateNew ConnState = iota
+
+ // StateActive represents a connection that has read 1 or more
+ // bytes of a request. The Server.ConnState hook for
+ // StateActive fires before the request has entered a handler
+ // and doesn't fire again until the request has been
+ // handled. After the request is handled, the state
+ // transitions to StateClosed, StateHijacked, or StateIdle.
+ // For HTTP/2, StateActive fires on the transition from zero
+ // to one active request, and only transitions away once all
+ // active requests are complete. That means that ConnState
+ // cannot be used to do per-request work; ConnState only notes
+ // the overall state of the connection.
+ StateActive
+
+ // StateIdle represents a connection that has finished
+ // handling a request and is in the keep-alive state, waiting
+ // for a new request. Connections transition from StateIdle
+ // to either StateActive or StateClosed.
+ StateIdle
+
+ // StateHijacked represents a hijacked connection.
+ // This is a terminal state. It does not transition to StateClosed.
+ StateHijacked
+
+ // StateClosed represents a closed connection.
+ // This is a terminal state. Hijacked connections do not
+ // transition to StateClosed.
+ StateClosed
+)
+
+var stateName = map[ConnState]string{
+ StateNew: "new",
+ StateActive: "active",
+ StateIdle: "idle",
+ StateHijacked: "hijacked",
+ StateClosed: "closed",
+}
+
+func (c ConnState) String() string {
+ return stateName[c]
+}
+
+// serverHandler delegates to either the server's Handler or
+// DefaultServeMux and also handles "OPTIONS *" requests.
+type serverHandler struct {
+ srv *Server
+}
+
+func (sh serverHandler) ServeHTTP(rw ResponseWriter, req *Request) {
+ handler := sh.srv.Handler
+ if handler == nil {
+ handler = DefaultServeMux
+ }
+ if !sh.srv.DisableGeneralOptionsHandler && req.RequestURI == "*" && req.Method == "OPTIONS" {
+ handler = globalOptionsHandler{}
+ }
+
+ handler.ServeHTTP(rw, req)
+}
+
+// AllowQuerySemicolons returns a handler that serves requests by converting any
+// unescaped semicolons in the URL query to ampersands, and invoking the handler h.
+//
+// This restores the pre-Go 1.17 behavior of splitting query parameters on both
+// semicolons and ampersands. (See golang.org/issue/25192). Note that this
+// behavior doesn't match that of many proxies, and the mismatch can lead to
+// security issues.
+//
+// AllowQuerySemicolons should be invoked before Request.ParseForm is called.
+func AllowQuerySemicolons(h Handler) Handler {
+ return HandlerFunc(func(w ResponseWriter, r *Request) {
+ if strings.Contains(r.URL.RawQuery, ";") {
+ r2 := new(Request)
+ *r2 = *r
+ r2.URL = new(url.URL)
+ *r2.URL = *r.URL
+ r2.URL.RawQuery = strings.ReplaceAll(r.URL.RawQuery, ";", "&")
+ h.ServeHTTP(w, r2)
+ } else {
+ h.ServeHTTP(w, r)
+ }
+ })
+}
+
+// ListenAndServe listens on the TCP network address srv.Addr and then
+// calls Serve to handle requests on incoming connections.
+// Accepted connections are configured to enable TCP keep-alives.
+//
+// If srv.Addr is blank, ":http" is used.
+//
+// ListenAndServe always returns a non-nil error. After Shutdown or Close,
+// the returned error is ErrServerClosed.
+func (srv *Server) ListenAndServe() error {
+ if srv.shuttingDown() {
+ return ErrServerClosed
+ }
+ addr := srv.Addr
+ if addr == "" {
+ addr = ":http"
+ }
+ ln, err := net.Listen("tcp", addr)
+ if err != nil {
+ return err
+ }
+ return srv.Serve(ln)
+}
+
+var testHookServerServe func(*Server, net.Listener) // used if non-nil
+
+// shouldConfigureHTTP2ForServe reports whether Server.Serve should configure
+// automatic HTTP/2. (which sets up the srv.TLSNextProto map)
+func (srv *Server) shouldConfigureHTTP2ForServe() bool {
+ if srv.TLSConfig == nil {
+ // Compatibility with Go 1.6:
+ // If there's no TLSConfig, it's possible that the user just
+ // didn't set it on the http.Server, but did pass it to
+ // tls.NewListener and passed that listener to Serve.
+ // So we should configure HTTP/2 (to set up srv.TLSNextProto)
+ // in case the listener returns an "h2" *tls.Conn.
+ return true
+ }
+ // The user specified a TLSConfig on their http.Server.
+ // In this, case, only configure HTTP/2 if their tls.Config
+ // explicitly mentions "h2". Otherwise http2.ConfigureServer
+ // would modify the tls.Config to add it, but they probably already
+ // passed this tls.Config to tls.NewListener. And if they did,
+ // it's too late anyway to fix it. It would only be potentially racy.
+ // See Issue 15908.
+ return strSliceContains(srv.TLSConfig.NextProtos, http2NextProtoTLS)
+}
+
+// ErrServerClosed is returned by the Server's Serve, ServeTLS, ListenAndServe,
+// and ListenAndServeTLS methods after a call to Shutdown or Close.
+var ErrServerClosed = errors.New("http: Server closed")
+
+// Serve accepts incoming connections on the Listener l, creating a
+// new service goroutine for each. The service goroutines read requests and
+// then call srv.Handler to reply to them.
+//
+// HTTP/2 support is only enabled if the Listener returns *tls.Conn
+// connections and they were configured with "h2" in the TLS
+// Config.NextProtos.
+//
+// Serve always returns a non-nil error and closes l.
+// After Shutdown or Close, the returned error is ErrServerClosed.
+func (srv *Server) Serve(l net.Listener) error {
+ if fn := testHookServerServe; fn != nil {
+ fn(srv, l) // call hook with unwrapped listener
+ }
+
+ origListener := l
+ l = &onceCloseListener{Listener: l}
+ defer l.Close()
+
+ if err := srv.setupHTTP2_Serve(); err != nil {
+ return err
+ }
+
+ if !srv.trackListener(&l, true) {
+ return ErrServerClosed
+ }
+ defer srv.trackListener(&l, false)
+
+ baseCtx := context.Background()
+ if srv.BaseContext != nil {
+ baseCtx = srv.BaseContext(origListener)
+ if baseCtx == nil {
+ panic("BaseContext returned a nil context")
+ }
+ }
+
+ var tempDelay time.Duration // how long to sleep on accept failure
+
+ ctx := context.WithValue(baseCtx, ServerContextKey, srv)
+ for {
+ rw, err := l.Accept()
+ if err != nil {
+ if srv.shuttingDown() {
+ return ErrServerClosed
+ }
+ if ne, ok := err.(net.Error); ok && ne.Temporary() {
+ if tempDelay == 0 {
+ tempDelay = 5 * time.Millisecond
+ } else {
+ tempDelay *= 2
+ }
+ if max := 1 * time.Second; tempDelay > max {
+ tempDelay = max
+ }
+ srv.logf("http: Accept error: %v; retrying in %v", err, tempDelay)
+ time.Sleep(tempDelay)
+ continue
+ }
+ return err
+ }
+ connCtx := ctx
+ if cc := srv.ConnContext; cc != nil {
+ connCtx = cc(connCtx, rw)
+ if connCtx == nil {
+ panic("ConnContext returned nil")
+ }
+ }
+ tempDelay = 0
+ c := srv.newConn(rw)
+ c.setState(c.rwc, StateNew, runHooks) // before Serve can return
+ go c.serve(connCtx)
+ }
+}
+
+// ServeTLS accepts incoming connections on the Listener l, creating a
+// new service goroutine for each. The service goroutines perform TLS
+// setup and then read requests, calling srv.Handler to reply to them.
+//
+// Files containing a certificate and matching private key for the
+// server must be provided if neither the Server's
+// TLSConfig.Certificates nor TLSConfig.GetCertificate are populated.
+// If the certificate is signed by a certificate authority, the
+// certFile should be the concatenation of the server's certificate,
+// any intermediates, and the CA's certificate.
+//
+// ServeTLS always returns a non-nil error. After Shutdown or Close, the
+// returned error is ErrServerClosed.
+func (srv *Server) ServeTLS(l net.Listener, certFile, keyFile string) error {
+ // Setup HTTP/2 before srv.Serve, to initialize srv.TLSConfig
+ // before we clone it and create the TLS Listener.
+ if err := srv.setupHTTP2_ServeTLS(); err != nil {
+ return err
+ }
+
+ config := cloneTLSConfig(srv.TLSConfig)
+ if !strSliceContains(config.NextProtos, "http/1.1") {
+ config.NextProtos = append(config.NextProtos, "http/1.1")
+ }
+
+ configHasCert := len(config.Certificates) > 0 || config.GetCertificate != nil
+ if !configHasCert || certFile != "" || keyFile != "" {
+ var err error
+ config.Certificates = make([]tls.Certificate, 1)
+ config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
+ if err != nil {
+ return err
+ }
+ }
+
+ tlsListener := tls.NewListener(l, config)
+ return srv.Serve(tlsListener)
+}
+
+// trackListener adds or removes a net.Listener to the set of tracked
+// listeners.
+//
+// We store a pointer to interface in the map set, in case the
+// net.Listener is not comparable. This is safe because we only call
+// trackListener via Serve and can track+defer untrack the same
+// pointer to local variable there. We never need to compare a
+// Listener from another caller.
+//
+// It reports whether the server is still up (not Shutdown or Closed).
+func (s *Server) trackListener(ln *net.Listener, add bool) bool {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if s.listeners == nil {
+ s.listeners = make(map[*net.Listener]struct{})
+ }
+ if add {
+ if s.shuttingDown() {
+ return false
+ }
+ s.listeners[ln] = struct{}{}
+ s.listenerGroup.Add(1)
+ } else {
+ delete(s.listeners, ln)
+ s.listenerGroup.Done()
+ }
+ return true
+}
+
+func (s *Server) trackConn(c *conn, add bool) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if s.activeConn == nil {
+ s.activeConn = make(map[*conn]struct{})
+ }
+ if add {
+ s.activeConn[c] = struct{}{}
+ } else {
+ delete(s.activeConn, c)
+ }
+}
+
+func (s *Server) idleTimeout() time.Duration {
+ if s.IdleTimeout != 0 {
+ return s.IdleTimeout
+ }
+ return s.ReadTimeout
+}
+
+func (s *Server) readHeaderTimeout() time.Duration {
+ if s.ReadHeaderTimeout != 0 {
+ return s.ReadHeaderTimeout
+ }
+ return s.ReadTimeout
+}
+
+func (s *Server) doKeepAlives() bool {
+ return !s.disableKeepAlives.Load() && !s.shuttingDown()
+}
+
+func (s *Server) shuttingDown() bool {
+ return s.inShutdown.Load()
+}
+
+// SetKeepAlivesEnabled controls whether HTTP keep-alives are enabled.
+// By default, keep-alives are always enabled. Only very
+// resource-constrained environments or servers in the process of
+// shutting down should disable them.
+func (srv *Server) SetKeepAlivesEnabled(v bool) {
+ if v {
+ srv.disableKeepAlives.Store(false)
+ return
+ }
+ srv.disableKeepAlives.Store(true)
+
+ // Close idle HTTP/1 conns:
+ srv.closeIdleConns()
+
+ // TODO: Issue 26303: close HTTP/2 conns as soon as they become idle.
+}
+
+func (s *Server) logf(format string, args ...any) {
+ if s.ErrorLog != nil {
+ s.ErrorLog.Printf(format, args...)
+ } else {
+ log.Printf(format, args...)
+ }
+}
+
+// logf prints to the ErrorLog of the *Server associated with request r
+// via ServerContextKey. If there's no associated server, or if ErrorLog
+// is nil, logging is done via the log package's standard logger.
+func logf(r *Request, format string, args ...any) {
+ s, _ := r.Context().Value(ServerContextKey).(*Server)
+ if s != nil && s.ErrorLog != nil {
+ s.ErrorLog.Printf(format, args...)
+ } else {
+ log.Printf(format, args...)
+ }
+}
+
+// ListenAndServe listens on the TCP network address addr and then calls
+// Serve with handler to handle requests on incoming connections.
+// Accepted connections are configured to enable TCP keep-alives.
+//
+// The handler is typically nil, in which case the DefaultServeMux is used.
+//
+// ListenAndServe always returns a non-nil error.
+func ListenAndServe(addr string, handler Handler) error {
+ server := &Server{Addr: addr, Handler: handler}
+ return server.ListenAndServe()
+}
+
+// ListenAndServeTLS acts identically to ListenAndServe, except that it
+// expects HTTPS connections. Additionally, files containing a certificate and
+// matching private key for the server must be provided. If the certificate
+// is signed by a certificate authority, the certFile should be the concatenation
+// of the server's certificate, any intermediates, and the CA's certificate.
+func ListenAndServeTLS(addr, certFile, keyFile string, handler Handler) error {
+ server := &Server{Addr: addr, Handler: handler}
+ return server.ListenAndServeTLS(certFile, keyFile)
+}
+
+// ListenAndServeTLS listens on the TCP network address srv.Addr and
+// then calls ServeTLS to handle requests on incoming TLS connections.
+// Accepted connections are configured to enable TCP keep-alives.
+//
+// Filenames containing a certificate and matching private key for the
+// server must be provided if neither the Server's TLSConfig.Certificates
+// nor TLSConfig.GetCertificate are populated. If the certificate is
+// signed by a certificate authority, the certFile should be the
+// concatenation of the server's certificate, any intermediates, and
+// the CA's certificate.
+//
+// If srv.Addr is blank, ":https" is used.
+//
+// ListenAndServeTLS always returns a non-nil error. After Shutdown or
+// Close, the returned error is ErrServerClosed.
+func (srv *Server) ListenAndServeTLS(certFile, keyFile string) error {
+ if srv.shuttingDown() {
+ return ErrServerClosed
+ }
+ addr := srv.Addr
+ if addr == "" {
+ addr = ":https"
+ }
+
+ ln, err := net.Listen("tcp", addr)
+ if err != nil {
+ return err
+ }
+
+ defer ln.Close()
+
+ return srv.ServeTLS(ln, certFile, keyFile)
+}
+
+// setupHTTP2_ServeTLS conditionally configures HTTP/2 on
+// srv and reports whether there was an error setting it up. If it is
+// not configured for policy reasons, nil is returned.
+func (srv *Server) setupHTTP2_ServeTLS() error {
+ srv.nextProtoOnce.Do(srv.onceSetNextProtoDefaults)
+ return srv.nextProtoErr
+}
+
+// setupHTTP2_Serve is called from (*Server).Serve and conditionally
+// configures HTTP/2 on srv using a more conservative policy than
+// setupHTTP2_ServeTLS because Serve is called after tls.Listen,
+// and may be called concurrently. See shouldConfigureHTTP2ForServe.
+//
+// The tests named TestTransportAutomaticHTTP2* and
+// TestConcurrentServerServe in server_test.go demonstrate some
+// of the supported use cases and motivations.
+func (srv *Server) setupHTTP2_Serve() error {
+ srv.nextProtoOnce.Do(srv.onceSetNextProtoDefaults_Serve)
+ return srv.nextProtoErr
+}
+
+func (srv *Server) onceSetNextProtoDefaults_Serve() {
+ if srv.shouldConfigureHTTP2ForServe() {
+ srv.onceSetNextProtoDefaults()
+ }
+}
+
+var http2server = godebug.New("http2server")
+
+// onceSetNextProtoDefaults configures HTTP/2, if the user hasn't
+// configured otherwise. (by setting srv.TLSNextProto non-nil)
+// It must only be called via srv.nextProtoOnce (use srv.setupHTTP2_*).
+func (srv *Server) onceSetNextProtoDefaults() {
+ if omitBundledHTTP2 {
+ return
+ }
+ if http2server.Value() == "0" {
+ http2server.IncNonDefault()
+ return
+ }
+ // Enable HTTP/2 by default if the user hasn't otherwise
+ // configured their TLSNextProto map.
+ if srv.TLSNextProto == nil {
+ conf := &http2Server{
+ NewWriteScheduler: func() http2WriteScheduler { return http2NewPriorityWriteScheduler(nil) },
+ }
+ srv.nextProtoErr = http2ConfigureServer(srv, conf)
+ }
+}
+
+// TimeoutHandler returns a Handler that runs h with the given time limit.
+//
+// The new Handler calls h.ServeHTTP to handle each request, but if a
+// call runs for longer than its time limit, the handler responds with
+// a 503 Service Unavailable error and the given message in its body.
+// (If msg is empty, a suitable default message will be sent.)
+// After such a timeout, writes by h to its ResponseWriter will return
+// ErrHandlerTimeout.
+//
+// TimeoutHandler supports the Pusher interface but does not support
+// the Hijacker or Flusher interfaces.
+func TimeoutHandler(h Handler, dt time.Duration, msg string) Handler {
+ return &timeoutHandler{
+ handler: h,
+ body: msg,
+ dt: dt,
+ }
+}
+
+// ErrHandlerTimeout is returned on ResponseWriter Write calls
+// in handlers which have timed out.
+var ErrHandlerTimeout = errors.New("http: Handler timeout")
+
+type timeoutHandler struct {
+ handler Handler
+ body string
+ dt time.Duration
+
+ // When set, no context will be created and this context will
+ // be used instead.
+ testContext context.Context
+}
+
+func (h *timeoutHandler) errorBody() string {
+ if h.body != "" {
+ return h.body
+ }
+ return "<html><head><title>Timeout</title></head><body><h1>Timeout</h1></body></html>"
+}
+
+func (h *timeoutHandler) ServeHTTP(w ResponseWriter, r *Request) {
+ ctx := h.testContext
+ if ctx == nil {
+ var cancelCtx context.CancelFunc
+ ctx, cancelCtx = context.WithTimeout(r.Context(), h.dt)
+ defer cancelCtx()
+ }
+ r = r.WithContext(ctx)
+ done := make(chan struct{})
+ tw := &timeoutWriter{
+ w: w,
+ h: make(Header),
+ req: r,
+ }
+ panicChan := make(chan any, 1)
+ go func() {
+ defer func() {
+ if p := recover(); p != nil {
+ panicChan <- p
+ }
+ }()
+ h.handler.ServeHTTP(tw, r)
+ close(done)
+ }()
+ select {
+ case p := <-panicChan:
+ panic(p)
+ case <-done:
+ tw.mu.Lock()
+ defer tw.mu.Unlock()
+ dst := w.Header()
+ for k, vv := range tw.h {
+ dst[k] = vv
+ }
+ if !tw.wroteHeader {
+ tw.code = StatusOK
+ }
+ w.WriteHeader(tw.code)
+ w.Write(tw.wbuf.Bytes())
+ case <-ctx.Done():
+ tw.mu.Lock()
+ defer tw.mu.Unlock()
+ switch err := ctx.Err(); err {
+ case context.DeadlineExceeded:
+ w.WriteHeader(StatusServiceUnavailable)
+ io.WriteString(w, h.errorBody())
+ tw.err = ErrHandlerTimeout
+ default:
+ w.WriteHeader(StatusServiceUnavailable)
+ tw.err = err
+ }
+ }
+}
+
+type timeoutWriter struct {
+ w ResponseWriter
+ h Header
+ wbuf bytes.Buffer
+ req *Request
+
+ mu sync.Mutex
+ err error
+ wroteHeader bool
+ code int
+}
+
+var _ Pusher = (*timeoutWriter)(nil)
+
+// Push implements the Pusher interface.
+func (tw *timeoutWriter) Push(target string, opts *PushOptions) error {
+ if pusher, ok := tw.w.(Pusher); ok {
+ return pusher.Push(target, opts)
+ }
+ return ErrNotSupported
+}
+
+func (tw *timeoutWriter) Header() Header { return tw.h }
+
+func (tw *timeoutWriter) Write(p []byte) (int, error) {
+ tw.mu.Lock()
+ defer tw.mu.Unlock()
+ if tw.err != nil {
+ return 0, tw.err
+ }
+ if !tw.wroteHeader {
+ tw.writeHeaderLocked(StatusOK)
+ }
+ return tw.wbuf.Write(p)
+}
+
+func (tw *timeoutWriter) writeHeaderLocked(code int) {
+ checkWriteHeaderCode(code)
+
+ switch {
+ case tw.err != nil:
+ return
+ case tw.wroteHeader:
+ if tw.req != nil {
+ caller := relevantCaller()
+ logf(tw.req, "http: superfluous response.WriteHeader call from %s (%s:%d)", caller.Function, path.Base(caller.File), caller.Line)
+ }
+ default:
+ tw.wroteHeader = true
+ tw.code = code
+ }
+}
+
+func (tw *timeoutWriter) WriteHeader(code int) {
+ tw.mu.Lock()
+ defer tw.mu.Unlock()
+ tw.writeHeaderLocked(code)
+}
+
+// onceCloseListener wraps a net.Listener, protecting it from
+// multiple Close calls.
+type onceCloseListener struct {
+ net.Listener
+ once sync.Once
+ closeErr error
+}
+
+func (oc *onceCloseListener) Close() error {
+ oc.once.Do(oc.close)
+ return oc.closeErr
+}
+
+func (oc *onceCloseListener) close() { oc.closeErr = oc.Listener.Close() }
+
+// globalOptionsHandler responds to "OPTIONS *" requests.
+type globalOptionsHandler struct{}
+
+func (globalOptionsHandler) ServeHTTP(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Length", "0")
+ if r.ContentLength != 0 {
+ // Read up to 4KB of OPTIONS body (as mentioned in the
+ // spec as being reserved for future use), but anything
+ // over that is considered a waste of server resources
+ // (or an attack) and we abort and close the connection,
+ // courtesy of MaxBytesReader's EOF behavior.
+ mb := MaxBytesReader(w, r.Body, 4<<10)
+ io.Copy(io.Discard, mb)
+ }
+}
+
+// initALPNRequest is an HTTP handler that initializes certain
+// uninitialized fields in its *Request. Such partially-initialized
+// Requests come from ALPN protocol handlers.
+type initALPNRequest struct {
+ ctx context.Context
+ c *tls.Conn
+ h serverHandler
+}
+
+// BaseContext is an exported but unadvertised http.Handler method
+// recognized by x/net/http2 to pass down a context; the TLSNextProto
+// API predates context support so we shoehorn through the only
+// interface we have available.
+func (h initALPNRequest) BaseContext() context.Context { return h.ctx }
+
+func (h initALPNRequest) ServeHTTP(rw ResponseWriter, req *Request) {
+ if req.TLS == nil {
+ req.TLS = &tls.ConnectionState{}
+ *req.TLS = h.c.ConnectionState()
+ }
+ if req.Body == nil {
+ req.Body = NoBody
+ }
+ if req.RemoteAddr == "" {
+ req.RemoteAddr = h.c.RemoteAddr().String()
+ }
+ h.h.ServeHTTP(rw, req)
+}
+
+// loggingConn is used for debugging.
+type loggingConn struct {
+ name string
+ net.Conn
+}
+
+var (
+ uniqNameMu sync.Mutex
+ uniqNameNext = make(map[string]int)
+)
+
+func newLoggingConn(baseName string, c net.Conn) net.Conn {
+ uniqNameMu.Lock()
+ defer uniqNameMu.Unlock()
+ uniqNameNext[baseName]++
+ return &loggingConn{
+ name: fmt.Sprintf("%s-%d", baseName, uniqNameNext[baseName]),
+ Conn: c,
+ }
+}
+
+func (c *loggingConn) Write(p []byte) (n int, err error) {
+ log.Printf("%s.Write(%d) = ....", c.name, len(p))
+ n, err = c.Conn.Write(p)
+ log.Printf("%s.Write(%d) = %d, %v", c.name, len(p), n, err)
+ return
+}
+
+func (c *loggingConn) Read(p []byte) (n int, err error) {
+ log.Printf("%s.Read(%d) = ....", c.name, len(p))
+ n, err = c.Conn.Read(p)
+ log.Printf("%s.Read(%d) = %d, %v", c.name, len(p), n, err)
+ return
+}
+
+func (c *loggingConn) Close() (err error) {
+ log.Printf("%s.Close() = ...", c.name)
+ err = c.Conn.Close()
+ log.Printf("%s.Close() = %v", c.name, err)
+ return
+}
+
+// checkConnErrorWriter writes to c.rwc and records any write errors to c.werr.
+// It only contains one field (and a pointer field at that), so it
+// fits in an interface value without an extra allocation.
+type checkConnErrorWriter struct {
+ c *conn
+}
+
+func (w checkConnErrorWriter) Write(p []byte) (n int, err error) {
+ n, err = w.c.rwc.Write(p)
+ if err != nil && w.c.werr == nil {
+ w.c.werr = err
+ w.c.cancelCtx()
+ }
+ return
+}
+
+func numLeadingCRorLF(v []byte) (n int) {
+ for _, b := range v {
+ if b == '\r' || b == '\n' {
+ n++
+ continue
+ }
+ break
+ }
+ return
+
+}
+
+func strSliceContains(ss []string, s string) bool {
+ for _, v := range ss {
+ if v == s {
+ return true
+ }
+ }
+ return false
+}
+
+// tlsRecordHeaderLooksLikeHTTP reports whether a TLS record header
+// looks like it might've been a misdirected plaintext HTTP request.
+func tlsRecordHeaderLooksLikeHTTP(hdr [5]byte) bool {
+ switch string(hdr[:]) {
+ case "GET /", "HEAD ", "POST ", "PUT /", "OPTIO":
+ return true
+ }
+ return false
+}
+
+// MaxBytesHandler returns a Handler that runs h with its ResponseWriter and Request.Body wrapped by a MaxBytesReader.
+func MaxBytesHandler(h Handler, n int64) Handler {
+ return HandlerFunc(func(w ResponseWriter, r *Request) {
+ r2 := *r
+ r2.Body = MaxBytesReader(w, r.Body, n)
+ h.ServeHTTP(w, &r2)
+ })
+}
diff --git a/src/net/http/server_test.go b/src/net/http/server_test.go
new file mode 100644
index 0000000..d17c5c1
--- /dev/null
+++ b/src/net/http/server_test.go
@@ -0,0 +1,98 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Server unit tests
+
+package http
+
+import (
+ "fmt"
+ "testing"
+ "time"
+)
+
+func TestServerTLSHandshakeTimeout(t *testing.T) {
+ tests := []struct {
+ s *Server
+ want time.Duration
+ }{
+ {
+ s: &Server{},
+ want: 0,
+ },
+ {
+ s: &Server{
+ ReadTimeout: -1,
+ },
+ want: 0,
+ },
+ {
+ s: &Server{
+ ReadTimeout: 5 * time.Second,
+ },
+ want: 5 * time.Second,
+ },
+ {
+ s: &Server{
+ ReadTimeout: 5 * time.Second,
+ WriteTimeout: -1,
+ },
+ want: 5 * time.Second,
+ },
+ {
+ s: &Server{
+ ReadTimeout: 5 * time.Second,
+ WriteTimeout: 4 * time.Second,
+ },
+ want: 4 * time.Second,
+ },
+ {
+ s: &Server{
+ ReadTimeout: 5 * time.Second,
+ ReadHeaderTimeout: 2 * time.Second,
+ WriteTimeout: 4 * time.Second,
+ },
+ want: 2 * time.Second,
+ },
+ }
+ for i, tt := range tests {
+ got := tt.s.tlsHandshakeTimeout()
+ if got != tt.want {
+ t.Errorf("%d. got %v; want %v", i, got, tt.want)
+ }
+ }
+}
+
+func BenchmarkServerMatch(b *testing.B) {
+ fn := func(w ResponseWriter, r *Request) {
+ fmt.Fprintf(w, "OK")
+ }
+ mux := NewServeMux()
+ mux.HandleFunc("/", fn)
+ mux.HandleFunc("/index", fn)
+ mux.HandleFunc("/home", fn)
+ mux.HandleFunc("/about", fn)
+ mux.HandleFunc("/contact", fn)
+ mux.HandleFunc("/robots.txt", fn)
+ mux.HandleFunc("/products/", fn)
+ mux.HandleFunc("/products/1", fn)
+ mux.HandleFunc("/products/2", fn)
+ mux.HandleFunc("/products/3", fn)
+ mux.HandleFunc("/products/3/image.jpg", fn)
+ mux.HandleFunc("/admin", fn)
+ mux.HandleFunc("/admin/products/", fn)
+ mux.HandleFunc("/admin/products/create", fn)
+ mux.HandleFunc("/admin/products/update", fn)
+ mux.HandleFunc("/admin/products/delete", fn)
+
+ paths := []string{"/", "/notfound", "/admin/", "/admin/foo", "/contact", "/products",
+ "/products/", "/products/3/image.jpg"}
+ b.StartTimer()
+ for i := 0; i < b.N; i++ {
+ if h, p := mux.match(paths[i%len(paths)]); h != nil && p == "" {
+ b.Error("impossible")
+ }
+ }
+ b.StopTimer()
+}
diff --git a/src/net/http/sniff.go b/src/net/http/sniff.go
new file mode 100644
index 0000000..ac18ab9
--- /dev/null
+++ b/src/net/http/sniff.go
@@ -0,0 +1,304 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "bytes"
+ "encoding/binary"
+)
+
+// The algorithm uses at most sniffLen bytes to make its decision.
+const sniffLen = 512
+
+// DetectContentType implements the algorithm described
+// at https://mimesniff.spec.whatwg.org/ to determine the
+// Content-Type of the given data. It considers at most the
+// first 512 bytes of data. DetectContentType always returns
+// a valid MIME type: if it cannot determine a more specific one, it
+// returns "application/octet-stream".
+func DetectContentType(data []byte) string {
+ if len(data) > sniffLen {
+ data = data[:sniffLen]
+ }
+
+ // Index of the first non-whitespace byte in data.
+ firstNonWS := 0
+ for ; firstNonWS < len(data) && isWS(data[firstNonWS]); firstNonWS++ {
+ }
+
+ for _, sig := range sniffSignatures {
+ if ct := sig.match(data, firstNonWS); ct != "" {
+ return ct
+ }
+ }
+
+ return "application/octet-stream" // fallback
+}
+
+// isWS reports whether the provided byte is a whitespace byte (0xWS)
+// as defined in https://mimesniff.spec.whatwg.org/#terminology.
+func isWS(b byte) bool {
+ switch b {
+ case '\t', '\n', '\x0c', '\r', ' ':
+ return true
+ }
+ return false
+}
+
+// isTT reports whether the provided byte is a tag-terminating byte (0xTT)
+// as defined in https://mimesniff.spec.whatwg.org/#terminology.
+func isTT(b byte) bool {
+ switch b {
+ case ' ', '>':
+ return true
+ }
+ return false
+}
+
+type sniffSig interface {
+ // match returns the MIME type of the data, or "" if unknown.
+ match(data []byte, firstNonWS int) string
+}
+
+// Data matching the table in section 6.
+var sniffSignatures = []sniffSig{
+ htmlSig("<!DOCTYPE HTML"),
+ htmlSig("<HTML"),
+ htmlSig("<HEAD"),
+ htmlSig("<SCRIPT"),
+ htmlSig("<IFRAME"),
+ htmlSig("<H1"),
+ htmlSig("<DIV"),
+ htmlSig("<FONT"),
+ htmlSig("<TABLE"),
+ htmlSig("<A"),
+ htmlSig("<STYLE"),
+ htmlSig("<TITLE"),
+ htmlSig("<B"),
+ htmlSig("<BODY"),
+ htmlSig("<BR"),
+ htmlSig("<P"),
+ htmlSig("<!--"),
+ &maskedSig{
+ mask: []byte("\xFF\xFF\xFF\xFF\xFF"),
+ pat: []byte("<?xml"),
+ skipWS: true,
+ ct: "text/xml; charset=utf-8"},
+ &exactSig{[]byte("%PDF-"), "application/pdf"},
+ &exactSig{[]byte("%!PS-Adobe-"), "application/postscript"},
+
+ // UTF BOMs.
+ &maskedSig{
+ mask: []byte("\xFF\xFF\x00\x00"),
+ pat: []byte("\xFE\xFF\x00\x00"),
+ ct: "text/plain; charset=utf-16be",
+ },
+ &maskedSig{
+ mask: []byte("\xFF\xFF\x00\x00"),
+ pat: []byte("\xFF\xFE\x00\x00"),
+ ct: "text/plain; charset=utf-16le",
+ },
+ &maskedSig{
+ mask: []byte("\xFF\xFF\xFF\x00"),
+ pat: []byte("\xEF\xBB\xBF\x00"),
+ ct: "text/plain; charset=utf-8",
+ },
+
+ // Image types
+ // For posterity, we originally returned "image/vnd.microsoft.icon" from
+ // https://tools.ietf.org/html/draft-ietf-websec-mime-sniff-03#section-7
+ // https://codereview.appspot.com/4746042
+ // but that has since been replaced with "image/x-icon" in Section 6.2
+ // of https://mimesniff.spec.whatwg.org/#matching-an-image-type-pattern
+ &exactSig{[]byte("\x00\x00\x01\x00"), "image/x-icon"},
+ &exactSig{[]byte("\x00\x00\x02\x00"), "image/x-icon"},
+ &exactSig{[]byte("BM"), "image/bmp"},
+ &exactSig{[]byte("GIF87a"), "image/gif"},
+ &exactSig{[]byte("GIF89a"), "image/gif"},
+ &maskedSig{
+ mask: []byte("\xFF\xFF\xFF\xFF\x00\x00\x00\x00\xFF\xFF\xFF\xFF\xFF\xFF"),
+ pat: []byte("RIFF\x00\x00\x00\x00WEBPVP"),
+ ct: "image/webp",
+ },
+ &exactSig{[]byte("\x89PNG\x0D\x0A\x1A\x0A"), "image/png"},
+ &exactSig{[]byte("\xFF\xD8\xFF"), "image/jpeg"},
+
+ // Audio and Video types
+ // Enforce the pattern match ordering as prescribed in
+ // https://mimesniff.spec.whatwg.org/#matching-an-audio-or-video-type-pattern
+ &maskedSig{
+ mask: []byte("\xFF\xFF\xFF\xFF\x00\x00\x00\x00\xFF\xFF\xFF\xFF"),
+ pat: []byte("FORM\x00\x00\x00\x00AIFF"),
+ ct: "audio/aiff",
+ },
+ &maskedSig{
+ mask: []byte("\xFF\xFF\xFF"),
+ pat: []byte("ID3"),
+ ct: "audio/mpeg",
+ },
+ &maskedSig{
+ mask: []byte("\xFF\xFF\xFF\xFF\xFF"),
+ pat: []byte("OggS\x00"),
+ ct: "application/ogg",
+ },
+ &maskedSig{
+ mask: []byte("\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF"),
+ pat: []byte("MThd\x00\x00\x00\x06"),
+ ct: "audio/midi",
+ },
+ &maskedSig{
+ mask: []byte("\xFF\xFF\xFF\xFF\x00\x00\x00\x00\xFF\xFF\xFF\xFF"),
+ pat: []byte("RIFF\x00\x00\x00\x00AVI "),
+ ct: "video/avi",
+ },
+ &maskedSig{
+ mask: []byte("\xFF\xFF\xFF\xFF\x00\x00\x00\x00\xFF\xFF\xFF\xFF"),
+ pat: []byte("RIFF\x00\x00\x00\x00WAVE"),
+ ct: "audio/wave",
+ },
+ // 6.2.0.2. video/mp4
+ mp4Sig{},
+ // 6.2.0.3. video/webm
+ &exactSig{[]byte("\x1A\x45\xDF\xA3"), "video/webm"},
+
+ // Font types
+ &maskedSig{
+ // 34 NULL bytes followed by the string "LP"
+ pat: []byte("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00LP"),
+ // 34 NULL bytes followed by \xF\xF
+ mask: []byte("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xFF\xFF"),
+ ct: "application/vnd.ms-fontobject",
+ },
+ &exactSig{[]byte("\x00\x01\x00\x00"), "font/ttf"},
+ &exactSig{[]byte("OTTO"), "font/otf"},
+ &exactSig{[]byte("ttcf"), "font/collection"},
+ &exactSig{[]byte("wOFF"), "font/woff"},
+ &exactSig{[]byte("wOF2"), "font/woff2"},
+
+ // Archive types
+ &exactSig{[]byte("\x1F\x8B\x08"), "application/x-gzip"},
+ &exactSig{[]byte("PK\x03\x04"), "application/zip"},
+ // RAR's signatures are incorrectly defined by the MIME spec as per
+ // https://github.com/whatwg/mimesniff/issues/63
+ // However, RAR Labs correctly defines it at:
+ // https://www.rarlab.com/technote.htm#rarsign
+ // so we use the definition from RAR Labs.
+ // TODO: do whatever the spec ends up doing.
+ &exactSig{[]byte("Rar!\x1A\x07\x00"), "application/x-rar-compressed"}, // RAR v1.5-v4.0
+ &exactSig{[]byte("Rar!\x1A\x07\x01\x00"), "application/x-rar-compressed"}, // RAR v5+
+
+ &exactSig{[]byte("\x00\x61\x73\x6D"), "application/wasm"},
+
+ textSig{}, // should be last
+}
+
+type exactSig struct {
+ sig []byte
+ ct string
+}
+
+func (e *exactSig) match(data []byte, firstNonWS int) string {
+ if bytes.HasPrefix(data, e.sig) {
+ return e.ct
+ }
+ return ""
+}
+
+type maskedSig struct {
+ mask, pat []byte
+ skipWS bool
+ ct string
+}
+
+func (m *maskedSig) match(data []byte, firstNonWS int) string {
+ // pattern matching algorithm section 6
+ // https://mimesniff.spec.whatwg.org/#pattern-matching-algorithm
+
+ if m.skipWS {
+ data = data[firstNonWS:]
+ }
+ if len(m.pat) != len(m.mask) {
+ return ""
+ }
+ if len(data) < len(m.pat) {
+ return ""
+ }
+ for i, pb := range m.pat {
+ maskedData := data[i] & m.mask[i]
+ if maskedData != pb {
+ return ""
+ }
+ }
+ return m.ct
+}
+
+type htmlSig []byte
+
+func (h htmlSig) match(data []byte, firstNonWS int) string {
+ data = data[firstNonWS:]
+ if len(data) < len(h)+1 {
+ return ""
+ }
+ for i, b := range h {
+ db := data[i]
+ if 'A' <= b && b <= 'Z' {
+ db &= 0xDF
+ }
+ if b != db {
+ return ""
+ }
+ }
+ // Next byte must be a tag-terminating byte(0xTT).
+ if !isTT(data[len(h)]) {
+ return ""
+ }
+ return "text/html; charset=utf-8"
+}
+
+var mp4ftype = []byte("ftyp")
+var mp4 = []byte("mp4")
+
+type mp4Sig struct{}
+
+func (mp4Sig) match(data []byte, firstNonWS int) string {
+ // https://mimesniff.spec.whatwg.org/#signature-for-mp4
+ // c.f. section 6.2.1
+ if len(data) < 12 {
+ return ""
+ }
+ boxSize := int(binary.BigEndian.Uint32(data[:4]))
+ if len(data) < boxSize || boxSize%4 != 0 {
+ return ""
+ }
+ if !bytes.Equal(data[4:8], mp4ftype) {
+ return ""
+ }
+ for st := 8; st < boxSize; st += 4 {
+ if st == 12 {
+ // Ignores the four bytes that correspond to the version number of the "major brand".
+ continue
+ }
+ if bytes.Equal(data[st:st+3], mp4) {
+ return "video/mp4"
+ }
+ }
+ return ""
+}
+
+type textSig struct{}
+
+func (textSig) match(data []byte, firstNonWS int) string {
+ // c.f. section 5, step 4.
+ for _, b := range data[firstNonWS:] {
+ switch {
+ case b <= 0x08,
+ b == 0x0B,
+ 0x0E <= b && b <= 0x1A,
+ 0x1C <= b && b <= 0x1F:
+ return ""
+ }
+ }
+ return "text/plain; charset=utf-8"
+}
diff --git a/src/net/http/sniff_test.go b/src/net/http/sniff_test.go
new file mode 100644
index 0000000..d6ef409
--- /dev/null
+++ b/src/net/http/sniff_test.go
@@ -0,0 +1,282 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http_test
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "log"
+ . "net/http"
+ "reflect"
+ "strconv"
+ "strings"
+ "testing"
+)
+
+var sniffTests = []struct {
+ desc string
+ data []byte
+ contentType string
+}{
+ // Some nonsense.
+ {"Empty", []byte{}, "text/plain; charset=utf-8"},
+ {"Binary", []byte{1, 2, 3}, "application/octet-stream"},
+
+ {"HTML document #1", []byte(`<HtMl><bOdY>blah blah blah</body></html>`), "text/html; charset=utf-8"},
+ {"HTML document #2", []byte(`<HTML></HTML>`), "text/html; charset=utf-8"},
+ {"HTML document #3 (leading whitespace)", []byte(` <!DOCTYPE HTML>...`), "text/html; charset=utf-8"},
+ {"HTML document #4 (leading CRLF)", []byte("\r\n<html>..."), "text/html; charset=utf-8"},
+
+ {"Plain text", []byte(`This is not HTML. It has ☃ though.`), "text/plain; charset=utf-8"},
+
+ {"XML", []byte("\n<?xml!"), "text/xml; charset=utf-8"},
+
+ // Image types.
+ {"Windows icon", []byte("\x00\x00\x01\x00"), "image/x-icon"},
+ {"Windows cursor", []byte("\x00\x00\x02\x00"), "image/x-icon"},
+ {"BMP image", []byte("BM..."), "image/bmp"},
+ {"GIF 87a", []byte(`GIF87a`), "image/gif"},
+ {"GIF 89a", []byte(`GIF89a...`), "image/gif"},
+ {"WEBP image", []byte("RIFF\x00\x00\x00\x00WEBPVP"), "image/webp"},
+ {"PNG image", []byte("\x89PNG\x0D\x0A\x1A\x0A"), "image/png"},
+ {"JPEG image", []byte("\xFF\xD8\xFF"), "image/jpeg"},
+
+ // Audio types.
+ {"MIDI audio", []byte("MThd\x00\x00\x00\x06\x00\x01"), "audio/midi"},
+ {"MP3 audio/MPEG audio", []byte("ID3\x03\x00\x00\x00\x00\x0f"), "audio/mpeg"},
+ {"WAV audio #1", []byte("RIFFb\xb8\x00\x00WAVEfmt \x12\x00\x00\x00\x06"), "audio/wave"},
+ {"WAV audio #2", []byte("RIFF,\x00\x00\x00WAVEfmt \x12\x00\x00\x00\x06"), "audio/wave"},
+ {"AIFF audio #1", []byte("FORM\x00\x00\x00\x00AIFFCOMM\x00\x00\x00\x12\x00\x01\x00\x00\x57\x55\x00\x10\x40\x0d\xf3\x34"), "audio/aiff"},
+
+ {"OGG audio", []byte("OggS\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x7e\x46\x00\x00\x00\x00\x00\x00\x1f\xf6\xb4\xfc\x01\x1e\x01\x76\x6f\x72"), "application/ogg"},
+ {"Must not match OGG", []byte("owow\x00"), "application/octet-stream"},
+ {"Must not match OGG", []byte("oooS\x00"), "application/octet-stream"},
+ {"Must not match OGG", []byte("oggS\x00"), "application/octet-stream"},
+
+ // Video types.
+ {"MP4 video", []byte("\x00\x00\x00\x18ftypmp42\x00\x00\x00\x00mp42isom<\x06t\xbfmdat"), "video/mp4"},
+ {"AVI video #1", []byte("RIFF,O\n\x00AVI LISTÀ"), "video/avi"},
+ {"AVI video #2", []byte("RIFF,\n\x00\x00AVI LISTÀ"), "video/avi"},
+
+ // Font types.
+ // {"MS.FontObject", []byte("\x00\x00")},
+ {"TTF sample I", []byte("\x00\x01\x00\x00\x00\x17\x01\x00\x00\x04\x01\x60\x4f"), "font/ttf"},
+ {"TTF sample II", []byte("\x00\x01\x00\x00\x00\x0e\x00\x80\x00\x03\x00\x60\x46"), "font/ttf"},
+
+ {"OTTO sample I", []byte("\x4f\x54\x54\x4f\x00\x0e\x00\x80\x00\x03\x00\x60\x42\x41\x53\x45"), "font/otf"},
+
+ {"woff sample I", []byte("\x77\x4f\x46\x46\x00\x01\x00\x00\x00\x00\x30\x54\x00\x0d\x00\x00"), "font/woff"},
+ {"woff2 sample", []byte("\x77\x4f\x46\x32\x00\x01\x00\x00\x00"), "font/woff2"},
+ {"wasm sample", []byte("\x00\x61\x73\x6d\x01\x00"), "application/wasm"},
+
+ // Archive types
+ {"RAR v1.5-v4.0", []byte("Rar!\x1A\x07\x00"), "application/x-rar-compressed"},
+ {"RAR v5+", []byte("Rar!\x1A\x07\x01\x00"), "application/x-rar-compressed"},
+ {"Incorrect RAR v1.5-v4.0", []byte("Rar \x1A\x07\x00"), "application/octet-stream"},
+ {"Incorrect RAR v5+", []byte("Rar \x1A\x07\x01\x00"), "application/octet-stream"},
+}
+
+func TestDetectContentType(t *testing.T) {
+ for _, tt := range sniffTests {
+ ct := DetectContentType(tt.data)
+ if ct != tt.contentType {
+ t.Errorf("%v: DetectContentType = %q, want %q", tt.desc, ct, tt.contentType)
+ }
+ }
+}
+
+func TestServerContentTypeSniff(t *testing.T) { run(t, testServerContentTypeSniff) }
+func testServerContentTypeSniff(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ i, _ := strconv.Atoi(r.FormValue("i"))
+ tt := sniffTests[i]
+ n, err := w.Write(tt.data)
+ if n != len(tt.data) || err != nil {
+ log.Fatalf("%v: Write(%q) = %v, %v want %d, nil", tt.desc, tt.data, n, err, len(tt.data))
+ }
+ }))
+ defer cst.close()
+
+ for i, tt := range sniffTests {
+ resp, err := cst.c.Get(cst.ts.URL + "/?i=" + strconv.Itoa(i))
+ if err != nil {
+ t.Errorf("%v: %v", tt.desc, err)
+ continue
+ }
+ // DetectContentType is defined to return
+ // text/plain; charset=utf-8 for an empty body,
+ // but as of Go 1.10 the HTTP server has been changed
+ // to return no content-type at all for an empty body.
+ // Adjust the expectation here.
+ wantContentType := tt.contentType
+ if len(tt.data) == 0 {
+ wantContentType = ""
+ }
+ if ct := resp.Header.Get("Content-Type"); ct != wantContentType {
+ t.Errorf("%v: Content-Type = %q, want %q", tt.desc, ct, wantContentType)
+ }
+ data, err := io.ReadAll(resp.Body)
+ if err != nil {
+ t.Errorf("%v: reading body: %v", tt.desc, err)
+ } else if !bytes.Equal(data, tt.data) {
+ t.Errorf("%v: data is %q, want %q", tt.desc, data, tt.data)
+ }
+ resp.Body.Close()
+ }
+}
+
+// Issue 5953: shouldn't sniff if the handler set a Content-Type header,
+// even if it's the empty string.
+func TestServerIssue5953(t *testing.T) { run(t, testServerIssue5953) }
+func testServerIssue5953(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header()["Content-Type"] = []string{""}
+ fmt.Fprintf(w, "<html><head></head><body>hi</body></html>")
+ }))
+
+ resp, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ got := resp.Header["Content-Type"]
+ want := []string{""}
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("Content-Type = %q; want %q", got, want)
+ }
+ resp.Body.Close()
+}
+
+type byteAtATimeReader struct {
+ buf []byte
+}
+
+func (b *byteAtATimeReader) Read(p []byte) (n int, err error) {
+ if len(p) < 1 {
+ return 0, nil
+ }
+ if len(b.buf) == 0 {
+ return 0, io.EOF
+ }
+ p[0] = b.buf[0]
+ b.buf = b.buf[1:]
+ return 1, nil
+}
+
+func TestContentTypeWithVariousSources(t *testing.T) { run(t, testContentTypeWithVariousSources) }
+func testContentTypeWithVariousSources(t *testing.T, mode testMode) {
+ const (
+ input = "\n<html>\n\t<head>\n"
+ expected = "text/html; charset=utf-8"
+ )
+
+ for _, test := range []struct {
+ name string
+ handler func(ResponseWriter, *Request)
+ }{{
+ name: "write",
+ handler: func(w ResponseWriter, r *Request) {
+ // Write the whole input at once.
+ n, err := w.Write([]byte(input))
+ if int(n) != len(input) || err != nil {
+ t.Errorf("w.Write(%q) = %v, %v want %d, nil", input, n, err, len(input))
+ }
+ },
+ }, {
+ name: "write one byte at a time",
+ handler: func(w ResponseWriter, r *Request) {
+ // Write the input one byte at a time.
+ buf := []byte(input)
+ for i := range buf {
+ n, err := w.Write(buf[i : i+1])
+ if n != 1 || err != nil {
+ t.Errorf("w.Write(%q) = %v, %v want 1, nil", input, n, err)
+ }
+ }
+ },
+ }, {
+ name: "copy from Reader",
+ handler: func(w ResponseWriter, r *Request) {
+ // Use io.Copy from a plain Reader.
+ type readerOnly struct{ io.Reader }
+ buf := bytes.NewBuffer([]byte(input))
+ n, err := io.Copy(w, readerOnly{buf})
+ if int(n) != len(input) || err != nil {
+ t.Errorf("io.Copy(w, %q) = %v, %v want %d, nil", input, n, err, len(input))
+ }
+ },
+ }, {
+ name: "copy from bytes.Buffer",
+ handler: func(w ResponseWriter, r *Request) {
+ // Use io.Copy from a bytes.Buffer to trigger ReadFrom.
+ buf := bytes.NewBuffer([]byte(input))
+ n, err := io.Copy(w, buf)
+ if int(n) != len(input) || err != nil {
+ t.Errorf("io.Copy(w, %q) = %v, %v want %d, nil", input, n, err, len(input))
+ }
+ },
+ }, {
+ name: "copy one byte at a time",
+ handler: func(w ResponseWriter, r *Request) {
+ // Use io.Copy from a Reader that returns one byte at a time.
+ n, err := io.Copy(w, &byteAtATimeReader{[]byte(input)})
+ if int(n) != len(input) || err != nil {
+ t.Errorf("io.Copy(w, %q) = %v, %v want %d, nil", input, n, err, len(input))
+ }
+ },
+ }} {
+ t.Run(test.name, func(t *testing.T) {
+ cst := newClientServerTest(t, mode, HandlerFunc(test.handler))
+
+ resp, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ if ct := resp.Header.Get("Content-Type"); ct != expected {
+ t.Errorf("Content-Type = %q, want %q", ct, expected)
+ }
+ if want, got := resp.Header.Get("Content-Length"), fmt.Sprint(len(input)); want != got {
+ t.Errorf("Content-Length = %q, want %q", want, got)
+ }
+ data, err := io.ReadAll(resp.Body)
+ if err != nil {
+ t.Errorf("reading body: %v", err)
+ } else if !bytes.Equal(data, []byte(input)) {
+ t.Errorf("data is %q, want %q", data, input)
+ }
+ resp.Body.Close()
+
+ })
+
+ }
+}
+
+func TestSniffWriteSize(t *testing.T) { run(t, testSniffWriteSize) }
+func testSniffWriteSize(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ size, _ := strconv.Atoi(r.FormValue("size"))
+ written, err := io.WriteString(w, strings.Repeat("a", size))
+ if err != nil {
+ t.Errorf("write of %d bytes: %v", size, err)
+ return
+ }
+ if written != size {
+ t.Errorf("write of %d bytes wrote %d bytes", size, written)
+ }
+ }))
+ for _, size := range []int{0, 1, 200, 600, 999, 1000, 1023, 1024, 512 << 10, 1 << 20} {
+ res, err := cst.c.Get(fmt.Sprintf("%s/?size=%d", cst.ts.URL, size))
+ if err != nil {
+ t.Fatalf("size %d: %v", size, err)
+ }
+ if _, err := io.Copy(io.Discard, res.Body); err != nil {
+ t.Fatalf("size %d: io.Copy of body = %v", size, err)
+ }
+ if err := res.Body.Close(); err != nil {
+ t.Fatalf("size %d: body Close = %v", size, err)
+ }
+ }
+}
diff --git a/src/net/http/socks_bundle.go b/src/net/http/socks_bundle.go
new file mode 100644
index 0000000..776b03d
--- /dev/null
+++ b/src/net/http/socks_bundle.go
@@ -0,0 +1,473 @@
+// Code generated by golang.org/x/tools/cmd/bundle. DO NOT EDIT.
+//go:generate bundle -o socks_bundle.go -prefix socks golang.org/x/net/internal/socks
+
+// Package socks provides a SOCKS version 5 client implementation.
+//
+// SOCKS protocol version 5 is defined in RFC 1928.
+// Username/Password authentication for SOCKS version 5 is defined in
+// RFC 1929.
+//
+
+package http
+
+import (
+ "context"
+ "errors"
+ "io"
+ "net"
+ "strconv"
+ "time"
+)
+
+var (
+ socksnoDeadline = time.Time{}
+ socksaLongTimeAgo = time.Unix(1, 0)
+)
+
+func (d *socksDialer) connect(ctx context.Context, c net.Conn, address string) (_ net.Addr, ctxErr error) {
+ host, port, err := sockssplitHostPort(address)
+ if err != nil {
+ return nil, err
+ }
+ if deadline, ok := ctx.Deadline(); ok && !deadline.IsZero() {
+ c.SetDeadline(deadline)
+ defer c.SetDeadline(socksnoDeadline)
+ }
+ if ctx != context.Background() {
+ errCh := make(chan error, 1)
+ done := make(chan struct{})
+ defer func() {
+ close(done)
+ if ctxErr == nil {
+ ctxErr = <-errCh
+ }
+ }()
+ go func() {
+ select {
+ case <-ctx.Done():
+ c.SetDeadline(socksaLongTimeAgo)
+ errCh <- ctx.Err()
+ case <-done:
+ errCh <- nil
+ }
+ }()
+ }
+
+ b := make([]byte, 0, 6+len(host)) // the size here is just an estimate
+ b = append(b, socksVersion5)
+ if len(d.AuthMethods) == 0 || d.Authenticate == nil {
+ b = append(b, 1, byte(socksAuthMethodNotRequired))
+ } else {
+ ams := d.AuthMethods
+ if len(ams) > 255 {
+ return nil, errors.New("too many authentication methods")
+ }
+ b = append(b, byte(len(ams)))
+ for _, am := range ams {
+ b = append(b, byte(am))
+ }
+ }
+ if _, ctxErr = c.Write(b); ctxErr != nil {
+ return
+ }
+
+ if _, ctxErr = io.ReadFull(c, b[:2]); ctxErr != nil {
+ return
+ }
+ if b[0] != socksVersion5 {
+ return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0])))
+ }
+ am := socksAuthMethod(b[1])
+ if am == socksAuthMethodNoAcceptableMethods {
+ return nil, errors.New("no acceptable authentication methods")
+ }
+ if d.Authenticate != nil {
+ if ctxErr = d.Authenticate(ctx, c, am); ctxErr != nil {
+ return
+ }
+ }
+
+ b = b[:0]
+ b = append(b, socksVersion5, byte(d.cmd), 0)
+ if ip := net.ParseIP(host); ip != nil {
+ if ip4 := ip.To4(); ip4 != nil {
+ b = append(b, socksAddrTypeIPv4)
+ b = append(b, ip4...)
+ } else if ip6 := ip.To16(); ip6 != nil {
+ b = append(b, socksAddrTypeIPv6)
+ b = append(b, ip6...)
+ } else {
+ return nil, errors.New("unknown address type")
+ }
+ } else {
+ if len(host) > 255 {
+ return nil, errors.New("FQDN too long")
+ }
+ b = append(b, socksAddrTypeFQDN)
+ b = append(b, byte(len(host)))
+ b = append(b, host...)
+ }
+ b = append(b, byte(port>>8), byte(port))
+ if _, ctxErr = c.Write(b); ctxErr != nil {
+ return
+ }
+
+ if _, ctxErr = io.ReadFull(c, b[:4]); ctxErr != nil {
+ return
+ }
+ if b[0] != socksVersion5 {
+ return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0])))
+ }
+ if cmdErr := socksReply(b[1]); cmdErr != socksStatusSucceeded {
+ return nil, errors.New("unknown error " + cmdErr.String())
+ }
+ if b[2] != 0 {
+ return nil, errors.New("non-zero reserved field")
+ }
+ l := 2
+ var a socksAddr
+ switch b[3] {
+ case socksAddrTypeIPv4:
+ l += net.IPv4len
+ a.IP = make(net.IP, net.IPv4len)
+ case socksAddrTypeIPv6:
+ l += net.IPv6len
+ a.IP = make(net.IP, net.IPv6len)
+ case socksAddrTypeFQDN:
+ if _, err := io.ReadFull(c, b[:1]); err != nil {
+ return nil, err
+ }
+ l += int(b[0])
+ default:
+ return nil, errors.New("unknown address type " + strconv.Itoa(int(b[3])))
+ }
+ if cap(b) < l {
+ b = make([]byte, l)
+ } else {
+ b = b[:l]
+ }
+ if _, ctxErr = io.ReadFull(c, b); ctxErr != nil {
+ return
+ }
+ if a.IP != nil {
+ copy(a.IP, b)
+ } else {
+ a.Name = string(b[:len(b)-2])
+ }
+ a.Port = int(b[len(b)-2])<<8 | int(b[len(b)-1])
+ return &a, nil
+}
+
+func sockssplitHostPort(address string) (string, int, error) {
+ host, port, err := net.SplitHostPort(address)
+ if err != nil {
+ return "", 0, err
+ }
+ portnum, err := strconv.Atoi(port)
+ if err != nil {
+ return "", 0, err
+ }
+ if 1 > portnum || portnum > 0xffff {
+ return "", 0, errors.New("port number out of range " + port)
+ }
+ return host, portnum, nil
+}
+
+// A Command represents a SOCKS command.
+type socksCommand int
+
+func (cmd socksCommand) String() string {
+ switch cmd {
+ case socksCmdConnect:
+ return "socks connect"
+ case sockscmdBind:
+ return "socks bind"
+ default:
+ return "socks " + strconv.Itoa(int(cmd))
+ }
+}
+
+// An AuthMethod represents a SOCKS authentication method.
+type socksAuthMethod int
+
+// A Reply represents a SOCKS command reply code.
+type socksReply int
+
+func (code socksReply) String() string {
+ switch code {
+ case socksStatusSucceeded:
+ return "succeeded"
+ case 0x01:
+ return "general SOCKS server failure"
+ case 0x02:
+ return "connection not allowed by ruleset"
+ case 0x03:
+ return "network unreachable"
+ case 0x04:
+ return "host unreachable"
+ case 0x05:
+ return "connection refused"
+ case 0x06:
+ return "TTL expired"
+ case 0x07:
+ return "command not supported"
+ case 0x08:
+ return "address type not supported"
+ default:
+ return "unknown code: " + strconv.Itoa(int(code))
+ }
+}
+
+// Wire protocol constants.
+const (
+ socksVersion5 = 0x05
+
+ socksAddrTypeIPv4 = 0x01
+ socksAddrTypeFQDN = 0x03
+ socksAddrTypeIPv6 = 0x04
+
+ socksCmdConnect socksCommand = 0x01 // establishes an active-open forward proxy connection
+ sockscmdBind socksCommand = 0x02 // establishes a passive-open forward proxy connection
+
+ socksAuthMethodNotRequired socksAuthMethod = 0x00 // no authentication required
+ socksAuthMethodUsernamePassword socksAuthMethod = 0x02 // use username/password
+ socksAuthMethodNoAcceptableMethods socksAuthMethod = 0xff // no acceptable authentication methods
+
+ socksStatusSucceeded socksReply = 0x00
+)
+
+// An Addr represents a SOCKS-specific address.
+// Either Name or IP is used exclusively.
+type socksAddr struct {
+ Name string // fully-qualified domain name
+ IP net.IP
+ Port int
+}
+
+func (a *socksAddr) Network() string { return "socks" }
+
+func (a *socksAddr) String() string {
+ if a == nil {
+ return "<nil>"
+ }
+ port := strconv.Itoa(a.Port)
+ if a.IP == nil {
+ return net.JoinHostPort(a.Name, port)
+ }
+ return net.JoinHostPort(a.IP.String(), port)
+}
+
+// A Conn represents a forward proxy connection.
+type socksConn struct {
+ net.Conn
+
+ boundAddr net.Addr
+}
+
+// BoundAddr returns the address assigned by the proxy server for
+// connecting to the command target address from the proxy server.
+func (c *socksConn) BoundAddr() net.Addr {
+ if c == nil {
+ return nil
+ }
+ return c.boundAddr
+}
+
+// A Dialer holds SOCKS-specific options.
+type socksDialer struct {
+ cmd socksCommand // either CmdConnect or cmdBind
+ proxyNetwork string // network between a proxy server and a client
+ proxyAddress string // proxy server address
+
+ // ProxyDial specifies the optional dial function for
+ // establishing the transport connection.
+ ProxyDial func(context.Context, string, string) (net.Conn, error)
+
+ // AuthMethods specifies the list of request authentication
+ // methods.
+ // If empty, SOCKS client requests only AuthMethodNotRequired.
+ AuthMethods []socksAuthMethod
+
+ // Authenticate specifies the optional authentication
+ // function. It must be non-nil when AuthMethods is not empty.
+ // It must return an error when the authentication is failed.
+ Authenticate func(context.Context, io.ReadWriter, socksAuthMethod) error
+}
+
+// DialContext connects to the provided address on the provided
+// network.
+//
+// The returned error value may be a net.OpError. When the Op field of
+// net.OpError contains "socks", the Source field contains a proxy
+// server address and the Addr field contains a command target
+// address.
+//
+// See func Dial of the net package of standard library for a
+// description of the network and address parameters.
+func (d *socksDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
+ if err := d.validateTarget(network, address); err != nil {
+ proxy, dst, _ := d.pathAddrs(address)
+ return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
+ }
+ if ctx == nil {
+ proxy, dst, _ := d.pathAddrs(address)
+ return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("nil context")}
+ }
+ var err error
+ var c net.Conn
+ if d.ProxyDial != nil {
+ c, err = d.ProxyDial(ctx, d.proxyNetwork, d.proxyAddress)
+ } else {
+ var dd net.Dialer
+ c, err = dd.DialContext(ctx, d.proxyNetwork, d.proxyAddress)
+ }
+ if err != nil {
+ proxy, dst, _ := d.pathAddrs(address)
+ return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
+ }
+ a, err := d.connect(ctx, c, address)
+ if err != nil {
+ c.Close()
+ proxy, dst, _ := d.pathAddrs(address)
+ return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
+ }
+ return &socksConn{Conn: c, boundAddr: a}, nil
+}
+
+// DialWithConn initiates a connection from SOCKS server to the target
+// network and address using the connection c that is already
+// connected to the SOCKS server.
+//
+// It returns the connection's local address assigned by the SOCKS
+// server.
+func (d *socksDialer) DialWithConn(ctx context.Context, c net.Conn, network, address string) (net.Addr, error) {
+ if err := d.validateTarget(network, address); err != nil {
+ proxy, dst, _ := d.pathAddrs(address)
+ return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
+ }
+ if ctx == nil {
+ proxy, dst, _ := d.pathAddrs(address)
+ return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("nil context")}
+ }
+ a, err := d.connect(ctx, c, address)
+ if err != nil {
+ proxy, dst, _ := d.pathAddrs(address)
+ return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
+ }
+ return a, nil
+}
+
+// Dial connects to the provided address on the provided network.
+//
+// Unlike DialContext, it returns a raw transport connection instead
+// of a forward proxy connection.
+//
+// Deprecated: Use DialContext or DialWithConn instead.
+func (d *socksDialer) Dial(network, address string) (net.Conn, error) {
+ if err := d.validateTarget(network, address); err != nil {
+ proxy, dst, _ := d.pathAddrs(address)
+ return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
+ }
+ var err error
+ var c net.Conn
+ if d.ProxyDial != nil {
+ c, err = d.ProxyDial(context.Background(), d.proxyNetwork, d.proxyAddress)
+ } else {
+ c, err = net.Dial(d.proxyNetwork, d.proxyAddress)
+ }
+ if err != nil {
+ proxy, dst, _ := d.pathAddrs(address)
+ return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
+ }
+ if _, err := d.DialWithConn(context.Background(), c, network, address); err != nil {
+ c.Close()
+ return nil, err
+ }
+ return c, nil
+}
+
+func (d *socksDialer) validateTarget(network, address string) error {
+ switch network {
+ case "tcp", "tcp6", "tcp4":
+ default:
+ return errors.New("network not implemented")
+ }
+ switch d.cmd {
+ case socksCmdConnect, sockscmdBind:
+ default:
+ return errors.New("command not implemented")
+ }
+ return nil
+}
+
+func (d *socksDialer) pathAddrs(address string) (proxy, dst net.Addr, err error) {
+ for i, s := range []string{d.proxyAddress, address} {
+ host, port, err := sockssplitHostPort(s)
+ if err != nil {
+ return nil, nil, err
+ }
+ a := &socksAddr{Port: port}
+ a.IP = net.ParseIP(host)
+ if a.IP == nil {
+ a.Name = host
+ }
+ if i == 0 {
+ proxy = a
+ } else {
+ dst = a
+ }
+ }
+ return
+}
+
+// NewDialer returns a new Dialer that dials through the provided
+// proxy server's network and address.
+func socksNewDialer(network, address string) *socksDialer {
+ return &socksDialer{proxyNetwork: network, proxyAddress: address, cmd: socksCmdConnect}
+}
+
+const (
+ socksauthUsernamePasswordVersion = 0x01
+ socksauthStatusSucceeded = 0x00
+)
+
+// UsernamePassword are the credentials for the username/password
+// authentication method.
+type socksUsernamePassword struct {
+ Username string
+ Password string
+}
+
+// Authenticate authenticates a pair of username and password with the
+// proxy server.
+func (up *socksUsernamePassword) Authenticate(ctx context.Context, rw io.ReadWriter, auth socksAuthMethod) error {
+ switch auth {
+ case socksAuthMethodNotRequired:
+ return nil
+ case socksAuthMethodUsernamePassword:
+ if len(up.Username) == 0 || len(up.Username) > 255 || len(up.Password) > 255 {
+ return errors.New("invalid username/password")
+ }
+ b := []byte{socksauthUsernamePasswordVersion}
+ b = append(b, byte(len(up.Username)))
+ b = append(b, up.Username...)
+ b = append(b, byte(len(up.Password)))
+ b = append(b, up.Password...)
+ // TODO(mikio): handle IO deadlines and cancelation if
+ // necessary
+ if _, err := rw.Write(b); err != nil {
+ return err
+ }
+ if _, err := io.ReadFull(rw, b[:2]); err != nil {
+ return err
+ }
+ if b[0] != socksauthUsernamePasswordVersion {
+ return errors.New("invalid username/password version")
+ }
+ if b[1] != socksauthStatusSucceeded {
+ return errors.New("username/password authentication failed")
+ }
+ return nil
+ }
+ return errors.New("unsupported authentication method " + strconv.Itoa(int(auth)))
+}
diff --git a/src/net/http/status.go b/src/net/http/status.go
new file mode 100644
index 0000000..cd90877
--- /dev/null
+++ b/src/net/http/status.go
@@ -0,0 +1,210 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+// HTTP status codes as registered with IANA.
+// See: https://www.iana.org/assignments/http-status-codes/http-status-codes.xhtml
+const (
+ StatusContinue = 100 // RFC 9110, 15.2.1
+ StatusSwitchingProtocols = 101 // RFC 9110, 15.2.2
+ StatusProcessing = 102 // RFC 2518, 10.1
+ StatusEarlyHints = 103 // RFC 8297
+
+ StatusOK = 200 // RFC 9110, 15.3.1
+ StatusCreated = 201 // RFC 9110, 15.3.2
+ StatusAccepted = 202 // RFC 9110, 15.3.3
+ StatusNonAuthoritativeInfo = 203 // RFC 9110, 15.3.4
+ StatusNoContent = 204 // RFC 9110, 15.3.5
+ StatusResetContent = 205 // RFC 9110, 15.3.6
+ StatusPartialContent = 206 // RFC 9110, 15.3.7
+ StatusMultiStatus = 207 // RFC 4918, 11.1
+ StatusAlreadyReported = 208 // RFC 5842, 7.1
+ StatusIMUsed = 226 // RFC 3229, 10.4.1
+
+ StatusMultipleChoices = 300 // RFC 9110, 15.4.1
+ StatusMovedPermanently = 301 // RFC 9110, 15.4.2
+ StatusFound = 302 // RFC 9110, 15.4.3
+ StatusSeeOther = 303 // RFC 9110, 15.4.4
+ StatusNotModified = 304 // RFC 9110, 15.4.5
+ StatusUseProxy = 305 // RFC 9110, 15.4.6
+ _ = 306 // RFC 9110, 15.4.7 (Unused)
+ StatusTemporaryRedirect = 307 // RFC 9110, 15.4.8
+ StatusPermanentRedirect = 308 // RFC 9110, 15.4.9
+
+ StatusBadRequest = 400 // RFC 9110, 15.5.1
+ StatusUnauthorized = 401 // RFC 9110, 15.5.2
+ StatusPaymentRequired = 402 // RFC 9110, 15.5.3
+ StatusForbidden = 403 // RFC 9110, 15.5.4
+ StatusNotFound = 404 // RFC 9110, 15.5.5
+ StatusMethodNotAllowed = 405 // RFC 9110, 15.5.6
+ StatusNotAcceptable = 406 // RFC 9110, 15.5.7
+ StatusProxyAuthRequired = 407 // RFC 9110, 15.5.8
+ StatusRequestTimeout = 408 // RFC 9110, 15.5.9
+ StatusConflict = 409 // RFC 9110, 15.5.10
+ StatusGone = 410 // RFC 9110, 15.5.11
+ StatusLengthRequired = 411 // RFC 9110, 15.5.12
+ StatusPreconditionFailed = 412 // RFC 9110, 15.5.13
+ StatusRequestEntityTooLarge = 413 // RFC 9110, 15.5.14
+ StatusRequestURITooLong = 414 // RFC 9110, 15.5.15
+ StatusUnsupportedMediaType = 415 // RFC 9110, 15.5.16
+ StatusRequestedRangeNotSatisfiable = 416 // RFC 9110, 15.5.17
+ StatusExpectationFailed = 417 // RFC 9110, 15.5.18
+ StatusTeapot = 418 // RFC 9110, 15.5.19 (Unused)
+ StatusMisdirectedRequest = 421 // RFC 9110, 15.5.20
+ StatusUnprocessableEntity = 422 // RFC 9110, 15.5.21
+ StatusLocked = 423 // RFC 4918, 11.3
+ StatusFailedDependency = 424 // RFC 4918, 11.4
+ StatusTooEarly = 425 // RFC 8470, 5.2.
+ StatusUpgradeRequired = 426 // RFC 9110, 15.5.22
+ StatusPreconditionRequired = 428 // RFC 6585, 3
+ StatusTooManyRequests = 429 // RFC 6585, 4
+ StatusRequestHeaderFieldsTooLarge = 431 // RFC 6585, 5
+ StatusUnavailableForLegalReasons = 451 // RFC 7725, 3
+
+ StatusInternalServerError = 500 // RFC 9110, 15.6.1
+ StatusNotImplemented = 501 // RFC 9110, 15.6.2
+ StatusBadGateway = 502 // RFC 9110, 15.6.3
+ StatusServiceUnavailable = 503 // RFC 9110, 15.6.4
+ StatusGatewayTimeout = 504 // RFC 9110, 15.6.5
+ StatusHTTPVersionNotSupported = 505 // RFC 9110, 15.6.6
+ StatusVariantAlsoNegotiates = 506 // RFC 2295, 8.1
+ StatusInsufficientStorage = 507 // RFC 4918, 11.5
+ StatusLoopDetected = 508 // RFC 5842, 7.2
+ StatusNotExtended = 510 // RFC 2774, 7
+ StatusNetworkAuthenticationRequired = 511 // RFC 6585, 6
+)
+
+// StatusText returns a text for the HTTP status code. It returns the empty
+// string if the code is unknown.
+func StatusText(code int) string {
+ switch code {
+ case StatusContinue:
+ return "Continue"
+ case StatusSwitchingProtocols:
+ return "Switching Protocols"
+ case StatusProcessing:
+ return "Processing"
+ case StatusEarlyHints:
+ return "Early Hints"
+ case StatusOK:
+ return "OK"
+ case StatusCreated:
+ return "Created"
+ case StatusAccepted:
+ return "Accepted"
+ case StatusNonAuthoritativeInfo:
+ return "Non-Authoritative Information"
+ case StatusNoContent:
+ return "No Content"
+ case StatusResetContent:
+ return "Reset Content"
+ case StatusPartialContent:
+ return "Partial Content"
+ case StatusMultiStatus:
+ return "Multi-Status"
+ case StatusAlreadyReported:
+ return "Already Reported"
+ case StatusIMUsed:
+ return "IM Used"
+ case StatusMultipleChoices:
+ return "Multiple Choices"
+ case StatusMovedPermanently:
+ return "Moved Permanently"
+ case StatusFound:
+ return "Found"
+ case StatusSeeOther:
+ return "See Other"
+ case StatusNotModified:
+ return "Not Modified"
+ case StatusUseProxy:
+ return "Use Proxy"
+ case StatusTemporaryRedirect:
+ return "Temporary Redirect"
+ case StatusPermanentRedirect:
+ return "Permanent Redirect"
+ case StatusBadRequest:
+ return "Bad Request"
+ case StatusUnauthorized:
+ return "Unauthorized"
+ case StatusPaymentRequired:
+ return "Payment Required"
+ case StatusForbidden:
+ return "Forbidden"
+ case StatusNotFound:
+ return "Not Found"
+ case StatusMethodNotAllowed:
+ return "Method Not Allowed"
+ case StatusNotAcceptable:
+ return "Not Acceptable"
+ case StatusProxyAuthRequired:
+ return "Proxy Authentication Required"
+ case StatusRequestTimeout:
+ return "Request Timeout"
+ case StatusConflict:
+ return "Conflict"
+ case StatusGone:
+ return "Gone"
+ case StatusLengthRequired:
+ return "Length Required"
+ case StatusPreconditionFailed:
+ return "Precondition Failed"
+ case StatusRequestEntityTooLarge:
+ return "Request Entity Too Large"
+ case StatusRequestURITooLong:
+ return "Request URI Too Long"
+ case StatusUnsupportedMediaType:
+ return "Unsupported Media Type"
+ case StatusRequestedRangeNotSatisfiable:
+ return "Requested Range Not Satisfiable"
+ case StatusExpectationFailed:
+ return "Expectation Failed"
+ case StatusTeapot:
+ return "I'm a teapot"
+ case StatusMisdirectedRequest:
+ return "Misdirected Request"
+ case StatusUnprocessableEntity:
+ return "Unprocessable Entity"
+ case StatusLocked:
+ return "Locked"
+ case StatusFailedDependency:
+ return "Failed Dependency"
+ case StatusTooEarly:
+ return "Too Early"
+ case StatusUpgradeRequired:
+ return "Upgrade Required"
+ case StatusPreconditionRequired:
+ return "Precondition Required"
+ case StatusTooManyRequests:
+ return "Too Many Requests"
+ case StatusRequestHeaderFieldsTooLarge:
+ return "Request Header Fields Too Large"
+ case StatusUnavailableForLegalReasons:
+ return "Unavailable For Legal Reasons"
+ case StatusInternalServerError:
+ return "Internal Server Error"
+ case StatusNotImplemented:
+ return "Not Implemented"
+ case StatusBadGateway:
+ return "Bad Gateway"
+ case StatusServiceUnavailable:
+ return "Service Unavailable"
+ case StatusGatewayTimeout:
+ return "Gateway Timeout"
+ case StatusHTTPVersionNotSupported:
+ return "HTTP Version Not Supported"
+ case StatusVariantAlsoNegotiates:
+ return "Variant Also Negotiates"
+ case StatusInsufficientStorage:
+ return "Insufficient Storage"
+ case StatusLoopDetected:
+ return "Loop Detected"
+ case StatusNotExtended:
+ return "Not Extended"
+ case StatusNetworkAuthenticationRequired:
+ return "Network Authentication Required"
+ default:
+ return ""
+ }
+}
diff --git a/src/net/http/testdata/file b/src/net/http/testdata/file
new file mode 100644
index 0000000..11f11f9
--- /dev/null
+++ b/src/net/http/testdata/file
@@ -0,0 +1 @@
+0123456789
diff --git a/src/net/http/testdata/index.html b/src/net/http/testdata/index.html
new file mode 100644
index 0000000..da8e1e9
--- /dev/null
+++ b/src/net/http/testdata/index.html
@@ -0,0 +1 @@
+index.html says hello
diff --git a/src/net/http/testdata/style.css b/src/net/http/testdata/style.css
new file mode 100644
index 0000000..208d16d
--- /dev/null
+++ b/src/net/http/testdata/style.css
@@ -0,0 +1 @@
+body {}
diff --git a/src/net/http/transfer.go b/src/net/http/transfer.go
new file mode 100644
index 0000000..d6f26a7
--- /dev/null
+++ b/src/net/http/transfer.go
@@ -0,0 +1,1124 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "bufio"
+ "bytes"
+ "errors"
+ "fmt"
+ "io"
+ "net/http/httptrace"
+ "net/http/internal"
+ "net/http/internal/ascii"
+ "net/textproto"
+ "reflect"
+ "sort"
+ "strconv"
+ "strings"
+ "sync"
+ "time"
+
+ "golang.org/x/net/http/httpguts"
+)
+
+// ErrLineTooLong is returned when reading request or response bodies
+// with malformed chunked encoding.
+var ErrLineTooLong = internal.ErrLineTooLong
+
+type errorReader struct {
+ err error
+}
+
+func (r errorReader) Read(p []byte) (n int, err error) {
+ return 0, r.err
+}
+
+type byteReader struct {
+ b byte
+ done bool
+}
+
+func (br *byteReader) Read(p []byte) (n int, err error) {
+ if br.done {
+ return 0, io.EOF
+ }
+ if len(p) == 0 {
+ return 0, nil
+ }
+ br.done = true
+ p[0] = br.b
+ return 1, io.EOF
+}
+
+// transferWriter inspects the fields of a user-supplied Request or Response,
+// sanitizes them without changing the user object and provides methods for
+// writing the respective header, body and trailer in wire format.
+type transferWriter struct {
+ Method string
+ Body io.Reader
+ BodyCloser io.Closer
+ ResponseToHEAD bool
+ ContentLength int64 // -1 means unknown, 0 means exactly none
+ Close bool
+ TransferEncoding []string
+ Header Header
+ Trailer Header
+ IsResponse bool
+ bodyReadError error // any non-EOF error from reading Body
+
+ FlushHeaders bool // flush headers to network before body
+ ByteReadCh chan readResult // non-nil if probeRequestBody called
+}
+
+func newTransferWriter(r any) (t *transferWriter, err error) {
+ t = &transferWriter{}
+
+ // Extract relevant fields
+ atLeastHTTP11 := false
+ switch rr := r.(type) {
+ case *Request:
+ if rr.ContentLength != 0 && rr.Body == nil {
+ return nil, fmt.Errorf("http: Request.ContentLength=%d with nil Body", rr.ContentLength)
+ }
+ t.Method = valueOrDefault(rr.Method, "GET")
+ t.Close = rr.Close
+ t.TransferEncoding = rr.TransferEncoding
+ t.Header = rr.Header
+ t.Trailer = rr.Trailer
+ t.Body = rr.Body
+ t.BodyCloser = rr.Body
+ t.ContentLength = rr.outgoingLength()
+ if t.ContentLength < 0 && len(t.TransferEncoding) == 0 && t.shouldSendChunkedRequestBody() {
+ t.TransferEncoding = []string{"chunked"}
+ }
+ // If there's a body, conservatively flush the headers
+ // to any bufio.Writer we're writing to, just in case
+ // the server needs the headers early, before we copy
+ // the body and possibly block. We make an exception
+ // for the common standard library in-memory types,
+ // though, to avoid unnecessary TCP packets on the
+ // wire. (Issue 22088.)
+ if t.ContentLength != 0 && !isKnownInMemoryReader(t.Body) {
+ t.FlushHeaders = true
+ }
+
+ atLeastHTTP11 = true // Transport requests are always 1.1 or 2.0
+ case *Response:
+ t.IsResponse = true
+ if rr.Request != nil {
+ t.Method = rr.Request.Method
+ }
+ t.Body = rr.Body
+ t.BodyCloser = rr.Body
+ t.ContentLength = rr.ContentLength
+ t.Close = rr.Close
+ t.TransferEncoding = rr.TransferEncoding
+ t.Header = rr.Header
+ t.Trailer = rr.Trailer
+ atLeastHTTP11 = rr.ProtoAtLeast(1, 1)
+ t.ResponseToHEAD = noResponseBodyExpected(t.Method)
+ }
+
+ // Sanitize Body,ContentLength,TransferEncoding
+ if t.ResponseToHEAD {
+ t.Body = nil
+ if chunked(t.TransferEncoding) {
+ t.ContentLength = -1
+ }
+ } else {
+ if !atLeastHTTP11 || t.Body == nil {
+ t.TransferEncoding = nil
+ }
+ if chunked(t.TransferEncoding) {
+ t.ContentLength = -1
+ } else if t.Body == nil { // no chunking, no body
+ t.ContentLength = 0
+ }
+ }
+
+ // Sanitize Trailer
+ if !chunked(t.TransferEncoding) {
+ t.Trailer = nil
+ }
+
+ return t, nil
+}
+
+// shouldSendChunkedRequestBody reports whether we should try to send a
+// chunked request body to the server. In particular, the case we really
+// want to prevent is sending a GET or other typically-bodyless request to a
+// server with a chunked body when the body has zero bytes, since GETs with
+// bodies (while acceptable according to specs), even zero-byte chunked
+// bodies, are approximately never seen in the wild and confuse most
+// servers. See Issue 18257, as one example.
+//
+// The only reason we'd send such a request is if the user set the Body to a
+// non-nil value (say, io.NopCloser(bytes.NewReader(nil))) and didn't
+// set ContentLength, or NewRequest set it to -1 (unknown), so then we assume
+// there's bytes to send.
+//
+// This code tries to read a byte from the Request.Body in such cases to see
+// whether the body actually has content (super rare) or is actually just
+// a non-nil content-less ReadCloser (the more common case). In that more
+// common case, we act as if their Body were nil instead, and don't send
+// a body.
+func (t *transferWriter) shouldSendChunkedRequestBody() bool {
+ // Note that t.ContentLength is the corrected content length
+ // from rr.outgoingLength, so 0 actually means zero, not unknown.
+ if t.ContentLength >= 0 || t.Body == nil { // redundant checks; caller did them
+ return false
+ }
+ if t.Method == "CONNECT" {
+ return false
+ }
+ if requestMethodUsuallyLacksBody(t.Method) {
+ // Only probe the Request.Body for GET/HEAD/DELETE/etc
+ // requests, because it's only those types of requests
+ // that confuse servers.
+ t.probeRequestBody() // adjusts t.Body, t.ContentLength
+ return t.Body != nil
+ }
+ // For all other request types (PUT, POST, PATCH, or anything
+ // made-up we've never heard of), assume it's normal and the server
+ // can deal with a chunked request body. Maybe we'll adjust this
+ // later.
+ return true
+}
+
+// probeRequestBody reads a byte from t.Body to see whether it's empty
+// (returns io.EOF right away).
+//
+// But because we've had problems with this blocking users in the past
+// (issue 17480) when the body is a pipe (perhaps waiting on the response
+// headers before the pipe is fed data), we need to be careful and bound how
+// long we wait for it. This delay will only affect users if all the following
+// are true:
+// - the request body blocks
+// - the content length is not set (or set to -1)
+// - the method doesn't usually have a body (GET, HEAD, DELETE, ...)
+// - there is no transfer-encoding=chunked already set.
+//
+// In other words, this delay will not normally affect anybody, and there
+// are workarounds if it does.
+func (t *transferWriter) probeRequestBody() {
+ t.ByteReadCh = make(chan readResult, 1)
+ go func(body io.Reader) {
+ var buf [1]byte
+ var rres readResult
+ rres.n, rres.err = body.Read(buf[:])
+ if rres.n == 1 {
+ rres.b = buf[0]
+ }
+ t.ByteReadCh <- rres
+ close(t.ByteReadCh)
+ }(t.Body)
+ timer := time.NewTimer(200 * time.Millisecond)
+ select {
+ case rres := <-t.ByteReadCh:
+ timer.Stop()
+ if rres.n == 0 && rres.err == io.EOF {
+ // It was empty.
+ t.Body = nil
+ t.ContentLength = 0
+ } else if rres.n == 1 {
+ if rres.err != nil {
+ t.Body = io.MultiReader(&byteReader{b: rres.b}, errorReader{rres.err})
+ } else {
+ t.Body = io.MultiReader(&byteReader{b: rres.b}, t.Body)
+ }
+ } else if rres.err != nil {
+ t.Body = errorReader{rres.err}
+ }
+ case <-timer.C:
+ // Too slow. Don't wait. Read it later, and keep
+ // assuming that this is ContentLength == -1
+ // (unknown), which means we'll send a
+ // "Transfer-Encoding: chunked" header.
+ t.Body = io.MultiReader(finishAsyncByteRead{t}, t.Body)
+ // Request that Request.Write flush the headers to the
+ // network before writing the body, since our body may not
+ // become readable until it's seen the response headers.
+ t.FlushHeaders = true
+ }
+}
+
+func noResponseBodyExpected(requestMethod string) bool {
+ return requestMethod == "HEAD"
+}
+
+func (t *transferWriter) shouldSendContentLength() bool {
+ if chunked(t.TransferEncoding) {
+ return false
+ }
+ if t.ContentLength > 0 {
+ return true
+ }
+ if t.ContentLength < 0 {
+ return false
+ }
+ // Many servers expect a Content-Length for these methods
+ if t.Method == "POST" || t.Method == "PUT" || t.Method == "PATCH" {
+ return true
+ }
+ if t.ContentLength == 0 && isIdentity(t.TransferEncoding) {
+ if t.Method == "GET" || t.Method == "HEAD" {
+ return false
+ }
+ return true
+ }
+
+ return false
+}
+
+func (t *transferWriter) writeHeader(w io.Writer, trace *httptrace.ClientTrace) error {
+ if t.Close && !hasToken(t.Header.get("Connection"), "close") {
+ if _, err := io.WriteString(w, "Connection: close\r\n"); err != nil {
+ return err
+ }
+ if trace != nil && trace.WroteHeaderField != nil {
+ trace.WroteHeaderField("Connection", []string{"close"})
+ }
+ }
+
+ // Write Content-Length and/or Transfer-Encoding whose values are a
+ // function of the sanitized field triple (Body, ContentLength,
+ // TransferEncoding)
+ if t.shouldSendContentLength() {
+ if _, err := io.WriteString(w, "Content-Length: "); err != nil {
+ return err
+ }
+ if _, err := io.WriteString(w, strconv.FormatInt(t.ContentLength, 10)+"\r\n"); err != nil {
+ return err
+ }
+ if trace != nil && trace.WroteHeaderField != nil {
+ trace.WroteHeaderField("Content-Length", []string{strconv.FormatInt(t.ContentLength, 10)})
+ }
+ } else if chunked(t.TransferEncoding) {
+ if _, err := io.WriteString(w, "Transfer-Encoding: chunked\r\n"); err != nil {
+ return err
+ }
+ if trace != nil && trace.WroteHeaderField != nil {
+ trace.WroteHeaderField("Transfer-Encoding", []string{"chunked"})
+ }
+ }
+
+ // Write Trailer header
+ if t.Trailer != nil {
+ keys := make([]string, 0, len(t.Trailer))
+ for k := range t.Trailer {
+ k = CanonicalHeaderKey(k)
+ switch k {
+ case "Transfer-Encoding", "Trailer", "Content-Length":
+ return badStringError("invalid Trailer key", k)
+ }
+ keys = append(keys, k)
+ }
+ if len(keys) > 0 {
+ sort.Strings(keys)
+ // TODO: could do better allocation-wise here, but trailers are rare,
+ // so being lazy for now.
+ if _, err := io.WriteString(w, "Trailer: "+strings.Join(keys, ",")+"\r\n"); err != nil {
+ return err
+ }
+ if trace != nil && trace.WroteHeaderField != nil {
+ trace.WroteHeaderField("Trailer", keys)
+ }
+ }
+ }
+
+ return nil
+}
+
+// always closes t.BodyCloser
+func (t *transferWriter) writeBody(w io.Writer) (err error) {
+ var ncopy int64
+ closed := false
+ defer func() {
+ if closed || t.BodyCloser == nil {
+ return
+ }
+ if closeErr := t.BodyCloser.Close(); closeErr != nil && err == nil {
+ err = closeErr
+ }
+ }()
+
+ // Write body. We "unwrap" the body first if it was wrapped in a
+ // nopCloser or readTrackingBody. This is to ensure that we can take advantage of
+ // OS-level optimizations in the event that the body is an
+ // *os.File.
+ if t.Body != nil {
+ var body = t.unwrapBody()
+ if chunked(t.TransferEncoding) {
+ if bw, ok := w.(*bufio.Writer); ok && !t.IsResponse {
+ w = &internal.FlushAfterChunkWriter{Writer: bw}
+ }
+ cw := internal.NewChunkedWriter(w)
+ _, err = t.doBodyCopy(cw, body)
+ if err == nil {
+ err = cw.Close()
+ }
+ } else if t.ContentLength == -1 {
+ dst := w
+ if t.Method == "CONNECT" {
+ dst = bufioFlushWriter{dst}
+ }
+ ncopy, err = t.doBodyCopy(dst, body)
+ } else {
+ ncopy, err = t.doBodyCopy(w, io.LimitReader(body, t.ContentLength))
+ if err != nil {
+ return err
+ }
+ var nextra int64
+ nextra, err = t.doBodyCopy(io.Discard, body)
+ ncopy += nextra
+ }
+ if err != nil {
+ return err
+ }
+ }
+ if t.BodyCloser != nil {
+ closed = true
+ if err := t.BodyCloser.Close(); err != nil {
+ return err
+ }
+ }
+
+ if !t.ResponseToHEAD && t.ContentLength != -1 && t.ContentLength != ncopy {
+ return fmt.Errorf("http: ContentLength=%d with Body length %d",
+ t.ContentLength, ncopy)
+ }
+
+ if chunked(t.TransferEncoding) {
+ // Write Trailer header
+ if t.Trailer != nil {
+ if err := t.Trailer.Write(w); err != nil {
+ return err
+ }
+ }
+ // Last chunk, empty trailer
+ _, err = io.WriteString(w, "\r\n")
+ }
+ return err
+}
+
+// doBodyCopy wraps a copy operation, with any resulting error also
+// being saved in bodyReadError.
+//
+// This function is only intended for use in writeBody.
+func (t *transferWriter) doBodyCopy(dst io.Writer, src io.Reader) (n int64, err error) {
+ n, err = io.Copy(dst, src)
+ if err != nil && err != io.EOF {
+ t.bodyReadError = err
+ }
+ return
+}
+
+// unwrapBody unwraps the body's inner reader if it's a
+// nopCloser. This is to ensure that body writes sourced from local
+// files (*os.File types) are properly optimized.
+//
+// This function is only intended for use in writeBody.
+func (t *transferWriter) unwrapBody() io.Reader {
+ if r, ok := unwrapNopCloser(t.Body); ok {
+ return r
+ }
+ if r, ok := t.Body.(*readTrackingBody); ok {
+ r.didRead = true
+ return r.ReadCloser
+ }
+ return t.Body
+}
+
+type transferReader struct {
+ // Input
+ Header Header
+ StatusCode int
+ RequestMethod string
+ ProtoMajor int
+ ProtoMinor int
+ // Output
+ Body io.ReadCloser
+ ContentLength int64
+ Chunked bool
+ Close bool
+ Trailer Header
+}
+
+func (t *transferReader) protoAtLeast(m, n int) bool {
+ return t.ProtoMajor > m || (t.ProtoMajor == m && t.ProtoMinor >= n)
+}
+
+// bodyAllowedForStatus reports whether a given response status code
+// permits a body. See RFC 7230, section 3.3.
+func bodyAllowedForStatus(status int) bool {
+ switch {
+ case status >= 100 && status <= 199:
+ return false
+ case status == 204:
+ return false
+ case status == 304:
+ return false
+ }
+ return true
+}
+
+var (
+ suppressedHeaders304 = []string{"Content-Type", "Content-Length", "Transfer-Encoding"}
+ suppressedHeadersNoBody = []string{"Content-Length", "Transfer-Encoding"}
+ excludedHeadersNoBody = map[string]bool{"Content-Length": true, "Transfer-Encoding": true}
+)
+
+func suppressedHeaders(status int) []string {
+ switch {
+ case status == 304:
+ // RFC 7232 section 4.1
+ return suppressedHeaders304
+ case !bodyAllowedForStatus(status):
+ return suppressedHeadersNoBody
+ }
+ return nil
+}
+
+// msg is *Request or *Response.
+func readTransfer(msg any, r *bufio.Reader) (err error) {
+ t := &transferReader{RequestMethod: "GET"}
+
+ // Unify input
+ isResponse := false
+ switch rr := msg.(type) {
+ case *Response:
+ t.Header = rr.Header
+ t.StatusCode = rr.StatusCode
+ t.ProtoMajor = rr.ProtoMajor
+ t.ProtoMinor = rr.ProtoMinor
+ t.Close = shouldClose(t.ProtoMajor, t.ProtoMinor, t.Header, true)
+ isResponse = true
+ if rr.Request != nil {
+ t.RequestMethod = rr.Request.Method
+ }
+ case *Request:
+ t.Header = rr.Header
+ t.RequestMethod = rr.Method
+ t.ProtoMajor = rr.ProtoMajor
+ t.ProtoMinor = rr.ProtoMinor
+ // Transfer semantics for Requests are exactly like those for
+ // Responses with status code 200, responding to a GET method
+ t.StatusCode = 200
+ t.Close = rr.Close
+ default:
+ panic("unexpected type")
+ }
+
+ // Default to HTTP/1.1
+ if t.ProtoMajor == 0 && t.ProtoMinor == 0 {
+ t.ProtoMajor, t.ProtoMinor = 1, 1
+ }
+
+ // Transfer-Encoding: chunked, and overriding Content-Length.
+ if err := t.parseTransferEncoding(); err != nil {
+ return err
+ }
+
+ realLength, err := fixLength(isResponse, t.StatusCode, t.RequestMethod, t.Header, t.Chunked)
+ if err != nil {
+ return err
+ }
+ if isResponse && t.RequestMethod == "HEAD" {
+ if n, err := parseContentLength(t.Header.get("Content-Length")); err != nil {
+ return err
+ } else {
+ t.ContentLength = n
+ }
+ } else {
+ t.ContentLength = realLength
+ }
+
+ // Trailer
+ t.Trailer, err = fixTrailer(t.Header, t.Chunked)
+ if err != nil {
+ return err
+ }
+
+ // If there is no Content-Length or chunked Transfer-Encoding on a *Response
+ // and the status is not 1xx, 204 or 304, then the body is unbounded.
+ // See RFC 7230, section 3.3.
+ switch msg.(type) {
+ case *Response:
+ if realLength == -1 && !t.Chunked && bodyAllowedForStatus(t.StatusCode) {
+ // Unbounded body.
+ t.Close = true
+ }
+ }
+
+ // Prepare body reader. ContentLength < 0 means chunked encoding
+ // or close connection when finished, since multipart is not supported yet
+ switch {
+ case t.Chunked:
+ if isResponse && (noResponseBodyExpected(t.RequestMethod) || !bodyAllowedForStatus(t.StatusCode)) {
+ t.Body = NoBody
+ } else {
+ t.Body = &body{src: internal.NewChunkedReader(r), hdr: msg, r: r, closing: t.Close}
+ }
+ case realLength == 0:
+ t.Body = NoBody
+ case realLength > 0:
+ t.Body = &body{src: io.LimitReader(r, realLength), closing: t.Close}
+ default:
+ // realLength < 0, i.e. "Content-Length" not mentioned in header
+ if t.Close {
+ // Close semantics (i.e. HTTP/1.0)
+ t.Body = &body{src: r, closing: t.Close}
+ } else {
+ // Persistent connection (i.e. HTTP/1.1)
+ t.Body = NoBody
+ }
+ }
+
+ // Unify output
+ switch rr := msg.(type) {
+ case *Request:
+ rr.Body = t.Body
+ rr.ContentLength = t.ContentLength
+ if t.Chunked {
+ rr.TransferEncoding = []string{"chunked"}
+ }
+ rr.Close = t.Close
+ rr.Trailer = t.Trailer
+ case *Response:
+ rr.Body = t.Body
+ rr.ContentLength = t.ContentLength
+ if t.Chunked {
+ rr.TransferEncoding = []string{"chunked"}
+ }
+ rr.Close = t.Close
+ rr.Trailer = t.Trailer
+ }
+
+ return nil
+}
+
+// Checks whether chunked is part of the encodings stack.
+func chunked(te []string) bool { return len(te) > 0 && te[0] == "chunked" }
+
+// Checks whether the encoding is explicitly "identity".
+func isIdentity(te []string) bool { return len(te) == 1 && te[0] == "identity" }
+
+// unsupportedTEError reports unsupported transfer-encodings.
+type unsupportedTEError struct {
+ err string
+}
+
+func (uste *unsupportedTEError) Error() string {
+ return uste.err
+}
+
+// isUnsupportedTEError checks if the error is of type
+// unsupportedTEError. It is usually invoked with a non-nil err.
+func isUnsupportedTEError(err error) bool {
+ _, ok := err.(*unsupportedTEError)
+ return ok
+}
+
+// parseTransferEncoding sets t.Chunked based on the Transfer-Encoding header.
+func (t *transferReader) parseTransferEncoding() error {
+ raw, present := t.Header["Transfer-Encoding"]
+ if !present {
+ return nil
+ }
+ delete(t.Header, "Transfer-Encoding")
+
+ // Issue 12785; ignore Transfer-Encoding on HTTP/1.0 requests.
+ if !t.protoAtLeast(1, 1) {
+ return nil
+ }
+
+ // Like nginx, we only support a single Transfer-Encoding header field, and
+ // only if set to "chunked". This is one of the most security sensitive
+ // surfaces in HTTP/1.1 due to the risk of request smuggling, so we keep it
+ // strict and simple.
+ if len(raw) != 1 {
+ return &unsupportedTEError{fmt.Sprintf("too many transfer encodings: %q", raw)}
+ }
+ if !ascii.EqualFold(raw[0], "chunked") {
+ return &unsupportedTEError{fmt.Sprintf("unsupported transfer encoding: %q", raw[0])}
+ }
+
+ // RFC 7230 3.3.2 says "A sender MUST NOT send a Content-Length header field
+ // in any message that contains a Transfer-Encoding header field."
+ //
+ // but also: "If a message is received with both a Transfer-Encoding and a
+ // Content-Length header field, the Transfer-Encoding overrides the
+ // Content-Length. Such a message might indicate an attempt to perform
+ // request smuggling (Section 9.5) or response splitting (Section 9.4) and
+ // ought to be handled as an error. A sender MUST remove the received
+ // Content-Length field prior to forwarding such a message downstream."
+ //
+ // Reportedly, these appear in the wild.
+ delete(t.Header, "Content-Length")
+
+ t.Chunked = true
+ return nil
+}
+
+// Determine the expected body length, using RFC 7230 Section 3.3. This
+// function is not a method, because ultimately it should be shared by
+// ReadResponse and ReadRequest.
+func fixLength(isResponse bool, status int, requestMethod string, header Header, chunked bool) (int64, error) {
+ isRequest := !isResponse
+ contentLens := header["Content-Length"]
+
+ // Hardening against HTTP request smuggling
+ if len(contentLens) > 1 {
+ // Per RFC 7230 Section 3.3.2, prevent multiple
+ // Content-Length headers if they differ in value.
+ // If there are dups of the value, remove the dups.
+ // See Issue 16490.
+ first := textproto.TrimString(contentLens[0])
+ for _, ct := range contentLens[1:] {
+ if first != textproto.TrimString(ct) {
+ return 0, fmt.Errorf("http: message cannot contain multiple Content-Length headers; got %q", contentLens)
+ }
+ }
+
+ // deduplicate Content-Length
+ header.Del("Content-Length")
+ header.Add("Content-Length", first)
+
+ contentLens = header["Content-Length"]
+ }
+
+ // Logic based on response type or status
+ if isResponse && noResponseBodyExpected(requestMethod) {
+ return 0, nil
+ }
+ if status/100 == 1 {
+ return 0, nil
+ }
+ switch status {
+ case 204, 304:
+ return 0, nil
+ }
+
+ // Logic based on Transfer-Encoding
+ if chunked {
+ return -1, nil
+ }
+
+ // Logic based on Content-Length
+ var cl string
+ if len(contentLens) == 1 {
+ cl = textproto.TrimString(contentLens[0])
+ }
+ if cl != "" {
+ n, err := parseContentLength(cl)
+ if err != nil {
+ return -1, err
+ }
+ return n, nil
+ }
+ header.Del("Content-Length")
+
+ if isRequest {
+ // RFC 7230 neither explicitly permits nor forbids an
+ // entity-body on a GET request so we permit one if
+ // declared, but we default to 0 here (not -1 below)
+ // if there's no mention of a body.
+ // Likewise, all other request methods are assumed to have
+ // no body if neither Transfer-Encoding chunked nor a
+ // Content-Length are set.
+ return 0, nil
+ }
+
+ // Body-EOF logic based on other methods (like closing, or chunked coding)
+ return -1, nil
+}
+
+// Determine whether to hang up after sending a request and body, or
+// receiving a response and body
+// 'header' is the request headers.
+func shouldClose(major, minor int, header Header, removeCloseHeader bool) bool {
+ if major < 1 {
+ return true
+ }
+
+ conv := header["Connection"]
+ hasClose := httpguts.HeaderValuesContainsToken(conv, "close")
+ if major == 1 && minor == 0 {
+ return hasClose || !httpguts.HeaderValuesContainsToken(conv, "keep-alive")
+ }
+
+ if hasClose && removeCloseHeader {
+ header.Del("Connection")
+ }
+
+ return hasClose
+}
+
+// Parse the trailer header.
+func fixTrailer(header Header, chunked bool) (Header, error) {
+ vv, ok := header["Trailer"]
+ if !ok {
+ return nil, nil
+ }
+ if !chunked {
+ // Trailer and no chunking:
+ // this is an invalid use case for trailer header.
+ // Nevertheless, no error will be returned and we
+ // let users decide if this is a valid HTTP message.
+ // The Trailer header will be kept in Response.Header
+ // but not populate Response.Trailer.
+ // See issue #27197.
+ return nil, nil
+ }
+ header.Del("Trailer")
+
+ trailer := make(Header)
+ var err error
+ for _, v := range vv {
+ foreachHeaderElement(v, func(key string) {
+ key = CanonicalHeaderKey(key)
+ switch key {
+ case "Transfer-Encoding", "Trailer", "Content-Length":
+ if err == nil {
+ err = badStringError("bad trailer key", key)
+ return
+ }
+ }
+ trailer[key] = nil
+ })
+ }
+ if err != nil {
+ return nil, err
+ }
+ if len(trailer) == 0 {
+ return nil, nil
+ }
+ return trailer, nil
+}
+
+// body turns a Reader into a ReadCloser.
+// Close ensures that the body has been fully read
+// and then reads the trailer if necessary.
+type body struct {
+ src io.Reader
+ hdr any // non-nil (Response or Request) value means read trailer
+ r *bufio.Reader // underlying wire-format reader for the trailer
+ closing bool // is the connection to be closed after reading body?
+ doEarlyClose bool // whether Close should stop early
+
+ mu sync.Mutex // guards following, and calls to Read and Close
+ sawEOF bool
+ closed bool
+ earlyClose bool // Close called and we didn't read to the end of src
+ onHitEOF func() // if non-nil, func to call when EOF is Read
+}
+
+// ErrBodyReadAfterClose is returned when reading a Request or Response
+// Body after the body has been closed. This typically happens when the body is
+// read after an HTTP Handler calls WriteHeader or Write on its
+// ResponseWriter.
+var ErrBodyReadAfterClose = errors.New("http: invalid Read on closed Body")
+
+func (b *body) Read(p []byte) (n int, err error) {
+ b.mu.Lock()
+ defer b.mu.Unlock()
+ if b.closed {
+ return 0, ErrBodyReadAfterClose
+ }
+ return b.readLocked(p)
+}
+
+// Must hold b.mu.
+func (b *body) readLocked(p []byte) (n int, err error) {
+ if b.sawEOF {
+ return 0, io.EOF
+ }
+ n, err = b.src.Read(p)
+
+ if err == io.EOF {
+ b.sawEOF = true
+ // Chunked case. Read the trailer.
+ if b.hdr != nil {
+ if e := b.readTrailer(); e != nil {
+ err = e
+ // Something went wrong in the trailer, we must not allow any
+ // further reads of any kind to succeed from body, nor any
+ // subsequent requests on the server connection. See
+ // golang.org/issue/12027
+ b.sawEOF = false
+ b.closed = true
+ }
+ b.hdr = nil
+ } else {
+ // If the server declared the Content-Length, our body is a LimitedReader
+ // and we need to check whether this EOF arrived early.
+ if lr, ok := b.src.(*io.LimitedReader); ok && lr.N > 0 {
+ err = io.ErrUnexpectedEOF
+ }
+ }
+ }
+
+ // If we can return an EOF here along with the read data, do
+ // so. This is optional per the io.Reader contract, but doing
+ // so helps the HTTP transport code recycle its connection
+ // earlier (since it will see this EOF itself), even if the
+ // client doesn't do future reads or Close.
+ if err == nil && n > 0 {
+ if lr, ok := b.src.(*io.LimitedReader); ok && lr.N == 0 {
+ err = io.EOF
+ b.sawEOF = true
+ }
+ }
+
+ if b.sawEOF && b.onHitEOF != nil {
+ b.onHitEOF()
+ }
+
+ return n, err
+}
+
+var (
+ singleCRLF = []byte("\r\n")
+ doubleCRLF = []byte("\r\n\r\n")
+)
+
+func seeUpcomingDoubleCRLF(r *bufio.Reader) bool {
+ for peekSize := 4; ; peekSize++ {
+ // This loop stops when Peek returns an error,
+ // which it does when r's buffer has been filled.
+ buf, err := r.Peek(peekSize)
+ if bytes.HasSuffix(buf, doubleCRLF) {
+ return true
+ }
+ if err != nil {
+ break
+ }
+ }
+ return false
+}
+
+var errTrailerEOF = errors.New("http: unexpected EOF reading trailer")
+
+func (b *body) readTrailer() error {
+ // The common case, since nobody uses trailers.
+ buf, err := b.r.Peek(2)
+ if bytes.Equal(buf, singleCRLF) {
+ b.r.Discard(2)
+ return nil
+ }
+ if len(buf) < 2 {
+ return errTrailerEOF
+ }
+ if err != nil {
+ return err
+ }
+
+ // Make sure there's a header terminator coming up, to prevent
+ // a DoS with an unbounded size Trailer. It's not easy to
+ // slip in a LimitReader here, as textproto.NewReader requires
+ // a concrete *bufio.Reader. Also, we can't get all the way
+ // back up to our conn's LimitedReader that *might* be backing
+ // this bufio.Reader. Instead, a hack: we iteratively Peek up
+ // to the bufio.Reader's max size, looking for a double CRLF.
+ // This limits the trailer to the underlying buffer size, typically 4kB.
+ if !seeUpcomingDoubleCRLF(b.r) {
+ return errors.New("http: suspiciously long trailer after chunked body")
+ }
+
+ hdr, err := textproto.NewReader(b.r).ReadMIMEHeader()
+ if err != nil {
+ if err == io.EOF {
+ return errTrailerEOF
+ }
+ return err
+ }
+ switch rr := b.hdr.(type) {
+ case *Request:
+ mergeSetHeader(&rr.Trailer, Header(hdr))
+ case *Response:
+ mergeSetHeader(&rr.Trailer, Header(hdr))
+ }
+ return nil
+}
+
+func mergeSetHeader(dst *Header, src Header) {
+ if *dst == nil {
+ *dst = src
+ return
+ }
+ for k, vv := range src {
+ (*dst)[k] = vv
+ }
+}
+
+// unreadDataSizeLocked returns the number of bytes of unread input.
+// It returns -1 if unknown.
+// b.mu must be held.
+func (b *body) unreadDataSizeLocked() int64 {
+ if lr, ok := b.src.(*io.LimitedReader); ok {
+ return lr.N
+ }
+ return -1
+}
+
+func (b *body) Close() error {
+ b.mu.Lock()
+ defer b.mu.Unlock()
+ if b.closed {
+ return nil
+ }
+ var err error
+ switch {
+ case b.sawEOF:
+ // Already saw EOF, so no need going to look for it.
+ case b.hdr == nil && b.closing:
+ // no trailer and closing the connection next.
+ // no point in reading to EOF.
+ case b.doEarlyClose:
+ // Read up to maxPostHandlerReadBytes bytes of the body, looking
+ // for EOF (and trailers), so we can re-use this connection.
+ if lr, ok := b.src.(*io.LimitedReader); ok && lr.N > maxPostHandlerReadBytes {
+ // There was a declared Content-Length, and we have more bytes remaining
+ // than our maxPostHandlerReadBytes tolerance. So, give up.
+ b.earlyClose = true
+ } else {
+ var n int64
+ // Consume the body, or, which will also lead to us reading
+ // the trailer headers after the body, if present.
+ n, err = io.CopyN(io.Discard, bodyLocked{b}, maxPostHandlerReadBytes)
+ if err == io.EOF {
+ err = nil
+ }
+ if n == maxPostHandlerReadBytes {
+ b.earlyClose = true
+ }
+ }
+ default:
+ // Fully consume the body, which will also lead to us reading
+ // the trailer headers after the body, if present.
+ _, err = io.Copy(io.Discard, bodyLocked{b})
+ }
+ b.closed = true
+ return err
+}
+
+func (b *body) didEarlyClose() bool {
+ b.mu.Lock()
+ defer b.mu.Unlock()
+ return b.earlyClose
+}
+
+// bodyRemains reports whether future Read calls might
+// yield data.
+func (b *body) bodyRemains() bool {
+ b.mu.Lock()
+ defer b.mu.Unlock()
+ return !b.sawEOF
+}
+
+func (b *body) registerOnHitEOF(fn func()) {
+ b.mu.Lock()
+ defer b.mu.Unlock()
+ b.onHitEOF = fn
+}
+
+// bodyLocked is an io.Reader reading from a *body when its mutex is
+// already held.
+type bodyLocked struct {
+ b *body
+}
+
+func (bl bodyLocked) Read(p []byte) (n int, err error) {
+ if bl.b.closed {
+ return 0, ErrBodyReadAfterClose
+ }
+ return bl.b.readLocked(p)
+}
+
+// parseContentLength trims whitespace from s and returns -1 if no value
+// is set, or the value if it's >= 0.
+func parseContentLength(cl string) (int64, error) {
+ cl = textproto.TrimString(cl)
+ if cl == "" {
+ return -1, nil
+ }
+ n, err := strconv.ParseUint(cl, 10, 63)
+ if err != nil {
+ return 0, badStringError("bad Content-Length", cl)
+ }
+ return int64(n), nil
+
+}
+
+// finishAsyncByteRead finishes reading the 1-byte sniff
+// from the ContentLength==0, Body!=nil case.
+type finishAsyncByteRead struct {
+ tw *transferWriter
+}
+
+func (fr finishAsyncByteRead) Read(p []byte) (n int, err error) {
+ if len(p) == 0 {
+ return
+ }
+ rres := <-fr.tw.ByteReadCh
+ n, err = rres.n, rres.err
+ if n == 1 {
+ p[0] = rres.b
+ }
+ if err == nil {
+ err = io.EOF
+ }
+ return
+}
+
+var nopCloserType = reflect.TypeOf(io.NopCloser(nil))
+var nopCloserWriterToType = reflect.TypeOf(io.NopCloser(struct {
+ io.Reader
+ io.WriterTo
+}{}))
+
+// unwrapNopCloser return the underlying reader and true if r is a NopCloser
+// else it return false.
+func unwrapNopCloser(r io.Reader) (underlyingReader io.Reader, isNopCloser bool) {
+ switch reflect.TypeOf(r) {
+ case nopCloserType, nopCloserWriterToType:
+ return reflect.ValueOf(r).Field(0).Interface().(io.Reader), true
+ default:
+ return nil, false
+ }
+}
+
+// isKnownInMemoryReader reports whether r is a type known to not
+// block on Read. Its caller uses this as an optional optimization to
+// send fewer TCP packets.
+func isKnownInMemoryReader(r io.Reader) bool {
+ switch r.(type) {
+ case *bytes.Reader, *bytes.Buffer, *strings.Reader:
+ return true
+ }
+ if r, ok := unwrapNopCloser(r); ok {
+ return isKnownInMemoryReader(r)
+ }
+ if r, ok := r.(*readTrackingBody); ok {
+ return isKnownInMemoryReader(r.ReadCloser)
+ }
+ return false
+}
+
+// bufioFlushWriter is an io.Writer wrapper that flushes all writes
+// on its wrapped writer if it's a *bufio.Writer.
+type bufioFlushWriter struct{ w io.Writer }
+
+func (fw bufioFlushWriter) Write(p []byte) (n int, err error) {
+ n, err = fw.w.Write(p)
+ if bw, ok := fw.w.(*bufio.Writer); n > 0 && ok {
+ ferr := bw.Flush()
+ if ferr != nil && err == nil {
+ err = ferr
+ }
+ }
+ return
+}
diff --git a/src/net/http/transfer_test.go b/src/net/http/transfer_test.go
new file mode 100644
index 0000000..5e0df89
--- /dev/null
+++ b/src/net/http/transfer_test.go
@@ -0,0 +1,363 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "bufio"
+ "bytes"
+ "crypto/rand"
+ "fmt"
+ "io"
+ "os"
+ "reflect"
+ "strings"
+ "testing"
+)
+
+func TestBodyReadBadTrailer(t *testing.T) {
+ b := &body{
+ src: strings.NewReader("foobar"),
+ hdr: true, // force reading the trailer
+ r: bufio.NewReader(strings.NewReader("")),
+ }
+ buf := make([]byte, 7)
+ n, err := b.Read(buf[:3])
+ got := string(buf[:n])
+ if got != "foo" || err != nil {
+ t.Fatalf(`first Read = %d (%q), %v; want 3 ("foo")`, n, got, err)
+ }
+
+ n, err = b.Read(buf[:])
+ got = string(buf[:n])
+ if got != "bar" || err != nil {
+ t.Fatalf(`second Read = %d (%q), %v; want 3 ("bar")`, n, got, err)
+ }
+
+ n, err = b.Read(buf[:])
+ got = string(buf[:n])
+ if err == nil {
+ t.Errorf("final Read was successful (%q), expected error from trailer read", got)
+ }
+}
+
+func TestFinalChunkedBodyReadEOF(t *testing.T) {
+ res, err := ReadResponse(bufio.NewReader(strings.NewReader(
+ "HTTP/1.1 200 OK\r\n"+
+ "Transfer-Encoding: chunked\r\n"+
+ "\r\n"+
+ "0a\r\n"+
+ "Body here\n\r\n"+
+ "09\r\n"+
+ "continued\r\n"+
+ "0\r\n"+
+ "\r\n")), nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ want := "Body here\ncontinued"
+ buf := make([]byte, len(want))
+ n, err := res.Body.Read(buf)
+ if n != len(want) || err != io.EOF {
+ t.Logf("body = %#v", res.Body)
+ t.Errorf("Read = %v, %v; want %d, EOF", n, err, len(want))
+ }
+ if string(buf) != want {
+ t.Errorf("buf = %q; want %q", buf, want)
+ }
+}
+
+func TestDetectInMemoryReaders(t *testing.T) {
+ pr, _ := io.Pipe()
+ tests := []struct {
+ r io.Reader
+ want bool
+ }{
+ {pr, false},
+
+ {bytes.NewReader(nil), true},
+ {bytes.NewBuffer(nil), true},
+ {strings.NewReader(""), true},
+
+ {io.NopCloser(pr), false},
+
+ {io.NopCloser(bytes.NewReader(nil)), true},
+ {io.NopCloser(bytes.NewBuffer(nil)), true},
+ {io.NopCloser(strings.NewReader("")), true},
+ }
+ for i, tt := range tests {
+ got := isKnownInMemoryReader(tt.r)
+ if got != tt.want {
+ t.Errorf("%d: got = %v; want %v", i, got, tt.want)
+ }
+ }
+}
+
+type mockTransferWriter struct {
+ CalledReader io.Reader
+ WriteCalled bool
+}
+
+var _ io.ReaderFrom = (*mockTransferWriter)(nil)
+
+func (w *mockTransferWriter) ReadFrom(r io.Reader) (int64, error) {
+ w.CalledReader = r
+ return io.Copy(io.Discard, r)
+}
+
+func (w *mockTransferWriter) Write(p []byte) (int, error) {
+ w.WriteCalled = true
+ return io.Discard.Write(p)
+}
+
+func TestTransferWriterWriteBodyReaderTypes(t *testing.T) {
+ fileType := reflect.TypeOf(&os.File{})
+ bufferType := reflect.TypeOf(&bytes.Buffer{})
+
+ nBytes := int64(1 << 10)
+ newFileFunc := func() (r io.Reader, done func(), err error) {
+ f, err := os.CreateTemp("", "net-http-newfilefunc")
+ if err != nil {
+ return nil, nil, err
+ }
+
+ // Write some bytes to the file to enable reading.
+ if _, err := io.CopyN(f, rand.Reader, nBytes); err != nil {
+ return nil, nil, fmt.Errorf("failed to write data to file: %v", err)
+ }
+ if _, err := f.Seek(0, 0); err != nil {
+ return nil, nil, fmt.Errorf("failed to seek to front: %v", err)
+ }
+
+ done = func() {
+ f.Close()
+ os.Remove(f.Name())
+ }
+
+ return f, done, nil
+ }
+
+ newBufferFunc := func() (io.Reader, func(), error) {
+ return bytes.NewBuffer(make([]byte, nBytes)), func() {}, nil
+ }
+
+ cases := []struct {
+ name string
+ bodyFunc func() (io.Reader, func(), error)
+ method string
+ contentLength int64
+ transferEncoding []string
+ limitedReader bool
+ expectedReader reflect.Type
+ expectedWrite bool
+ }{
+ {
+ name: "file, non-chunked, size set",
+ bodyFunc: newFileFunc,
+ method: "PUT",
+ contentLength: nBytes,
+ limitedReader: true,
+ expectedReader: fileType,
+ },
+ {
+ name: "file, non-chunked, size set, nopCloser wrapped",
+ method: "PUT",
+ bodyFunc: func() (io.Reader, func(), error) {
+ r, cleanup, err := newFileFunc()
+ return io.NopCloser(r), cleanup, err
+ },
+ contentLength: nBytes,
+ limitedReader: true,
+ expectedReader: fileType,
+ },
+ {
+ name: "file, non-chunked, negative size",
+ method: "PUT",
+ bodyFunc: newFileFunc,
+ contentLength: -1,
+ expectedReader: fileType,
+ },
+ {
+ name: "file, non-chunked, CONNECT, negative size",
+ method: "CONNECT",
+ bodyFunc: newFileFunc,
+ contentLength: -1,
+ expectedReader: fileType,
+ },
+ {
+ name: "file, chunked",
+ method: "PUT",
+ bodyFunc: newFileFunc,
+ transferEncoding: []string{"chunked"},
+ expectedWrite: true,
+ },
+ {
+ name: "buffer, non-chunked, size set",
+ bodyFunc: newBufferFunc,
+ method: "PUT",
+ contentLength: nBytes,
+ limitedReader: true,
+ expectedReader: bufferType,
+ },
+ {
+ name: "buffer, non-chunked, size set, nopCloser wrapped",
+ method: "PUT",
+ bodyFunc: func() (io.Reader, func(), error) {
+ r, cleanup, err := newBufferFunc()
+ return io.NopCloser(r), cleanup, err
+ },
+ contentLength: nBytes,
+ limitedReader: true,
+ expectedReader: bufferType,
+ },
+ {
+ name: "buffer, non-chunked, negative size",
+ method: "PUT",
+ bodyFunc: newBufferFunc,
+ contentLength: -1,
+ expectedWrite: true,
+ },
+ {
+ name: "buffer, non-chunked, CONNECT, negative size",
+ method: "CONNECT",
+ bodyFunc: newBufferFunc,
+ contentLength: -1,
+ expectedWrite: true,
+ },
+ {
+ name: "buffer, chunked",
+ method: "PUT",
+ bodyFunc: newBufferFunc,
+ transferEncoding: []string{"chunked"},
+ expectedWrite: true,
+ },
+ }
+
+ for _, tc := range cases {
+ t.Run(tc.name, func(t *testing.T) {
+ body, cleanup, err := tc.bodyFunc()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cleanup()
+
+ mw := &mockTransferWriter{}
+ tw := &transferWriter{
+ Body: body,
+ ContentLength: tc.contentLength,
+ TransferEncoding: tc.transferEncoding,
+ }
+
+ if err := tw.writeBody(mw); err != nil {
+ t.Fatal(err)
+ }
+
+ if tc.expectedReader != nil {
+ if mw.CalledReader == nil {
+ t.Fatal("did not call ReadFrom")
+ }
+
+ var actualReader reflect.Type
+ lr, ok := mw.CalledReader.(*io.LimitedReader)
+ if ok && tc.limitedReader {
+ actualReader = reflect.TypeOf(lr.R)
+ } else {
+ actualReader = reflect.TypeOf(mw.CalledReader)
+ }
+
+ if tc.expectedReader != actualReader {
+ t.Fatalf("got reader %s want %s", actualReader, tc.expectedReader)
+ }
+ }
+
+ if tc.expectedWrite && !mw.WriteCalled {
+ t.Fatal("did not invoke Write")
+ }
+ })
+ }
+}
+
+func TestParseTransferEncoding(t *testing.T) {
+ tests := []struct {
+ hdr Header
+ wantErr error
+ }{
+ {
+ hdr: Header{"Transfer-Encoding": {"fugazi"}},
+ wantErr: &unsupportedTEError{`unsupported transfer encoding: "fugazi"`},
+ },
+ {
+ hdr: Header{"Transfer-Encoding": {"chunked, chunked", "identity", "chunked"}},
+ wantErr: &unsupportedTEError{`too many transfer encodings: ["chunked, chunked" "identity" "chunked"]`},
+ },
+ {
+ hdr: Header{"Transfer-Encoding": {""}},
+ wantErr: &unsupportedTEError{`unsupported transfer encoding: ""`},
+ },
+ {
+ hdr: Header{"Transfer-Encoding": {"chunked, identity"}},
+ wantErr: &unsupportedTEError{`unsupported transfer encoding: "chunked, identity"`},
+ },
+ {
+ hdr: Header{"Transfer-Encoding": {"chunked", "identity"}},
+ wantErr: &unsupportedTEError{`too many transfer encodings: ["chunked" "identity"]`},
+ },
+ {
+ hdr: Header{"Transfer-Encoding": {"\x0bchunked"}},
+ wantErr: &unsupportedTEError{`unsupported transfer encoding: "\vchunked"`},
+ },
+ {
+ hdr: Header{"Transfer-Encoding": {"chunked"}},
+ wantErr: nil,
+ },
+ }
+
+ for i, tt := range tests {
+ tr := &transferReader{
+ Header: tt.hdr,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ }
+ gotErr := tr.parseTransferEncoding()
+ if !reflect.DeepEqual(gotErr, tt.wantErr) {
+ t.Errorf("%d.\ngot error:\n%v\nwant error:\n%v\n\n", i, gotErr, tt.wantErr)
+ }
+ }
+}
+
+// issue 39017 - disallow Content-Length values such as "+3"
+func TestParseContentLength(t *testing.T) {
+ tests := []struct {
+ cl string
+ wantErr error
+ }{
+ {
+ cl: "3",
+ wantErr: nil,
+ },
+ {
+ cl: "+3",
+ wantErr: badStringError("bad Content-Length", "+3"),
+ },
+ {
+ cl: "-3",
+ wantErr: badStringError("bad Content-Length", "-3"),
+ },
+ {
+ // max int64, for safe conversion before returning
+ cl: "9223372036854775807",
+ wantErr: nil,
+ },
+ {
+ cl: "9223372036854775808",
+ wantErr: badStringError("bad Content-Length", "9223372036854775808"),
+ },
+ }
+
+ for _, tt := range tests {
+ if _, gotErr := parseContentLength(tt.cl); !reflect.DeepEqual(gotErr, tt.wantErr) {
+ t.Errorf("%q:\n\tgot=%v\n\twant=%v", tt.cl, gotErr, tt.wantErr)
+ }
+ }
+}
diff --git a/src/net/http/transport.go b/src/net/http/transport.go
new file mode 100644
index 0000000..c07352b
--- /dev/null
+++ b/src/net/http/transport.go
@@ -0,0 +1,2942 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// HTTP client implementation. See RFC 7230 through 7235.
+//
+// This is the low-level Transport implementation of RoundTripper.
+// The high-level interface is in client.go.
+
+package http
+
+import (
+ "bufio"
+ "compress/gzip"
+ "container/list"
+ "context"
+ "crypto/tls"
+ "errors"
+ "fmt"
+ "internal/godebug"
+ "io"
+ "log"
+ "net"
+ "net/http/httptrace"
+ "net/http/internal/ascii"
+ "net/textproto"
+ "net/url"
+ "reflect"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "golang.org/x/net/http/httpguts"
+ "golang.org/x/net/http/httpproxy"
+)
+
+// DefaultTransport is the default implementation of Transport and is
+// used by DefaultClient. It establishes network connections as needed
+// and caches them for reuse by subsequent calls. It uses HTTP proxies
+// as directed by the environment variables HTTP_PROXY, HTTPS_PROXY
+// and NO_PROXY (or the lowercase versions thereof).
+var DefaultTransport RoundTripper = &Transport{
+ Proxy: ProxyFromEnvironment,
+ DialContext: defaultTransportDialContext(&net.Dialer{
+ Timeout: 30 * time.Second,
+ KeepAlive: 30 * time.Second,
+ }),
+ ForceAttemptHTTP2: true,
+ MaxIdleConns: 100,
+ IdleConnTimeout: 90 * time.Second,
+ TLSHandshakeTimeout: 10 * time.Second,
+ ExpectContinueTimeout: 1 * time.Second,
+}
+
+// DefaultMaxIdleConnsPerHost is the default value of Transport's
+// MaxIdleConnsPerHost.
+const DefaultMaxIdleConnsPerHost = 2
+
+// Transport is an implementation of RoundTripper that supports HTTP,
+// HTTPS, and HTTP proxies (for either HTTP or HTTPS with CONNECT).
+//
+// By default, Transport caches connections for future re-use.
+// This may leave many open connections when accessing many hosts.
+// This behavior can be managed using Transport's CloseIdleConnections method
+// and the MaxIdleConnsPerHost and DisableKeepAlives fields.
+//
+// Transports should be reused instead of created as needed.
+// Transports are safe for concurrent use by multiple goroutines.
+//
+// A Transport is a low-level primitive for making HTTP and HTTPS requests.
+// For high-level functionality, such as cookies and redirects, see Client.
+//
+// Transport uses HTTP/1.1 for HTTP URLs and either HTTP/1.1 or HTTP/2
+// for HTTPS URLs, depending on whether the server supports HTTP/2,
+// and how the Transport is configured. The DefaultTransport supports HTTP/2.
+// To explicitly enable HTTP/2 on a transport, use golang.org/x/net/http2
+// and call ConfigureTransport. See the package docs for more about HTTP/2.
+//
+// Responses with status codes in the 1xx range are either handled
+// automatically (100 expect-continue) or ignored. The one
+// exception is HTTP status code 101 (Switching Protocols), which is
+// considered a terminal status and returned by RoundTrip. To see the
+// ignored 1xx responses, use the httptrace trace package's
+// ClientTrace.Got1xxResponse.
+//
+// Transport only retries a request upon encountering a network error
+// if the connection has been already been used successfully and if the
+// request is idempotent and either has no body or has its Request.GetBody
+// defined. HTTP requests are considered idempotent if they have HTTP methods
+// GET, HEAD, OPTIONS, or TRACE; or if their Header map contains an
+// "Idempotency-Key" or "X-Idempotency-Key" entry. If the idempotency key
+// value is a zero-length slice, the request is treated as idempotent but the
+// header is not sent on the wire.
+type Transport struct {
+ idleMu sync.Mutex
+ closeIdle bool // user has requested to close all idle conns
+ idleConn map[connectMethodKey][]*persistConn // most recently used at end
+ idleConnWait map[connectMethodKey]wantConnQueue // waiting getConns
+ idleLRU connLRU
+
+ reqMu sync.Mutex
+ reqCanceler map[cancelKey]func(error)
+
+ altMu sync.Mutex // guards changing altProto only
+ altProto atomic.Value // of nil or map[string]RoundTripper, key is URI scheme
+
+ connsPerHostMu sync.Mutex
+ connsPerHost map[connectMethodKey]int
+ connsPerHostWait map[connectMethodKey]wantConnQueue // waiting getConns
+
+ // Proxy specifies a function to return a proxy for a given
+ // Request. If the function returns a non-nil error, the
+ // request is aborted with the provided error.
+ //
+ // The proxy type is determined by the URL scheme. "http",
+ // "https", and "socks5" are supported. If the scheme is empty,
+ // "http" is assumed.
+ //
+ // If Proxy is nil or returns a nil *URL, no proxy is used.
+ Proxy func(*Request) (*url.URL, error)
+
+ // OnProxyConnectResponse is called when the Transport gets an HTTP response from
+ // a proxy for a CONNECT request. It's called before the check for a 200 OK response.
+ // If it returns an error, the request fails with that error.
+ OnProxyConnectResponse func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error
+
+ // DialContext specifies the dial function for creating unencrypted TCP connections.
+ // If DialContext is nil (and the deprecated Dial below is also nil),
+ // then the transport dials using package net.
+ //
+ // DialContext runs concurrently with calls to RoundTrip.
+ // A RoundTrip call that initiates a dial may end up using
+ // a connection dialed previously when the earlier connection
+ // becomes idle before the later DialContext completes.
+ DialContext func(ctx context.Context, network, addr string) (net.Conn, error)
+
+ // Dial specifies the dial function for creating unencrypted TCP connections.
+ //
+ // Dial runs concurrently with calls to RoundTrip.
+ // A RoundTrip call that initiates a dial may end up using
+ // a connection dialed previously when the earlier connection
+ // becomes idle before the later Dial completes.
+ //
+ // Deprecated: Use DialContext instead, which allows the transport
+ // to cancel dials as soon as they are no longer needed.
+ // If both are set, DialContext takes priority.
+ Dial func(network, addr string) (net.Conn, error)
+
+ // DialTLSContext specifies an optional dial function for creating
+ // TLS connections for non-proxied HTTPS requests.
+ //
+ // If DialTLSContext is nil (and the deprecated DialTLS below is also nil),
+ // DialContext and TLSClientConfig are used.
+ //
+ // If DialTLSContext is set, the Dial and DialContext hooks are not used for HTTPS
+ // requests and the TLSClientConfig and TLSHandshakeTimeout
+ // are ignored. The returned net.Conn is assumed to already be
+ // past the TLS handshake.
+ DialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)
+
+ // DialTLS specifies an optional dial function for creating
+ // TLS connections for non-proxied HTTPS requests.
+ //
+ // Deprecated: Use DialTLSContext instead, which allows the transport
+ // to cancel dials as soon as they are no longer needed.
+ // If both are set, DialTLSContext takes priority.
+ DialTLS func(network, addr string) (net.Conn, error)
+
+ // TLSClientConfig specifies the TLS configuration to use with
+ // tls.Client.
+ // If nil, the default configuration is used.
+ // If non-nil, HTTP/2 support may not be enabled by default.
+ TLSClientConfig *tls.Config
+
+ // TLSHandshakeTimeout specifies the maximum amount of time to
+ // wait for a TLS handshake. Zero means no timeout.
+ TLSHandshakeTimeout time.Duration
+
+ // DisableKeepAlives, if true, disables HTTP keep-alives and
+ // will only use the connection to the server for a single
+ // HTTP request.
+ //
+ // This is unrelated to the similarly named TCP keep-alives.
+ DisableKeepAlives bool
+
+ // DisableCompression, if true, prevents the Transport from
+ // requesting compression with an "Accept-Encoding: gzip"
+ // request header when the Request contains no existing
+ // Accept-Encoding value. If the Transport requests gzip on
+ // its own and gets a gzipped response, it's transparently
+ // decoded in the Response.Body. However, if the user
+ // explicitly requested gzip it is not automatically
+ // uncompressed.
+ DisableCompression bool
+
+ // MaxIdleConns controls the maximum number of idle (keep-alive)
+ // connections across all hosts. Zero means no limit.
+ MaxIdleConns int
+
+ // MaxIdleConnsPerHost, if non-zero, controls the maximum idle
+ // (keep-alive) connections to keep per-host. If zero,
+ // DefaultMaxIdleConnsPerHost is used.
+ MaxIdleConnsPerHost int
+
+ // MaxConnsPerHost optionally limits the total number of
+ // connections per host, including connections in the dialing,
+ // active, and idle states. On limit violation, dials will block.
+ //
+ // Zero means no limit.
+ MaxConnsPerHost int
+
+ // IdleConnTimeout is the maximum amount of time an idle
+ // (keep-alive) connection will remain idle before closing
+ // itself.
+ // Zero means no limit.
+ IdleConnTimeout time.Duration
+
+ // ResponseHeaderTimeout, if non-zero, specifies the amount of
+ // time to wait for a server's response headers after fully
+ // writing the request (including its body, if any). This
+ // time does not include the time to read the response body.
+ ResponseHeaderTimeout time.Duration
+
+ // ExpectContinueTimeout, if non-zero, specifies the amount of
+ // time to wait for a server's first response headers after fully
+ // writing the request headers if the request has an
+ // "Expect: 100-continue" header. Zero means no timeout and
+ // causes the body to be sent immediately, without
+ // waiting for the server to approve.
+ // This time does not include the time to send the request header.
+ ExpectContinueTimeout time.Duration
+
+ // TLSNextProto specifies how the Transport switches to an
+ // alternate protocol (such as HTTP/2) after a TLS ALPN
+ // protocol negotiation. If Transport dials an TLS connection
+ // with a non-empty protocol name and TLSNextProto contains a
+ // map entry for that key (such as "h2"), then the func is
+ // called with the request's authority (such as "example.com"
+ // or "example.com:1234") and the TLS connection. The function
+ // must return a RoundTripper that then handles the request.
+ // If TLSNextProto is not nil, HTTP/2 support is not enabled
+ // automatically.
+ TLSNextProto map[string]func(authority string, c *tls.Conn) RoundTripper
+
+ // ProxyConnectHeader optionally specifies headers to send to
+ // proxies during CONNECT requests.
+ // To set the header dynamically, see GetProxyConnectHeader.
+ ProxyConnectHeader Header
+
+ // GetProxyConnectHeader optionally specifies a func to return
+ // headers to send to proxyURL during a CONNECT request to the
+ // ip:port target.
+ // If it returns an error, the Transport's RoundTrip fails with
+ // that error. It can return (nil, nil) to not add headers.
+ // If GetProxyConnectHeader is non-nil, ProxyConnectHeader is
+ // ignored.
+ GetProxyConnectHeader func(ctx context.Context, proxyURL *url.URL, target string) (Header, error)
+
+ // MaxResponseHeaderBytes specifies a limit on how many
+ // response bytes are allowed in the server's response
+ // header.
+ //
+ // Zero means to use a default limit.
+ MaxResponseHeaderBytes int64
+
+ // WriteBufferSize specifies the size of the write buffer used
+ // when writing to the transport.
+ // If zero, a default (currently 4KB) is used.
+ WriteBufferSize int
+
+ // ReadBufferSize specifies the size of the read buffer used
+ // when reading from the transport.
+ // If zero, a default (currently 4KB) is used.
+ ReadBufferSize int
+
+ // nextProtoOnce guards initialization of TLSNextProto and
+ // h2transport (via onceSetNextProtoDefaults)
+ nextProtoOnce sync.Once
+ h2transport h2Transport // non-nil if http2 wired up
+ tlsNextProtoWasNil bool // whether TLSNextProto was nil when the Once fired
+
+ // ForceAttemptHTTP2 controls whether HTTP/2 is enabled when a non-zero
+ // Dial, DialTLS, or DialContext func or TLSClientConfig is provided.
+ // By default, use of any those fields conservatively disables HTTP/2.
+ // To use a custom dialer or TLS config and still attempt HTTP/2
+ // upgrades, set this to true.
+ ForceAttemptHTTP2 bool
+}
+
+// A cancelKey is the key of the reqCanceler map.
+// We wrap the *Request in this type since we want to use the original request,
+// not any transient one created by roundTrip.
+type cancelKey struct {
+ req *Request
+}
+
+func (t *Transport) writeBufferSize() int {
+ if t.WriteBufferSize > 0 {
+ return t.WriteBufferSize
+ }
+ return 4 << 10
+}
+
+func (t *Transport) readBufferSize() int {
+ if t.ReadBufferSize > 0 {
+ return t.ReadBufferSize
+ }
+ return 4 << 10
+}
+
+// Clone returns a deep copy of t's exported fields.
+func (t *Transport) Clone() *Transport {
+ t.nextProtoOnce.Do(t.onceSetNextProtoDefaults)
+ t2 := &Transport{
+ Proxy: t.Proxy,
+ OnProxyConnectResponse: t.OnProxyConnectResponse,
+ DialContext: t.DialContext,
+ Dial: t.Dial,
+ DialTLS: t.DialTLS,
+ DialTLSContext: t.DialTLSContext,
+ TLSHandshakeTimeout: t.TLSHandshakeTimeout,
+ DisableKeepAlives: t.DisableKeepAlives,
+ DisableCompression: t.DisableCompression,
+ MaxIdleConns: t.MaxIdleConns,
+ MaxIdleConnsPerHost: t.MaxIdleConnsPerHost,
+ MaxConnsPerHost: t.MaxConnsPerHost,
+ IdleConnTimeout: t.IdleConnTimeout,
+ ResponseHeaderTimeout: t.ResponseHeaderTimeout,
+ ExpectContinueTimeout: t.ExpectContinueTimeout,
+ ProxyConnectHeader: t.ProxyConnectHeader.Clone(),
+ GetProxyConnectHeader: t.GetProxyConnectHeader,
+ MaxResponseHeaderBytes: t.MaxResponseHeaderBytes,
+ ForceAttemptHTTP2: t.ForceAttemptHTTP2,
+ WriteBufferSize: t.WriteBufferSize,
+ ReadBufferSize: t.ReadBufferSize,
+ }
+ if t.TLSClientConfig != nil {
+ t2.TLSClientConfig = t.TLSClientConfig.Clone()
+ }
+ if !t.tlsNextProtoWasNil {
+ npm := map[string]func(authority string, c *tls.Conn) RoundTripper{}
+ for k, v := range t.TLSNextProto {
+ npm[k] = v
+ }
+ t2.TLSNextProto = npm
+ }
+ return t2
+}
+
+// h2Transport is the interface we expect to be able to call from
+// net/http against an *http2.Transport that's either bundled into
+// h2_bundle.go or supplied by the user via x/net/http2.
+//
+// We name it with the "h2" prefix to stay out of the "http2" prefix
+// namespace used by x/tools/cmd/bundle for h2_bundle.go.
+type h2Transport interface {
+ CloseIdleConnections()
+}
+
+func (t *Transport) hasCustomTLSDialer() bool {
+ return t.DialTLS != nil || t.DialTLSContext != nil
+}
+
+var http2client = godebug.New("http2client")
+
+// onceSetNextProtoDefaults initializes TLSNextProto.
+// It must be called via t.nextProtoOnce.Do.
+func (t *Transport) onceSetNextProtoDefaults() {
+ t.tlsNextProtoWasNil = (t.TLSNextProto == nil)
+ if http2client.Value() == "0" {
+ http2client.IncNonDefault()
+ return
+ }
+
+ // If they've already configured http2 with
+ // golang.org/x/net/http2 instead of the bundled copy, try to
+ // get at its http2.Transport value (via the "https"
+ // altproto map) so we can call CloseIdleConnections on it if
+ // requested. (Issue 22891)
+ altProto, _ := t.altProto.Load().(map[string]RoundTripper)
+ if rv := reflect.ValueOf(altProto["https"]); rv.IsValid() && rv.Type().Kind() == reflect.Struct && rv.Type().NumField() == 1 {
+ if v := rv.Field(0); v.CanInterface() {
+ if h2i, ok := v.Interface().(h2Transport); ok {
+ t.h2transport = h2i
+ return
+ }
+ }
+ }
+
+ if t.TLSNextProto != nil {
+ // This is the documented way to disable http2 on a
+ // Transport.
+ return
+ }
+ if !t.ForceAttemptHTTP2 && (t.TLSClientConfig != nil || t.Dial != nil || t.DialContext != nil || t.hasCustomTLSDialer()) {
+ // Be conservative and don't automatically enable
+ // http2 if they've specified a custom TLS config or
+ // custom dialers. Let them opt-in themselves via
+ // http2.ConfigureTransport so we don't surprise them
+ // by modifying their tls.Config. Issue 14275.
+ // However, if ForceAttemptHTTP2 is true, it overrides the above checks.
+ return
+ }
+ if omitBundledHTTP2 {
+ return
+ }
+ t2, err := http2configureTransports(t)
+ if err != nil {
+ log.Printf("Error enabling Transport HTTP/2 support: %v", err)
+ return
+ }
+ t.h2transport = t2
+
+ // Auto-configure the http2.Transport's MaxHeaderListSize from
+ // the http.Transport's MaxResponseHeaderBytes. They don't
+ // exactly mean the same thing, but they're close.
+ //
+ // TODO: also add this to x/net/http2.Configure Transport, behind
+ // a +build go1.7 build tag:
+ if limit1 := t.MaxResponseHeaderBytes; limit1 != 0 && t2.MaxHeaderListSize == 0 {
+ const h2max = 1<<32 - 1
+ if limit1 >= h2max {
+ t2.MaxHeaderListSize = h2max
+ } else {
+ t2.MaxHeaderListSize = uint32(limit1)
+ }
+ }
+}
+
+// ProxyFromEnvironment returns the URL of the proxy to use for a
+// given request, as indicated by the environment variables
+// HTTP_PROXY, HTTPS_PROXY and NO_PROXY (or the lowercase versions
+// thereof). Requests use the proxy from the environment variable
+// matching their scheme, unless excluded by NO_PROXY.
+//
+// The environment values may be either a complete URL or a
+// "host[:port]", in which case the "http" scheme is assumed.
+// The schemes "http", "https", and "socks5" are supported.
+// An error is returned if the value is a different form.
+//
+// A nil URL and nil error are returned if no proxy is defined in the
+// environment, or a proxy should not be used for the given request,
+// as defined by NO_PROXY.
+//
+// As a special case, if req.URL.Host is "localhost" (with or without
+// a port number), then a nil URL and nil error will be returned.
+func ProxyFromEnvironment(req *Request) (*url.URL, error) {
+ return envProxyFunc()(req.URL)
+}
+
+// ProxyURL returns a proxy function (for use in a Transport)
+// that always returns the same URL.
+func ProxyURL(fixedURL *url.URL) func(*Request) (*url.URL, error) {
+ return func(*Request) (*url.URL, error) {
+ return fixedURL, nil
+ }
+}
+
+// transportRequest is a wrapper around a *Request that adds
+// optional extra headers to write and stores any error to return
+// from roundTrip.
+type transportRequest struct {
+ *Request // original request, not to be mutated
+ extra Header // extra headers to write, or nil
+ trace *httptrace.ClientTrace // optional
+ cancelKey cancelKey
+
+ mu sync.Mutex // guards err
+ err error // first setError value for mapRoundTripError to consider
+}
+
+func (tr *transportRequest) extraHeaders() Header {
+ if tr.extra == nil {
+ tr.extra = make(Header)
+ }
+ return tr.extra
+}
+
+func (tr *transportRequest) setError(err error) {
+ tr.mu.Lock()
+ if tr.err == nil {
+ tr.err = err
+ }
+ tr.mu.Unlock()
+}
+
+// useRegisteredProtocol reports whether an alternate protocol (as registered
+// with Transport.RegisterProtocol) should be respected for this request.
+func (t *Transport) useRegisteredProtocol(req *Request) bool {
+ if req.URL.Scheme == "https" && req.requiresHTTP1() {
+ // If this request requires HTTP/1, don't use the
+ // "https" alternate protocol, which is used by the
+ // HTTP/2 code to take over requests if there's an
+ // existing cached HTTP/2 connection.
+ return false
+ }
+ return true
+}
+
+// alternateRoundTripper returns the alternate RoundTripper to use
+// for this request if the Request's URL scheme requires one,
+// or nil for the normal case of using the Transport.
+func (t *Transport) alternateRoundTripper(req *Request) RoundTripper {
+ if !t.useRegisteredProtocol(req) {
+ return nil
+ }
+ altProto, _ := t.altProto.Load().(map[string]RoundTripper)
+ return altProto[req.URL.Scheme]
+}
+
+// roundTrip implements a RoundTripper over HTTP.
+func (t *Transport) roundTrip(req *Request) (*Response, error) {
+ t.nextProtoOnce.Do(t.onceSetNextProtoDefaults)
+ ctx := req.Context()
+ trace := httptrace.ContextClientTrace(ctx)
+
+ if req.URL == nil {
+ req.closeBody()
+ return nil, errors.New("http: nil Request.URL")
+ }
+ if req.Header == nil {
+ req.closeBody()
+ return nil, errors.New("http: nil Request.Header")
+ }
+ scheme := req.URL.Scheme
+ isHTTP := scheme == "http" || scheme == "https"
+ if isHTTP {
+ for k, vv := range req.Header {
+ if !httpguts.ValidHeaderFieldName(k) {
+ req.closeBody()
+ return nil, fmt.Errorf("net/http: invalid header field name %q", k)
+ }
+ for _, v := range vv {
+ if !httpguts.ValidHeaderFieldValue(v) {
+ req.closeBody()
+ // Don't include the value in the error, because it may be sensitive.
+ return nil, fmt.Errorf("net/http: invalid header field value for %q", k)
+ }
+ }
+ }
+ }
+
+ origReq := req
+ cancelKey := cancelKey{origReq}
+ req = setupRewindBody(req)
+
+ if altRT := t.alternateRoundTripper(req); altRT != nil {
+ if resp, err := altRT.RoundTrip(req); err != ErrSkipAltProtocol {
+ return resp, err
+ }
+ var err error
+ req, err = rewindBody(req)
+ if err != nil {
+ return nil, err
+ }
+ }
+ if !isHTTP {
+ req.closeBody()
+ return nil, badStringError("unsupported protocol scheme", scheme)
+ }
+ if req.Method != "" && !validMethod(req.Method) {
+ req.closeBody()
+ return nil, fmt.Errorf("net/http: invalid method %q", req.Method)
+ }
+ if req.URL.Host == "" {
+ req.closeBody()
+ return nil, errors.New("http: no Host in request URL")
+ }
+
+ for {
+ select {
+ case <-ctx.Done():
+ req.closeBody()
+ return nil, ctx.Err()
+ default:
+ }
+
+ // treq gets modified by roundTrip, so we need to recreate for each retry.
+ treq := &transportRequest{Request: req, trace: trace, cancelKey: cancelKey}
+ cm, err := t.connectMethodForRequest(treq)
+ if err != nil {
+ req.closeBody()
+ return nil, err
+ }
+
+ // Get the cached or newly-created connection to either the
+ // host (for http or https), the http proxy, or the http proxy
+ // pre-CONNECTed to https server. In any case, we'll be ready
+ // to send it requests.
+ pconn, err := t.getConn(treq, cm)
+ if err != nil {
+ t.setReqCanceler(cancelKey, nil)
+ req.closeBody()
+ return nil, err
+ }
+
+ var resp *Response
+ if pconn.alt != nil {
+ // HTTP/2 path.
+ t.setReqCanceler(cancelKey, nil) // not cancelable with CancelRequest
+ resp, err = pconn.alt.RoundTrip(req)
+ } else {
+ resp, err = pconn.roundTrip(treq)
+ }
+ if err == nil {
+ resp.Request = origReq
+ return resp, nil
+ }
+
+ // Failed. Clean up and determine whether to retry.
+ if http2isNoCachedConnError(err) {
+ if t.removeIdleConn(pconn) {
+ t.decConnsPerHost(pconn.cacheKey)
+ }
+ } else if !pconn.shouldRetryRequest(req, err) {
+ // Issue 16465: return underlying net.Conn.Read error from peek,
+ // as we've historically done.
+ if e, ok := err.(nothingWrittenError); ok {
+ err = e.error
+ }
+ if e, ok := err.(transportReadFromServerError); ok {
+ err = e.err
+ }
+ if b, ok := req.Body.(*readTrackingBody); ok && !b.didClose {
+ // Issue 49621: Close the request body if pconn.roundTrip
+ // didn't do so already. This can happen if the pconn
+ // write loop exits without reading the write request.
+ req.closeBody()
+ }
+ return nil, err
+ }
+ testHookRoundTripRetried()
+
+ // Rewind the body if we're able to.
+ req, err = rewindBody(req)
+ if err != nil {
+ return nil, err
+ }
+ }
+}
+
+var errCannotRewind = errors.New("net/http: cannot rewind body after connection loss")
+
+type readTrackingBody struct {
+ io.ReadCloser
+ didRead bool
+ didClose bool
+}
+
+func (r *readTrackingBody) Read(data []byte) (int, error) {
+ r.didRead = true
+ return r.ReadCloser.Read(data)
+}
+
+func (r *readTrackingBody) Close() error {
+ r.didClose = true
+ return r.ReadCloser.Close()
+}
+
+// setupRewindBody returns a new request with a custom body wrapper
+// that can report whether the body needs rewinding.
+// This lets rewindBody avoid an error result when the request
+// does not have GetBody but the body hasn't been read at all yet.
+func setupRewindBody(req *Request) *Request {
+ if req.Body == nil || req.Body == NoBody {
+ return req
+ }
+ newReq := *req
+ newReq.Body = &readTrackingBody{ReadCloser: req.Body}
+ return &newReq
+}
+
+// rewindBody returns a new request with the body rewound.
+// It returns req unmodified if the body does not need rewinding.
+// rewindBody takes care of closing req.Body when appropriate
+// (in all cases except when rewindBody returns req unmodified).
+func rewindBody(req *Request) (rewound *Request, err error) {
+ if req.Body == nil || req.Body == NoBody || (!req.Body.(*readTrackingBody).didRead && !req.Body.(*readTrackingBody).didClose) {
+ return req, nil // nothing to rewind
+ }
+ if !req.Body.(*readTrackingBody).didClose {
+ req.closeBody()
+ }
+ if req.GetBody == nil {
+ return nil, errCannotRewind
+ }
+ body, err := req.GetBody()
+ if err != nil {
+ return nil, err
+ }
+ newReq := *req
+ newReq.Body = &readTrackingBody{ReadCloser: body}
+ return &newReq, nil
+}
+
+// shouldRetryRequest reports whether we should retry sending a failed
+// HTTP request on a new connection. The non-nil input error is the
+// error from roundTrip.
+func (pc *persistConn) shouldRetryRequest(req *Request, err error) bool {
+ if http2isNoCachedConnError(err) {
+ // Issue 16582: if the user started a bunch of
+ // requests at once, they can all pick the same conn
+ // and violate the server's max concurrent streams.
+ // Instead, match the HTTP/1 behavior for now and dial
+ // again to get a new TCP connection, rather than failing
+ // this request.
+ return true
+ }
+ if err == errMissingHost {
+ // User error.
+ return false
+ }
+ if !pc.isReused() {
+ // This was a fresh connection. There's no reason the server
+ // should've hung up on us.
+ //
+ // Also, if we retried now, we could loop forever
+ // creating new connections and retrying if the server
+ // is just hanging up on us because it doesn't like
+ // our request (as opposed to sending an error).
+ return false
+ }
+ if _, ok := err.(nothingWrittenError); ok {
+ // We never wrote anything, so it's safe to retry, if there's no body or we
+ // can "rewind" the body with GetBody.
+ return req.outgoingLength() == 0 || req.GetBody != nil
+ }
+ if !req.isReplayable() {
+ // Don't retry non-idempotent requests.
+ return false
+ }
+ if _, ok := err.(transportReadFromServerError); ok {
+ // We got some non-EOF net.Conn.Read failure reading
+ // the 1st response byte from the server.
+ return true
+ }
+ if err == errServerClosedIdle {
+ // The server replied with io.EOF while we were trying to
+ // read the response. Probably an unfortunately keep-alive
+ // timeout, just as the client was writing a request.
+ return true
+ }
+ return false // conservatively
+}
+
+// ErrSkipAltProtocol is a sentinel error value defined by Transport.RegisterProtocol.
+var ErrSkipAltProtocol = errors.New("net/http: skip alternate protocol")
+
+// RegisterProtocol registers a new protocol with scheme.
+// The Transport will pass requests using the given scheme to rt.
+// It is rt's responsibility to simulate HTTP request semantics.
+//
+// RegisterProtocol can be used by other packages to provide
+// implementations of protocol schemes like "ftp" or "file".
+//
+// If rt.RoundTrip returns ErrSkipAltProtocol, the Transport will
+// handle the RoundTrip itself for that one request, as if the
+// protocol were not registered.
+func (t *Transport) RegisterProtocol(scheme string, rt RoundTripper) {
+ t.altMu.Lock()
+ defer t.altMu.Unlock()
+ oldMap, _ := t.altProto.Load().(map[string]RoundTripper)
+ if _, exists := oldMap[scheme]; exists {
+ panic("protocol " + scheme + " already registered")
+ }
+ newMap := make(map[string]RoundTripper)
+ for k, v := range oldMap {
+ newMap[k] = v
+ }
+ newMap[scheme] = rt
+ t.altProto.Store(newMap)
+}
+
+// CloseIdleConnections closes any connections which were previously
+// connected from previous requests but are now sitting idle in
+// a "keep-alive" state. It does not interrupt any connections currently
+// in use.
+func (t *Transport) CloseIdleConnections() {
+ t.nextProtoOnce.Do(t.onceSetNextProtoDefaults)
+ t.idleMu.Lock()
+ m := t.idleConn
+ t.idleConn = nil
+ t.closeIdle = true // close newly idle connections
+ t.idleLRU = connLRU{}
+ t.idleMu.Unlock()
+ for _, conns := range m {
+ for _, pconn := range conns {
+ pconn.close(errCloseIdleConns)
+ }
+ }
+ if t2 := t.h2transport; t2 != nil {
+ t2.CloseIdleConnections()
+ }
+}
+
+// CancelRequest cancels an in-flight request by closing its connection.
+// CancelRequest should only be called after RoundTrip has returned.
+//
+// Deprecated: Use Request.WithContext to create a request with a
+// cancelable context instead. CancelRequest cannot cancel HTTP/2
+// requests.
+func (t *Transport) CancelRequest(req *Request) {
+ t.cancelRequest(cancelKey{req}, errRequestCanceled)
+}
+
+// Cancel an in-flight request, recording the error value.
+// Returns whether the request was canceled.
+func (t *Transport) cancelRequest(key cancelKey, err error) bool {
+ // This function must not return until the cancel func has completed.
+ // See: https://golang.org/issue/34658
+ t.reqMu.Lock()
+ defer t.reqMu.Unlock()
+ cancel := t.reqCanceler[key]
+ delete(t.reqCanceler, key)
+ if cancel != nil {
+ cancel(err)
+ }
+
+ return cancel != nil
+}
+
+//
+// Private implementation past this point.
+//
+
+var (
+ envProxyOnce sync.Once
+ envProxyFuncValue func(*url.URL) (*url.URL, error)
+)
+
+// envProxyFunc returns a function that reads the
+// environment variable to determine the proxy address.
+func envProxyFunc() func(*url.URL) (*url.URL, error) {
+ envProxyOnce.Do(func() {
+ envProxyFuncValue = httpproxy.FromEnvironment().ProxyFunc()
+ })
+ return envProxyFuncValue
+}
+
+// resetProxyConfig is used by tests.
+func resetProxyConfig() {
+ envProxyOnce = sync.Once{}
+ envProxyFuncValue = nil
+}
+
+func (t *Transport) connectMethodForRequest(treq *transportRequest) (cm connectMethod, err error) {
+ cm.targetScheme = treq.URL.Scheme
+ cm.targetAddr = canonicalAddr(treq.URL)
+ if t.Proxy != nil {
+ cm.proxyURL, err = t.Proxy(treq.Request)
+ }
+ cm.onlyH1 = treq.requiresHTTP1()
+ return cm, err
+}
+
+// proxyAuth returns the Proxy-Authorization header to set
+// on requests, if applicable.
+func (cm *connectMethod) proxyAuth() string {
+ if cm.proxyURL == nil {
+ return ""
+ }
+ if u := cm.proxyURL.User; u != nil {
+ username := u.Username()
+ password, _ := u.Password()
+ return "Basic " + basicAuth(username, password)
+ }
+ return ""
+}
+
+// error values for debugging and testing, not seen by users.
+var (
+ errKeepAlivesDisabled = errors.New("http: putIdleConn: keep alives disabled")
+ errConnBroken = errors.New("http: putIdleConn: connection is in bad state")
+ errCloseIdle = errors.New("http: putIdleConn: CloseIdleConnections was called")
+ errTooManyIdle = errors.New("http: putIdleConn: too many idle connections")
+ errTooManyIdleHost = errors.New("http: putIdleConn: too many idle connections for host")
+ errCloseIdleConns = errors.New("http: CloseIdleConnections called")
+ errReadLoopExiting = errors.New("http: persistConn.readLoop exiting")
+ errIdleConnTimeout = errors.New("http: idle connection timeout")
+
+ // errServerClosedIdle is not seen by users for idempotent requests, but may be
+ // seen by a user if the server shuts down an idle connection and sends its FIN
+ // in flight with already-written POST body bytes from the client.
+ // See https://github.com/golang/go/issues/19943#issuecomment-355607646
+ errServerClosedIdle = errors.New("http: server closed idle connection")
+)
+
+// transportReadFromServerError is used by Transport.readLoop when the
+// 1 byte peek read fails and we're actually anticipating a response.
+// Usually this is just due to the inherent keep-alive shut down race,
+// where the server closed the connection at the same time the client
+// wrote. The underlying err field is usually io.EOF or some
+// ECONNRESET sort of thing which varies by platform. But it might be
+// the user's custom net.Conn.Read error too, so we carry it along for
+// them to return from Transport.RoundTrip.
+type transportReadFromServerError struct {
+ err error
+}
+
+func (e transportReadFromServerError) Unwrap() error { return e.err }
+
+func (e transportReadFromServerError) Error() string {
+ return fmt.Sprintf("net/http: Transport failed to read from server: %v", e.err)
+}
+
+func (t *Transport) putOrCloseIdleConn(pconn *persistConn) {
+ if err := t.tryPutIdleConn(pconn); err != nil {
+ pconn.close(err)
+ }
+}
+
+func (t *Transport) maxIdleConnsPerHost() int {
+ if v := t.MaxIdleConnsPerHost; v != 0 {
+ return v
+ }
+ return DefaultMaxIdleConnsPerHost
+}
+
+// tryPutIdleConn adds pconn to the list of idle persistent connections awaiting
+// a new request.
+// If pconn is no longer needed or not in a good state, tryPutIdleConn returns
+// an error explaining why it wasn't registered.
+// tryPutIdleConn does not close pconn. Use putOrCloseIdleConn instead for that.
+func (t *Transport) tryPutIdleConn(pconn *persistConn) error {
+ if t.DisableKeepAlives || t.MaxIdleConnsPerHost < 0 {
+ return errKeepAlivesDisabled
+ }
+ if pconn.isBroken() {
+ return errConnBroken
+ }
+ pconn.markReused()
+
+ t.idleMu.Lock()
+ defer t.idleMu.Unlock()
+
+ // HTTP/2 (pconn.alt != nil) connections do not come out of the idle list,
+ // because multiple goroutines can use them simultaneously.
+ // If this is an HTTP/2 connection being “returned,” we're done.
+ if pconn.alt != nil && t.idleLRU.m[pconn] != nil {
+ return nil
+ }
+
+ // Deliver pconn to goroutine waiting for idle connection, if any.
+ // (They may be actively dialing, but this conn is ready first.
+ // Chrome calls this socket late binding.
+ // See https://www.chromium.org/developers/design-documents/network-stack#TOC-Connection-Management.)
+ key := pconn.cacheKey
+ if q, ok := t.idleConnWait[key]; ok {
+ done := false
+ if pconn.alt == nil {
+ // HTTP/1.
+ // Loop over the waiting list until we find a w that isn't done already, and hand it pconn.
+ for q.len() > 0 {
+ w := q.popFront()
+ if w.tryDeliver(pconn, nil) {
+ done = true
+ break
+ }
+ }
+ } else {
+ // HTTP/2.
+ // Can hand the same pconn to everyone in the waiting list,
+ // and we still won't be done: we want to put it in the idle
+ // list unconditionally, for any future clients too.
+ for q.len() > 0 {
+ w := q.popFront()
+ w.tryDeliver(pconn, nil)
+ }
+ }
+ if q.len() == 0 {
+ delete(t.idleConnWait, key)
+ } else {
+ t.idleConnWait[key] = q
+ }
+ if done {
+ return nil
+ }
+ }
+
+ if t.closeIdle {
+ return errCloseIdle
+ }
+ if t.idleConn == nil {
+ t.idleConn = make(map[connectMethodKey][]*persistConn)
+ }
+ idles := t.idleConn[key]
+ if len(idles) >= t.maxIdleConnsPerHost() {
+ return errTooManyIdleHost
+ }
+ for _, exist := range idles {
+ if exist == pconn {
+ log.Fatalf("dup idle pconn %p in freelist", pconn)
+ }
+ }
+ t.idleConn[key] = append(idles, pconn)
+ t.idleLRU.add(pconn)
+ if t.MaxIdleConns != 0 && t.idleLRU.len() > t.MaxIdleConns {
+ oldest := t.idleLRU.removeOldest()
+ oldest.close(errTooManyIdle)
+ t.removeIdleConnLocked(oldest)
+ }
+
+ // Set idle timer, but only for HTTP/1 (pconn.alt == nil).
+ // The HTTP/2 implementation manages the idle timer itself
+ // (see idleConnTimeout in h2_bundle.go).
+ if t.IdleConnTimeout > 0 && pconn.alt == nil {
+ if pconn.idleTimer != nil {
+ pconn.idleTimer.Reset(t.IdleConnTimeout)
+ } else {
+ pconn.idleTimer = time.AfterFunc(t.IdleConnTimeout, pconn.closeConnIfStillIdle)
+ }
+ }
+ pconn.idleAt = time.Now()
+ return nil
+}
+
+// queueForIdleConn queues w to receive the next idle connection for w.cm.
+// As an optimization hint to the caller, queueForIdleConn reports whether
+// it successfully delivered an already-idle connection.
+func (t *Transport) queueForIdleConn(w *wantConn) (delivered bool) {
+ if t.DisableKeepAlives {
+ return false
+ }
+
+ t.idleMu.Lock()
+ defer t.idleMu.Unlock()
+
+ // Stop closing connections that become idle - we might want one.
+ // (That is, undo the effect of t.CloseIdleConnections.)
+ t.closeIdle = false
+
+ if w == nil {
+ // Happens in test hook.
+ return false
+ }
+
+ // If IdleConnTimeout is set, calculate the oldest
+ // persistConn.idleAt time we're willing to use a cached idle
+ // conn.
+ var oldTime time.Time
+ if t.IdleConnTimeout > 0 {
+ oldTime = time.Now().Add(-t.IdleConnTimeout)
+ }
+
+ // Look for most recently-used idle connection.
+ if list, ok := t.idleConn[w.key]; ok {
+ stop := false
+ delivered := false
+ for len(list) > 0 && !stop {
+ pconn := list[len(list)-1]
+
+ // See whether this connection has been idle too long, considering
+ // only the wall time (the Round(0)), in case this is a laptop or VM
+ // coming out of suspend with previously cached idle connections.
+ tooOld := !oldTime.IsZero() && pconn.idleAt.Round(0).Before(oldTime)
+ if tooOld {
+ // Async cleanup. Launch in its own goroutine (as if a
+ // time.AfterFunc called it); it acquires idleMu, which we're
+ // holding, and does a synchronous net.Conn.Close.
+ go pconn.closeConnIfStillIdle()
+ }
+ if pconn.isBroken() || tooOld {
+ // If either persistConn.readLoop has marked the connection
+ // broken, but Transport.removeIdleConn has not yet removed it
+ // from the idle list, or if this persistConn is too old (it was
+ // idle too long), then ignore it and look for another. In both
+ // cases it's already in the process of being closed.
+ list = list[:len(list)-1]
+ continue
+ }
+ delivered = w.tryDeliver(pconn, nil)
+ if delivered {
+ if pconn.alt != nil {
+ // HTTP/2: multiple clients can share pconn.
+ // Leave it in the list.
+ } else {
+ // HTTP/1: only one client can use pconn.
+ // Remove it from the list.
+ t.idleLRU.remove(pconn)
+ list = list[:len(list)-1]
+ }
+ }
+ stop = true
+ }
+ if len(list) > 0 {
+ t.idleConn[w.key] = list
+ } else {
+ delete(t.idleConn, w.key)
+ }
+ if stop {
+ return delivered
+ }
+ }
+
+ // Register to receive next connection that becomes idle.
+ if t.idleConnWait == nil {
+ t.idleConnWait = make(map[connectMethodKey]wantConnQueue)
+ }
+ q := t.idleConnWait[w.key]
+ q.cleanFront()
+ q.pushBack(w)
+ t.idleConnWait[w.key] = q
+ return false
+}
+
+// removeIdleConn marks pconn as dead.
+func (t *Transport) removeIdleConn(pconn *persistConn) bool {
+ t.idleMu.Lock()
+ defer t.idleMu.Unlock()
+ return t.removeIdleConnLocked(pconn)
+}
+
+// t.idleMu must be held.
+func (t *Transport) removeIdleConnLocked(pconn *persistConn) bool {
+ if pconn.idleTimer != nil {
+ pconn.idleTimer.Stop()
+ }
+ t.idleLRU.remove(pconn)
+ key := pconn.cacheKey
+ pconns := t.idleConn[key]
+ var removed bool
+ switch len(pconns) {
+ case 0:
+ // Nothing
+ case 1:
+ if pconns[0] == pconn {
+ delete(t.idleConn, key)
+ removed = true
+ }
+ default:
+ for i, v := range pconns {
+ if v != pconn {
+ continue
+ }
+ // Slide down, keeping most recently-used
+ // conns at the end.
+ copy(pconns[i:], pconns[i+1:])
+ t.idleConn[key] = pconns[:len(pconns)-1]
+ removed = true
+ break
+ }
+ }
+ return removed
+}
+
+func (t *Transport) setReqCanceler(key cancelKey, fn func(error)) {
+ t.reqMu.Lock()
+ defer t.reqMu.Unlock()
+ if t.reqCanceler == nil {
+ t.reqCanceler = make(map[cancelKey]func(error))
+ }
+ if fn != nil {
+ t.reqCanceler[key] = fn
+ } else {
+ delete(t.reqCanceler, key)
+ }
+}
+
+// replaceReqCanceler replaces an existing cancel function. If there is no cancel function
+// for the request, we don't set the function and return false.
+// Since CancelRequest will clear the canceler, we can use the return value to detect if
+// the request was canceled since the last setReqCancel call.
+func (t *Transport) replaceReqCanceler(key cancelKey, fn func(error)) bool {
+ t.reqMu.Lock()
+ defer t.reqMu.Unlock()
+ _, ok := t.reqCanceler[key]
+ if !ok {
+ return false
+ }
+ if fn != nil {
+ t.reqCanceler[key] = fn
+ } else {
+ delete(t.reqCanceler, key)
+ }
+ return true
+}
+
+var zeroDialer net.Dialer
+
+func (t *Transport) dial(ctx context.Context, network, addr string) (net.Conn, error) {
+ if t.DialContext != nil {
+ c, err := t.DialContext(ctx, network, addr)
+ if c == nil && err == nil {
+ err = errors.New("net/http: Transport.DialContext hook returned (nil, nil)")
+ }
+ return c, err
+ }
+ if t.Dial != nil {
+ c, err := t.Dial(network, addr)
+ if c == nil && err == nil {
+ err = errors.New("net/http: Transport.Dial hook returned (nil, nil)")
+ }
+ return c, err
+ }
+ return zeroDialer.DialContext(ctx, network, addr)
+}
+
+// A wantConn records state about a wanted connection
+// (that is, an active call to getConn).
+// The conn may be gotten by dialing or by finding an idle connection,
+// or a cancellation may make the conn no longer wanted.
+// These three options are racing against each other and use
+// wantConn to coordinate and agree about the winning outcome.
+type wantConn struct {
+ cm connectMethod
+ key connectMethodKey // cm.key()
+ ctx context.Context // context for dial
+ ready chan struct{} // closed when pc, err pair is delivered
+
+ // hooks for testing to know when dials are done
+ // beforeDial is called in the getConn goroutine when the dial is queued.
+ // afterDial is called when the dial is completed or canceled.
+ beforeDial func()
+ afterDial func()
+
+ mu sync.Mutex // protects pc, err, close(ready)
+ pc *persistConn
+ err error
+}
+
+// waiting reports whether w is still waiting for an answer (connection or error).
+func (w *wantConn) waiting() bool {
+ select {
+ case <-w.ready:
+ return false
+ default:
+ return true
+ }
+}
+
+// tryDeliver attempts to deliver pc, err to w and reports whether it succeeded.
+func (w *wantConn) tryDeliver(pc *persistConn, err error) bool {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+
+ if w.pc != nil || w.err != nil {
+ return false
+ }
+
+ w.pc = pc
+ w.err = err
+ if w.pc == nil && w.err == nil {
+ panic("net/http: internal error: misuse of tryDeliver")
+ }
+ close(w.ready)
+ return true
+}
+
+// cancel marks w as no longer wanting a result (for example, due to cancellation).
+// If a connection has been delivered already, cancel returns it with t.putOrCloseIdleConn.
+func (w *wantConn) cancel(t *Transport, err error) {
+ w.mu.Lock()
+ if w.pc == nil && w.err == nil {
+ close(w.ready) // catch misbehavior in future delivery
+ }
+ pc := w.pc
+ w.pc = nil
+ w.err = err
+ w.mu.Unlock()
+
+ if pc != nil {
+ t.putOrCloseIdleConn(pc)
+ }
+}
+
+// A wantConnQueue is a queue of wantConns.
+type wantConnQueue struct {
+ // This is a queue, not a deque.
+ // It is split into two stages - head[headPos:] and tail.
+ // popFront is trivial (headPos++) on the first stage, and
+ // pushBack is trivial (append) on the second stage.
+ // If the first stage is empty, popFront can swap the
+ // first and second stages to remedy the situation.
+ //
+ // This two-stage split is analogous to the use of two lists
+ // in Okasaki's purely functional queue but without the
+ // overhead of reversing the list when swapping stages.
+ head []*wantConn
+ headPos int
+ tail []*wantConn
+}
+
+// len returns the number of items in the queue.
+func (q *wantConnQueue) len() int {
+ return len(q.head) - q.headPos + len(q.tail)
+}
+
+// pushBack adds w to the back of the queue.
+func (q *wantConnQueue) pushBack(w *wantConn) {
+ q.tail = append(q.tail, w)
+}
+
+// popFront removes and returns the wantConn at the front of the queue.
+func (q *wantConnQueue) popFront() *wantConn {
+ if q.headPos >= len(q.head) {
+ if len(q.tail) == 0 {
+ return nil
+ }
+ // Pick up tail as new head, clear tail.
+ q.head, q.headPos, q.tail = q.tail, 0, q.head[:0]
+ }
+ w := q.head[q.headPos]
+ q.head[q.headPos] = nil
+ q.headPos++
+ return w
+}
+
+// peekFront returns the wantConn at the front of the queue without removing it.
+func (q *wantConnQueue) peekFront() *wantConn {
+ if q.headPos < len(q.head) {
+ return q.head[q.headPos]
+ }
+ if len(q.tail) > 0 {
+ return q.tail[0]
+ }
+ return nil
+}
+
+// cleanFront pops any wantConns that are no longer waiting from the head of the
+// queue, reporting whether any were popped.
+func (q *wantConnQueue) cleanFront() (cleaned bool) {
+ for {
+ w := q.peekFront()
+ if w == nil || w.waiting() {
+ return cleaned
+ }
+ q.popFront()
+ cleaned = true
+ }
+}
+
+func (t *Transport) customDialTLS(ctx context.Context, network, addr string) (conn net.Conn, err error) {
+ if t.DialTLSContext != nil {
+ conn, err = t.DialTLSContext(ctx, network, addr)
+ } else {
+ conn, err = t.DialTLS(network, addr)
+ }
+ if conn == nil && err == nil {
+ err = errors.New("net/http: Transport.DialTLS or DialTLSContext returned (nil, nil)")
+ }
+ return
+}
+
+// getConn dials and creates a new persistConn to the target as
+// specified in the connectMethod. This includes doing a proxy CONNECT
+// and/or setting up TLS. If this doesn't return an error, the persistConn
+// is ready to write requests to.
+func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persistConn, err error) {
+ req := treq.Request
+ trace := treq.trace
+ ctx := req.Context()
+ if trace != nil && trace.GetConn != nil {
+ trace.GetConn(cm.addr())
+ }
+
+ w := &wantConn{
+ cm: cm,
+ key: cm.key(),
+ ctx: ctx,
+ ready: make(chan struct{}, 1),
+ beforeDial: testHookPrePendingDial,
+ afterDial: testHookPostPendingDial,
+ }
+ defer func() {
+ if err != nil {
+ w.cancel(t, err)
+ }
+ }()
+
+ // Queue for idle connection.
+ if delivered := t.queueForIdleConn(w); delivered {
+ pc := w.pc
+ // Trace only for HTTP/1.
+ // HTTP/2 calls trace.GotConn itself.
+ if pc.alt == nil && trace != nil && trace.GotConn != nil {
+ trace.GotConn(pc.gotIdleConnTrace(pc.idleAt))
+ }
+ // set request canceler to some non-nil function so we
+ // can detect whether it was cleared between now and when
+ // we enter roundTrip
+ t.setReqCanceler(treq.cancelKey, func(error) {})
+ return pc, nil
+ }
+
+ cancelc := make(chan error, 1)
+ t.setReqCanceler(treq.cancelKey, func(err error) { cancelc <- err })
+
+ // Queue for permission to dial.
+ t.queueForDial(w)
+
+ // Wait for completion or cancellation.
+ select {
+ case <-w.ready:
+ // Trace success but only for HTTP/1.
+ // HTTP/2 calls trace.GotConn itself.
+ if w.pc != nil && w.pc.alt == nil && trace != nil && trace.GotConn != nil {
+ trace.GotConn(httptrace.GotConnInfo{Conn: w.pc.conn, Reused: w.pc.isReused()})
+ }
+ if w.err != nil {
+ // If the request has been canceled, that's probably
+ // what caused w.err; if so, prefer to return the
+ // cancellation error (see golang.org/issue/16049).
+ select {
+ case <-req.Cancel:
+ return nil, errRequestCanceledConn
+ case <-req.Context().Done():
+ return nil, req.Context().Err()
+ case err := <-cancelc:
+ if err == errRequestCanceled {
+ err = errRequestCanceledConn
+ }
+ return nil, err
+ default:
+ // return below
+ }
+ }
+ return w.pc, w.err
+ case <-req.Cancel:
+ return nil, errRequestCanceledConn
+ case <-req.Context().Done():
+ return nil, req.Context().Err()
+ case err := <-cancelc:
+ if err == errRequestCanceled {
+ err = errRequestCanceledConn
+ }
+ return nil, err
+ }
+}
+
+// queueForDial queues w to wait for permission to begin dialing.
+// Once w receives permission to dial, it will do so in a separate goroutine.
+func (t *Transport) queueForDial(w *wantConn) {
+ w.beforeDial()
+ if t.MaxConnsPerHost <= 0 {
+ go t.dialConnFor(w)
+ return
+ }
+
+ t.connsPerHostMu.Lock()
+ defer t.connsPerHostMu.Unlock()
+
+ if n := t.connsPerHost[w.key]; n < t.MaxConnsPerHost {
+ if t.connsPerHost == nil {
+ t.connsPerHost = make(map[connectMethodKey]int)
+ }
+ t.connsPerHost[w.key] = n + 1
+ go t.dialConnFor(w)
+ return
+ }
+
+ if t.connsPerHostWait == nil {
+ t.connsPerHostWait = make(map[connectMethodKey]wantConnQueue)
+ }
+ q := t.connsPerHostWait[w.key]
+ q.cleanFront()
+ q.pushBack(w)
+ t.connsPerHostWait[w.key] = q
+}
+
+// dialConnFor dials on behalf of w and delivers the result to w.
+// dialConnFor has received permission to dial w.cm and is counted in t.connCount[w.cm.key()].
+// If the dial is canceled or unsuccessful, dialConnFor decrements t.connCount[w.cm.key()].
+func (t *Transport) dialConnFor(w *wantConn) {
+ defer w.afterDial()
+
+ pc, err := t.dialConn(w.ctx, w.cm)
+ delivered := w.tryDeliver(pc, err)
+ if err == nil && (!delivered || pc.alt != nil) {
+ // pconn was not passed to w,
+ // or it is HTTP/2 and can be shared.
+ // Add to the idle connection pool.
+ t.putOrCloseIdleConn(pc)
+ }
+ if err != nil {
+ t.decConnsPerHost(w.key)
+ }
+}
+
+// decConnsPerHost decrements the per-host connection count for key,
+// which may in turn give a different waiting goroutine permission to dial.
+func (t *Transport) decConnsPerHost(key connectMethodKey) {
+ if t.MaxConnsPerHost <= 0 {
+ return
+ }
+
+ t.connsPerHostMu.Lock()
+ defer t.connsPerHostMu.Unlock()
+ n := t.connsPerHost[key]
+ if n == 0 {
+ // Shouldn't happen, but if it does, the counting is buggy and could
+ // easily lead to a silent deadlock, so report the problem loudly.
+ panic("net/http: internal error: connCount underflow")
+ }
+
+ // Can we hand this count to a goroutine still waiting to dial?
+ // (Some goroutines on the wait list may have timed out or
+ // gotten a connection another way. If they're all gone,
+ // we don't want to kick off any spurious dial operations.)
+ if q := t.connsPerHostWait[key]; q.len() > 0 {
+ done := false
+ for q.len() > 0 {
+ w := q.popFront()
+ if w.waiting() {
+ go t.dialConnFor(w)
+ done = true
+ break
+ }
+ }
+ if q.len() == 0 {
+ delete(t.connsPerHostWait, key)
+ } else {
+ // q is a value (like a slice), so we have to store
+ // the updated q back into the map.
+ t.connsPerHostWait[key] = q
+ }
+ if done {
+ return
+ }
+ }
+
+ // Otherwise, decrement the recorded count.
+ if n--; n == 0 {
+ delete(t.connsPerHost, key)
+ } else {
+ t.connsPerHost[key] = n
+ }
+}
+
+// Add TLS to a persistent connection, i.e. negotiate a TLS session. If pconn is already a TLS
+// tunnel, this function establishes a nested TLS session inside the encrypted channel.
+// The remote endpoint's name may be overridden by TLSClientConfig.ServerName.
+func (pconn *persistConn) addTLS(ctx context.Context, name string, trace *httptrace.ClientTrace) error {
+ // Initiate TLS and check remote host name against certificate.
+ cfg := cloneTLSConfig(pconn.t.TLSClientConfig)
+ if cfg.ServerName == "" {
+ cfg.ServerName = name
+ }
+ if pconn.cacheKey.onlyH1 {
+ cfg.NextProtos = nil
+ }
+ plainConn := pconn.conn
+ tlsConn := tls.Client(plainConn, cfg)
+ errc := make(chan error, 2)
+ var timer *time.Timer // for canceling TLS handshake
+ if d := pconn.t.TLSHandshakeTimeout; d != 0 {
+ timer = time.AfterFunc(d, func() {
+ errc <- tlsHandshakeTimeoutError{}
+ })
+ }
+ go func() {
+ if trace != nil && trace.TLSHandshakeStart != nil {
+ trace.TLSHandshakeStart()
+ }
+ err := tlsConn.HandshakeContext(ctx)
+ if timer != nil {
+ timer.Stop()
+ }
+ errc <- err
+ }()
+ if err := <-errc; err != nil {
+ plainConn.Close()
+ if trace != nil && trace.TLSHandshakeDone != nil {
+ trace.TLSHandshakeDone(tls.ConnectionState{}, err)
+ }
+ return err
+ }
+ cs := tlsConn.ConnectionState()
+ if trace != nil && trace.TLSHandshakeDone != nil {
+ trace.TLSHandshakeDone(cs, nil)
+ }
+ pconn.tlsState = &cs
+ pconn.conn = tlsConn
+ return nil
+}
+
+type erringRoundTripper interface {
+ RoundTripErr() error
+}
+
+func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *persistConn, err error) {
+ pconn = &persistConn{
+ t: t,
+ cacheKey: cm.key(),
+ reqch: make(chan requestAndChan, 1),
+ writech: make(chan writeRequest, 1),
+ closech: make(chan struct{}),
+ writeErrCh: make(chan error, 1),
+ writeLoopDone: make(chan struct{}),
+ }
+ trace := httptrace.ContextClientTrace(ctx)
+ wrapErr := func(err error) error {
+ if cm.proxyURL != nil {
+ // Return a typed error, per Issue 16997
+ return &net.OpError{Op: "proxyconnect", Net: "tcp", Err: err}
+ }
+ return err
+ }
+ if cm.scheme() == "https" && t.hasCustomTLSDialer() {
+ var err error
+ pconn.conn, err = t.customDialTLS(ctx, "tcp", cm.addr())
+ if err != nil {
+ return nil, wrapErr(err)
+ }
+ if tc, ok := pconn.conn.(*tls.Conn); ok {
+ // Handshake here, in case DialTLS didn't. TLSNextProto below
+ // depends on it for knowing the connection state.
+ if trace != nil && trace.TLSHandshakeStart != nil {
+ trace.TLSHandshakeStart()
+ }
+ if err := tc.HandshakeContext(ctx); err != nil {
+ go pconn.conn.Close()
+ if trace != nil && trace.TLSHandshakeDone != nil {
+ trace.TLSHandshakeDone(tls.ConnectionState{}, err)
+ }
+ return nil, err
+ }
+ cs := tc.ConnectionState()
+ if trace != nil && trace.TLSHandshakeDone != nil {
+ trace.TLSHandshakeDone(cs, nil)
+ }
+ pconn.tlsState = &cs
+ }
+ } else {
+ conn, err := t.dial(ctx, "tcp", cm.addr())
+ if err != nil {
+ return nil, wrapErr(err)
+ }
+ pconn.conn = conn
+ if cm.scheme() == "https" {
+ var firstTLSHost string
+ if firstTLSHost, _, err = net.SplitHostPort(cm.addr()); err != nil {
+ return nil, wrapErr(err)
+ }
+ if err = pconn.addTLS(ctx, firstTLSHost, trace); err != nil {
+ return nil, wrapErr(err)
+ }
+ }
+ }
+
+ // Proxy setup.
+ switch {
+ case cm.proxyURL == nil:
+ // Do nothing. Not using a proxy.
+ case cm.proxyURL.Scheme == "socks5":
+ conn := pconn.conn
+ d := socksNewDialer("tcp", conn.RemoteAddr().String())
+ if u := cm.proxyURL.User; u != nil {
+ auth := &socksUsernamePassword{
+ Username: u.Username(),
+ }
+ auth.Password, _ = u.Password()
+ d.AuthMethods = []socksAuthMethod{
+ socksAuthMethodNotRequired,
+ socksAuthMethodUsernamePassword,
+ }
+ d.Authenticate = auth.Authenticate
+ }
+ if _, err := d.DialWithConn(ctx, conn, "tcp", cm.targetAddr); err != nil {
+ conn.Close()
+ return nil, err
+ }
+ case cm.targetScheme == "http":
+ pconn.isProxy = true
+ if pa := cm.proxyAuth(); pa != "" {
+ pconn.mutateHeaderFunc = func(h Header) {
+ h.Set("Proxy-Authorization", pa)
+ }
+ }
+ case cm.targetScheme == "https":
+ conn := pconn.conn
+ var hdr Header
+ if t.GetProxyConnectHeader != nil {
+ var err error
+ hdr, err = t.GetProxyConnectHeader(ctx, cm.proxyURL, cm.targetAddr)
+ if err != nil {
+ conn.Close()
+ return nil, err
+ }
+ } else {
+ hdr = t.ProxyConnectHeader
+ }
+ if hdr == nil {
+ hdr = make(Header)
+ }
+ if pa := cm.proxyAuth(); pa != "" {
+ hdr = hdr.Clone()
+ hdr.Set("Proxy-Authorization", pa)
+ }
+ connectReq := &Request{
+ Method: "CONNECT",
+ URL: &url.URL{Opaque: cm.targetAddr},
+ Host: cm.targetAddr,
+ Header: hdr,
+ }
+
+ // If there's no done channel (no deadline or cancellation
+ // from the caller possible), at least set some (long)
+ // timeout here. This will make sure we don't block forever
+ // and leak a goroutine if the connection stops replying
+ // after the TCP connect.
+ connectCtx := ctx
+ if ctx.Done() == nil {
+ newCtx, cancel := context.WithTimeout(ctx, 1*time.Minute)
+ defer cancel()
+ connectCtx = newCtx
+ }
+
+ didReadResponse := make(chan struct{}) // closed after CONNECT write+read is done or fails
+ var (
+ resp *Response
+ err error // write or read error
+ )
+ // Write the CONNECT request & read the response.
+ go func() {
+ defer close(didReadResponse)
+ err = connectReq.Write(conn)
+ if err != nil {
+ return
+ }
+ // Okay to use and discard buffered reader here, because
+ // TLS server will not speak until spoken to.
+ br := bufio.NewReader(conn)
+ resp, err = ReadResponse(br, connectReq)
+ }()
+ select {
+ case <-connectCtx.Done():
+ conn.Close()
+ <-didReadResponse
+ return nil, connectCtx.Err()
+ case <-didReadResponse:
+ // resp or err now set
+ }
+ if err != nil {
+ conn.Close()
+ return nil, err
+ }
+
+ if t.OnProxyConnectResponse != nil {
+ err = t.OnProxyConnectResponse(ctx, cm.proxyURL, connectReq, resp)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ if resp.StatusCode != 200 {
+ _, text, ok := strings.Cut(resp.Status, " ")
+ conn.Close()
+ if !ok {
+ return nil, errors.New("unknown status code")
+ }
+ return nil, errors.New(text)
+ }
+ }
+
+ if cm.proxyURL != nil && cm.targetScheme == "https" {
+ if err := pconn.addTLS(ctx, cm.tlsHost(), trace); err != nil {
+ return nil, err
+ }
+ }
+
+ if s := pconn.tlsState; s != nil && s.NegotiatedProtocolIsMutual && s.NegotiatedProtocol != "" {
+ if next, ok := t.TLSNextProto[s.NegotiatedProtocol]; ok {
+ alt := next(cm.targetAddr, pconn.conn.(*tls.Conn))
+ if e, ok := alt.(erringRoundTripper); ok {
+ // pconn.conn was closed by next (http2configureTransports.upgradeFn).
+ return nil, e.RoundTripErr()
+ }
+ return &persistConn{t: t, cacheKey: pconn.cacheKey, alt: alt}, nil
+ }
+ }
+
+ pconn.br = bufio.NewReaderSize(pconn, t.readBufferSize())
+ pconn.bw = bufio.NewWriterSize(persistConnWriter{pconn}, t.writeBufferSize())
+
+ go pconn.readLoop()
+ go pconn.writeLoop()
+ return pconn, nil
+}
+
+// persistConnWriter is the io.Writer written to by pc.bw.
+// It accumulates the number of bytes written to the underlying conn,
+// so the retry logic can determine whether any bytes made it across
+// the wire.
+// This is exactly 1 pointer field wide so it can go into an interface
+// without allocation.
+type persistConnWriter struct {
+ pc *persistConn
+}
+
+func (w persistConnWriter) Write(p []byte) (n int, err error) {
+ n, err = w.pc.conn.Write(p)
+ w.pc.nwrite += int64(n)
+ return
+}
+
+// ReadFrom exposes persistConnWriter's underlying Conn to io.Copy and if
+// the Conn implements io.ReaderFrom, it can take advantage of optimizations
+// such as sendfile.
+func (w persistConnWriter) ReadFrom(r io.Reader) (n int64, err error) {
+ n, err = io.Copy(w.pc.conn, r)
+ w.pc.nwrite += n
+ return
+}
+
+var _ io.ReaderFrom = (*persistConnWriter)(nil)
+
+// connectMethod is the map key (in its String form) for keeping persistent
+// TCP connections alive for subsequent HTTP requests.
+//
+// A connect method may be of the following types:
+//
+// connectMethod.key().String() Description
+// ------------------------------ -------------------------
+// |http|foo.com http directly to server, no proxy
+// |https|foo.com https directly to server, no proxy
+// |https,h1|foo.com https directly to server w/o HTTP/2, no proxy
+// http://proxy.com|https|foo.com http to proxy, then CONNECT to foo.com
+// http://proxy.com|http http to proxy, http to anywhere after that
+// socks5://proxy.com|http|foo.com socks5 to proxy, then http to foo.com
+// socks5://proxy.com|https|foo.com socks5 to proxy, then https to foo.com
+// https://proxy.com|https|foo.com https to proxy, then CONNECT to foo.com
+// https://proxy.com|http https to proxy, http to anywhere after that
+type connectMethod struct {
+ _ incomparable
+ proxyURL *url.URL // nil for no proxy, else full proxy URL
+ targetScheme string // "http" or "https"
+ // If proxyURL specifies an http or https proxy, and targetScheme is http (not https),
+ // then targetAddr is not included in the connect method key, because the socket can
+ // be reused for different targetAddr values.
+ targetAddr string
+ onlyH1 bool // whether to disable HTTP/2 and force HTTP/1
+}
+
+func (cm *connectMethod) key() connectMethodKey {
+ proxyStr := ""
+ targetAddr := cm.targetAddr
+ if cm.proxyURL != nil {
+ proxyStr = cm.proxyURL.String()
+ if (cm.proxyURL.Scheme == "http" || cm.proxyURL.Scheme == "https") && cm.targetScheme == "http" {
+ targetAddr = ""
+ }
+ }
+ return connectMethodKey{
+ proxy: proxyStr,
+ scheme: cm.targetScheme,
+ addr: targetAddr,
+ onlyH1: cm.onlyH1,
+ }
+}
+
+// scheme returns the first hop scheme: http, https, or socks5
+func (cm *connectMethod) scheme() string {
+ if cm.proxyURL != nil {
+ return cm.proxyURL.Scheme
+ }
+ return cm.targetScheme
+}
+
+// addr returns the first hop "host:port" to which we need to TCP connect.
+func (cm *connectMethod) addr() string {
+ if cm.proxyURL != nil {
+ return canonicalAddr(cm.proxyURL)
+ }
+ return cm.targetAddr
+}
+
+// tlsHost returns the host name to match against the peer's
+// TLS certificate.
+func (cm *connectMethod) tlsHost() string {
+ h := cm.targetAddr
+ if hasPort(h) {
+ h = h[:strings.LastIndex(h, ":")]
+ }
+ return h
+}
+
+// connectMethodKey is the map key version of connectMethod, with a
+// stringified proxy URL (or the empty string) instead of a pointer to
+// a URL.
+type connectMethodKey struct {
+ proxy, scheme, addr string
+ onlyH1 bool
+}
+
+func (k connectMethodKey) String() string {
+ // Only used by tests.
+ var h1 string
+ if k.onlyH1 {
+ h1 = ",h1"
+ }
+ return fmt.Sprintf("%s|%s%s|%s", k.proxy, k.scheme, h1, k.addr)
+}
+
+// persistConn wraps a connection, usually a persistent one
+// (but may be used for non-keep-alive requests as well)
+type persistConn struct {
+ // alt optionally specifies the TLS NextProto RoundTripper.
+ // This is used for HTTP/2 today and future protocols later.
+ // If it's non-nil, the rest of the fields are unused.
+ alt RoundTripper
+
+ t *Transport
+ cacheKey connectMethodKey
+ conn net.Conn
+ tlsState *tls.ConnectionState
+ br *bufio.Reader // from conn
+ bw *bufio.Writer // to conn
+ nwrite int64 // bytes written
+ reqch chan requestAndChan // written by roundTrip; read by readLoop
+ writech chan writeRequest // written by roundTrip; read by writeLoop
+ closech chan struct{} // closed when conn closed
+ isProxy bool
+ sawEOF bool // whether we've seen EOF from conn; owned by readLoop
+ readLimit int64 // bytes allowed to be read; owned by readLoop
+ // writeErrCh passes the request write error (usually nil)
+ // from the writeLoop goroutine to the readLoop which passes
+ // it off to the res.Body reader, which then uses it to decide
+ // whether or not a connection can be reused. Issue 7569.
+ writeErrCh chan error
+
+ writeLoopDone chan struct{} // closed when write loop ends
+
+ // Both guarded by Transport.idleMu:
+ idleAt time.Time // time it last become idle
+ idleTimer *time.Timer // holding an AfterFunc to close it
+
+ mu sync.Mutex // guards following fields
+ numExpectedResponses int
+ closed error // set non-nil when conn is closed, before closech is closed
+ canceledErr error // set non-nil if conn is canceled
+ broken bool // an error has happened on this connection; marked broken so it's not reused.
+ reused bool // whether conn has had successful request/response and is being reused.
+ // mutateHeaderFunc is an optional func to modify extra
+ // headers on each outbound request before it's written. (the
+ // original Request given to RoundTrip is not modified)
+ mutateHeaderFunc func(Header)
+}
+
+func (pc *persistConn) maxHeaderResponseSize() int64 {
+ if v := pc.t.MaxResponseHeaderBytes; v != 0 {
+ return v
+ }
+ return 10 << 20 // conservative default; same as http2
+}
+
+func (pc *persistConn) Read(p []byte) (n int, err error) {
+ if pc.readLimit <= 0 {
+ return 0, fmt.Errorf("read limit of %d bytes exhausted", pc.maxHeaderResponseSize())
+ }
+ if int64(len(p)) > pc.readLimit {
+ p = p[:pc.readLimit]
+ }
+ n, err = pc.conn.Read(p)
+ if err == io.EOF {
+ pc.sawEOF = true
+ }
+ pc.readLimit -= int64(n)
+ return
+}
+
+// isBroken reports whether this connection is in a known broken state.
+func (pc *persistConn) isBroken() bool {
+ pc.mu.Lock()
+ b := pc.closed != nil
+ pc.mu.Unlock()
+ return b
+}
+
+// canceled returns non-nil if the connection was closed due to
+// CancelRequest or due to context cancellation.
+func (pc *persistConn) canceled() error {
+ pc.mu.Lock()
+ defer pc.mu.Unlock()
+ return pc.canceledErr
+}
+
+// isReused reports whether this connection has been used before.
+func (pc *persistConn) isReused() bool {
+ pc.mu.Lock()
+ r := pc.reused
+ pc.mu.Unlock()
+ return r
+}
+
+func (pc *persistConn) gotIdleConnTrace(idleAt time.Time) (t httptrace.GotConnInfo) {
+ pc.mu.Lock()
+ defer pc.mu.Unlock()
+ t.Reused = pc.reused
+ t.Conn = pc.conn
+ t.WasIdle = true
+ if !idleAt.IsZero() {
+ t.IdleTime = time.Since(idleAt)
+ }
+ return
+}
+
+func (pc *persistConn) cancelRequest(err error) {
+ pc.mu.Lock()
+ defer pc.mu.Unlock()
+ pc.canceledErr = err
+ pc.closeLocked(errRequestCanceled)
+}
+
+// closeConnIfStillIdle closes the connection if it's still sitting idle.
+// This is what's called by the persistConn's idleTimer, and is run in its
+// own goroutine.
+func (pc *persistConn) closeConnIfStillIdle() {
+ t := pc.t
+ t.idleMu.Lock()
+ defer t.idleMu.Unlock()
+ if _, ok := t.idleLRU.m[pc]; !ok {
+ // Not idle.
+ return
+ }
+ t.removeIdleConnLocked(pc)
+ pc.close(errIdleConnTimeout)
+}
+
+// mapRoundTripError returns the appropriate error value for
+// persistConn.roundTrip.
+//
+// The provided err is the first error that (*persistConn).roundTrip
+// happened to receive from its select statement.
+//
+// The startBytesWritten value should be the value of pc.nwrite before the roundTrip
+// started writing the request.
+func (pc *persistConn) mapRoundTripError(req *transportRequest, startBytesWritten int64, err error) error {
+ if err == nil {
+ return nil
+ }
+
+ // Wait for the writeLoop goroutine to terminate to avoid data
+ // races on callers who mutate the request on failure.
+ //
+ // When resc in pc.roundTrip and hence rc.ch receives a responseAndError
+ // with a non-nil error it implies that the persistConn is either closed
+ // or closing. Waiting on pc.writeLoopDone is hence safe as all callers
+ // close closech which in turn ensures writeLoop returns.
+ <-pc.writeLoopDone
+
+ // If the request was canceled, that's better than network
+ // failures that were likely the result of tearing down the
+ // connection.
+ if cerr := pc.canceled(); cerr != nil {
+ return cerr
+ }
+
+ // See if an error was set explicitly.
+ req.mu.Lock()
+ reqErr := req.err
+ req.mu.Unlock()
+ if reqErr != nil {
+ return reqErr
+ }
+
+ if err == errServerClosedIdle {
+ // Don't decorate
+ return err
+ }
+
+ if _, ok := err.(transportReadFromServerError); ok {
+ if pc.nwrite == startBytesWritten {
+ return nothingWrittenError{err}
+ }
+ // Don't decorate
+ return err
+ }
+ if pc.isBroken() {
+ if pc.nwrite == startBytesWritten {
+ return nothingWrittenError{err}
+ }
+ return fmt.Errorf("net/http: HTTP/1.x transport connection broken: %w", err)
+ }
+ return err
+}
+
+// errCallerOwnsConn is an internal sentinel error used when we hand
+// off a writable response.Body to the caller. We use this to prevent
+// closing a net.Conn that is now owned by the caller.
+var errCallerOwnsConn = errors.New("read loop ending; caller owns writable underlying conn")
+
+func (pc *persistConn) readLoop() {
+ closeErr := errReadLoopExiting // default value, if not changed below
+ defer func() {
+ pc.close(closeErr)
+ pc.t.removeIdleConn(pc)
+ }()
+
+ tryPutIdleConn := func(trace *httptrace.ClientTrace) bool {
+ if err := pc.t.tryPutIdleConn(pc); err != nil {
+ closeErr = err
+ if trace != nil && trace.PutIdleConn != nil && err != errKeepAlivesDisabled {
+ trace.PutIdleConn(err)
+ }
+ return false
+ }
+ if trace != nil && trace.PutIdleConn != nil {
+ trace.PutIdleConn(nil)
+ }
+ return true
+ }
+
+ // eofc is used to block caller goroutines reading from Response.Body
+ // at EOF until this goroutines has (potentially) added the connection
+ // back to the idle pool.
+ eofc := make(chan struct{})
+ defer close(eofc) // unblock reader on errors
+
+ // Read this once, before loop starts. (to avoid races in tests)
+ testHookMu.Lock()
+ testHookReadLoopBeforeNextRead := testHookReadLoopBeforeNextRead
+ testHookMu.Unlock()
+
+ alive := true
+ for alive {
+ pc.readLimit = pc.maxHeaderResponseSize()
+ _, err := pc.br.Peek(1)
+
+ pc.mu.Lock()
+ if pc.numExpectedResponses == 0 {
+ pc.readLoopPeekFailLocked(err)
+ pc.mu.Unlock()
+ return
+ }
+ pc.mu.Unlock()
+
+ rc := <-pc.reqch
+ trace := httptrace.ContextClientTrace(rc.req.Context())
+
+ var resp *Response
+ if err == nil {
+ resp, err = pc.readResponse(rc, trace)
+ } else {
+ err = transportReadFromServerError{err}
+ closeErr = err
+ }
+
+ if err != nil {
+ if pc.readLimit <= 0 {
+ err = fmt.Errorf("net/http: server response headers exceeded %d bytes; aborted", pc.maxHeaderResponseSize())
+ }
+
+ select {
+ case rc.ch <- responseAndError{err: err}:
+ case <-rc.callerGone:
+ return
+ }
+ return
+ }
+ pc.readLimit = maxInt64 // effectively no limit for response bodies
+
+ pc.mu.Lock()
+ pc.numExpectedResponses--
+ pc.mu.Unlock()
+
+ bodyWritable := resp.bodyIsWritable()
+ hasBody := rc.req.Method != "HEAD" && resp.ContentLength != 0
+
+ if resp.Close || rc.req.Close || resp.StatusCode <= 199 || bodyWritable {
+ // Don't do keep-alive on error if either party requested a close
+ // or we get an unexpected informational (1xx) response.
+ // StatusCode 100 is already handled above.
+ alive = false
+ }
+
+ if !hasBody || bodyWritable {
+ replaced := pc.t.replaceReqCanceler(rc.cancelKey, nil)
+
+ // Put the idle conn back into the pool before we send the response
+ // so if they process it quickly and make another request, they'll
+ // get this same conn. But we use the unbuffered channel 'rc'
+ // to guarantee that persistConn.roundTrip got out of its select
+ // potentially waiting for this persistConn to close.
+ alive = alive &&
+ !pc.sawEOF &&
+ pc.wroteRequest() &&
+ replaced && tryPutIdleConn(trace)
+
+ if bodyWritable {
+ closeErr = errCallerOwnsConn
+ }
+
+ select {
+ case rc.ch <- responseAndError{res: resp}:
+ case <-rc.callerGone:
+ return
+ }
+
+ // Now that they've read from the unbuffered channel, they're safely
+ // out of the select that also waits on this goroutine to die, so
+ // we're allowed to exit now if needed (if alive is false)
+ testHookReadLoopBeforeNextRead()
+ continue
+ }
+
+ waitForBodyRead := make(chan bool, 2)
+ body := &bodyEOFSignal{
+ body: resp.Body,
+ earlyCloseFn: func() error {
+ waitForBodyRead <- false
+ <-eofc // will be closed by deferred call at the end of the function
+ return nil
+
+ },
+ fn: func(err error) error {
+ isEOF := err == io.EOF
+ waitForBodyRead <- isEOF
+ if isEOF {
+ <-eofc // see comment above eofc declaration
+ } else if err != nil {
+ if cerr := pc.canceled(); cerr != nil {
+ return cerr
+ }
+ }
+ return err
+ },
+ }
+
+ resp.Body = body
+ if rc.addedGzip && ascii.EqualFold(resp.Header.Get("Content-Encoding"), "gzip") {
+ resp.Body = &gzipReader{body: body}
+ resp.Header.Del("Content-Encoding")
+ resp.Header.Del("Content-Length")
+ resp.ContentLength = -1
+ resp.Uncompressed = true
+ }
+
+ select {
+ case rc.ch <- responseAndError{res: resp}:
+ case <-rc.callerGone:
+ return
+ }
+
+ // Before looping back to the top of this function and peeking on
+ // the bufio.Reader, wait for the caller goroutine to finish
+ // reading the response body. (or for cancellation or death)
+ select {
+ case bodyEOF := <-waitForBodyRead:
+ replaced := pc.t.replaceReqCanceler(rc.cancelKey, nil) // before pc might return to idle pool
+ alive = alive &&
+ bodyEOF &&
+ !pc.sawEOF &&
+ pc.wroteRequest() &&
+ replaced && tryPutIdleConn(trace)
+ if bodyEOF {
+ eofc <- struct{}{}
+ }
+ case <-rc.req.Cancel:
+ alive = false
+ pc.t.CancelRequest(rc.req)
+ case <-rc.req.Context().Done():
+ alive = false
+ pc.t.cancelRequest(rc.cancelKey, rc.req.Context().Err())
+ case <-pc.closech:
+ alive = false
+ }
+
+ testHookReadLoopBeforeNextRead()
+ }
+}
+
+func (pc *persistConn) readLoopPeekFailLocked(peekErr error) {
+ if pc.closed != nil {
+ return
+ }
+ if n := pc.br.Buffered(); n > 0 {
+ buf, _ := pc.br.Peek(n)
+ if is408Message(buf) {
+ pc.closeLocked(errServerClosedIdle)
+ return
+ } else {
+ log.Printf("Unsolicited response received on idle HTTP channel starting with %q; err=%v", buf, peekErr)
+ }
+ }
+ if peekErr == io.EOF {
+ // common case.
+ pc.closeLocked(errServerClosedIdle)
+ } else {
+ pc.closeLocked(fmt.Errorf("readLoopPeekFailLocked: %w", peekErr))
+ }
+}
+
+// is408Message reports whether buf has the prefix of an
+// HTTP 408 Request Timeout response.
+// See golang.org/issue/32310.
+func is408Message(buf []byte) bool {
+ if len(buf) < len("HTTP/1.x 408") {
+ return false
+ }
+ if string(buf[:7]) != "HTTP/1." {
+ return false
+ }
+ return string(buf[8:12]) == " 408"
+}
+
+// readResponse reads an HTTP response (or two, in the case of "Expect:
+// 100-continue") from the server. It returns the final non-100 one.
+// trace is optional.
+func (pc *persistConn) readResponse(rc requestAndChan, trace *httptrace.ClientTrace) (resp *Response, err error) {
+ if trace != nil && trace.GotFirstResponseByte != nil {
+ if peek, err := pc.br.Peek(1); err == nil && len(peek) == 1 {
+ trace.GotFirstResponseByte()
+ }
+ }
+ num1xx := 0 // number of informational 1xx headers received
+ const max1xxResponses = 5 // arbitrary bound on number of informational responses
+
+ continueCh := rc.continueCh
+ for {
+ resp, err = ReadResponse(pc.br, rc.req)
+ if err != nil {
+ return
+ }
+ resCode := resp.StatusCode
+ if continueCh != nil {
+ if resCode == 100 {
+ if trace != nil && trace.Got100Continue != nil {
+ trace.Got100Continue()
+ }
+ continueCh <- struct{}{}
+ continueCh = nil
+ } else if resCode >= 200 {
+ close(continueCh)
+ continueCh = nil
+ }
+ }
+ is1xx := 100 <= resCode && resCode <= 199
+ // treat 101 as a terminal status, see issue 26161
+ is1xxNonTerminal := is1xx && resCode != StatusSwitchingProtocols
+ if is1xxNonTerminal {
+ num1xx++
+ if num1xx > max1xxResponses {
+ return nil, errors.New("net/http: too many 1xx informational responses")
+ }
+ pc.readLimit = pc.maxHeaderResponseSize() // reset the limit
+ if trace != nil && trace.Got1xxResponse != nil {
+ if err := trace.Got1xxResponse(resCode, textproto.MIMEHeader(resp.Header)); err != nil {
+ return nil, err
+ }
+ }
+ continue
+ }
+ break
+ }
+ if resp.isProtocolSwitch() {
+ resp.Body = newReadWriteCloserBody(pc.br, pc.conn)
+ }
+
+ resp.TLS = pc.tlsState
+ return
+}
+
+// waitForContinue returns the function to block until
+// any response, timeout or connection close. After any of them,
+// the function returns a bool which indicates if the body should be sent.
+func (pc *persistConn) waitForContinue(continueCh <-chan struct{}) func() bool {
+ if continueCh == nil {
+ return nil
+ }
+ return func() bool {
+ timer := time.NewTimer(pc.t.ExpectContinueTimeout)
+ defer timer.Stop()
+
+ select {
+ case _, ok := <-continueCh:
+ return ok
+ case <-timer.C:
+ return true
+ case <-pc.closech:
+ return false
+ }
+ }
+}
+
+func newReadWriteCloserBody(br *bufio.Reader, rwc io.ReadWriteCloser) io.ReadWriteCloser {
+ body := &readWriteCloserBody{ReadWriteCloser: rwc}
+ if br.Buffered() != 0 {
+ body.br = br
+ }
+ return body
+}
+
+// readWriteCloserBody is the Response.Body type used when we want to
+// give users write access to the Body through the underlying
+// connection (TCP, unless using custom dialers). This is then
+// the concrete type for a Response.Body on the 101 Switching
+// Protocols response, as used by WebSockets, h2c, etc.
+type readWriteCloserBody struct {
+ _ incomparable
+ br *bufio.Reader // used until empty
+ io.ReadWriteCloser
+}
+
+func (b *readWriteCloserBody) Read(p []byte) (n int, err error) {
+ if b.br != nil {
+ if n := b.br.Buffered(); len(p) > n {
+ p = p[:n]
+ }
+ n, err = b.br.Read(p)
+ if b.br.Buffered() == 0 {
+ b.br = nil
+ }
+ return n, err
+ }
+ return b.ReadWriteCloser.Read(p)
+}
+
+// nothingWrittenError wraps a write errors which ended up writing zero bytes.
+type nothingWrittenError struct {
+ error
+}
+
+func (nwe nothingWrittenError) Unwrap() error {
+ return nwe.error
+}
+
+func (pc *persistConn) writeLoop() {
+ defer close(pc.writeLoopDone)
+ for {
+ select {
+ case wr := <-pc.writech:
+ startBytesWritten := pc.nwrite
+ err := wr.req.Request.write(pc.bw, pc.isProxy, wr.req.extra, pc.waitForContinue(wr.continueCh))
+ if bre, ok := err.(requestBodyReadError); ok {
+ err = bre.error
+ // Errors reading from the user's
+ // Request.Body are high priority.
+ // Set it here before sending on the
+ // channels below or calling
+ // pc.close() which tears down
+ // connections and causes other
+ // errors.
+ wr.req.setError(err)
+ }
+ if err == nil {
+ err = pc.bw.Flush()
+ }
+ if err != nil {
+ if pc.nwrite == startBytesWritten {
+ err = nothingWrittenError{err}
+ }
+ }
+ pc.writeErrCh <- err // to the body reader, which might recycle us
+ wr.ch <- err // to the roundTrip function
+ if err != nil {
+ pc.close(err)
+ return
+ }
+ case <-pc.closech:
+ return
+ }
+ }
+}
+
+// maxWriteWaitBeforeConnReuse is how long the a Transport RoundTrip
+// will wait to see the Request's Body.Write result after getting a
+// response from the server. See comments in (*persistConn).wroteRequest.
+//
+// In tests, we set this to a large value to avoid flakiness from inconsistent
+// recycling of connections.
+var maxWriteWaitBeforeConnReuse = 50 * time.Millisecond
+
+// wroteRequest is a check before recycling a connection that the previous write
+// (from writeLoop above) happened and was successful.
+func (pc *persistConn) wroteRequest() bool {
+ select {
+ case err := <-pc.writeErrCh:
+ // Common case: the write happened well before the response, so
+ // avoid creating a timer.
+ return err == nil
+ default:
+ // Rare case: the request was written in writeLoop above but
+ // before it could send to pc.writeErrCh, the reader read it
+ // all, processed it, and called us here. In this case, give the
+ // write goroutine a bit of time to finish its send.
+ //
+ // Less rare case: We also get here in the legitimate case of
+ // Issue 7569, where the writer is still writing (or stalled),
+ // but the server has already replied. In this case, we don't
+ // want to wait too long, and we want to return false so this
+ // connection isn't re-used.
+ t := time.NewTimer(maxWriteWaitBeforeConnReuse)
+ defer t.Stop()
+ select {
+ case err := <-pc.writeErrCh:
+ return err == nil
+ case <-t.C:
+ return false
+ }
+ }
+}
+
+// responseAndError is how the goroutine reading from an HTTP/1 server
+// communicates with the goroutine doing the RoundTrip.
+type responseAndError struct {
+ _ incomparable
+ res *Response // else use this response (see res method)
+ err error
+}
+
+type requestAndChan struct {
+ _ incomparable
+ req *Request
+ cancelKey cancelKey
+ ch chan responseAndError // unbuffered; always send in select on callerGone
+
+ // whether the Transport (as opposed to the user client code)
+ // added the Accept-Encoding gzip header. If the Transport
+ // set it, only then do we transparently decode the gzip.
+ addedGzip bool
+
+ // Optional blocking chan for Expect: 100-continue (for send).
+ // If the request has an "Expect: 100-continue" header and
+ // the server responds 100 Continue, readLoop send a value
+ // to writeLoop via this chan.
+ continueCh chan<- struct{}
+
+ callerGone <-chan struct{} // closed when roundTrip caller has returned
+}
+
+// A writeRequest is sent by the caller's goroutine to the
+// writeLoop's goroutine to write a request while the read loop
+// concurrently waits on both the write response and the server's
+// reply.
+type writeRequest struct {
+ req *transportRequest
+ ch chan<- error
+
+ // Optional blocking chan for Expect: 100-continue (for receive).
+ // If not nil, writeLoop blocks sending request body until
+ // it receives from this chan.
+ continueCh <-chan struct{}
+}
+
+type httpError struct {
+ err string
+ timeout bool
+}
+
+func (e *httpError) Error() string { return e.err }
+func (e *httpError) Timeout() bool { return e.timeout }
+func (e *httpError) Temporary() bool { return true }
+
+var errTimeout error = &httpError{err: "net/http: timeout awaiting response headers", timeout: true}
+
+// errRequestCanceled is set to be identical to the one from h2 to facilitate
+// testing.
+var errRequestCanceled = http2errRequestCanceled
+var errRequestCanceledConn = errors.New("net/http: request canceled while waiting for connection") // TODO: unify?
+
+func nop() {}
+
+// testHooks. Always non-nil.
+var (
+ testHookEnterRoundTrip = nop
+ testHookWaitResLoop = nop
+ testHookRoundTripRetried = nop
+ testHookPrePendingDial = nop
+ testHookPostPendingDial = nop
+
+ testHookMu sync.Locker = fakeLocker{} // guards following
+ testHookReadLoopBeforeNextRead = nop
+)
+
+func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) {
+ testHookEnterRoundTrip()
+ if !pc.t.replaceReqCanceler(req.cancelKey, pc.cancelRequest) {
+ pc.t.putOrCloseIdleConn(pc)
+ return nil, errRequestCanceled
+ }
+ pc.mu.Lock()
+ pc.numExpectedResponses++
+ headerFn := pc.mutateHeaderFunc
+ pc.mu.Unlock()
+
+ if headerFn != nil {
+ headerFn(req.extraHeaders())
+ }
+
+ // Ask for a compressed version if the caller didn't set their
+ // own value for Accept-Encoding. We only attempt to
+ // uncompress the gzip stream if we were the layer that
+ // requested it.
+ requestedGzip := false
+ if !pc.t.DisableCompression &&
+ req.Header.Get("Accept-Encoding") == "" &&
+ req.Header.Get("Range") == "" &&
+ req.Method != "HEAD" {
+ // Request gzip only, not deflate. Deflate is ambiguous and
+ // not as universally supported anyway.
+ // See: https://zlib.net/zlib_faq.html#faq39
+ //
+ // Note that we don't request this for HEAD requests,
+ // due to a bug in nginx:
+ // https://trac.nginx.org/nginx/ticket/358
+ // https://golang.org/issue/5522
+ //
+ // We don't request gzip if the request is for a range, since
+ // auto-decoding a portion of a gzipped document will just fail
+ // anyway. See https://golang.org/issue/8923
+ requestedGzip = true
+ req.extraHeaders().Set("Accept-Encoding", "gzip")
+ }
+
+ var continueCh chan struct{}
+ if req.ProtoAtLeast(1, 1) && req.Body != nil && req.expectsContinue() {
+ continueCh = make(chan struct{}, 1)
+ }
+
+ if pc.t.DisableKeepAlives &&
+ !req.wantsClose() &&
+ !isProtocolSwitchHeader(req.Header) {
+ req.extraHeaders().Set("Connection", "close")
+ }
+
+ gone := make(chan struct{})
+ defer close(gone)
+
+ defer func() {
+ if err != nil {
+ pc.t.setReqCanceler(req.cancelKey, nil)
+ }
+ }()
+
+ const debugRoundTrip = false
+
+ // Write the request concurrently with waiting for a response,
+ // in case the server decides to reply before reading our full
+ // request body.
+ startBytesWritten := pc.nwrite
+ writeErrCh := make(chan error, 1)
+ pc.writech <- writeRequest{req, writeErrCh, continueCh}
+
+ resc := make(chan responseAndError)
+ pc.reqch <- requestAndChan{
+ req: req.Request,
+ cancelKey: req.cancelKey,
+ ch: resc,
+ addedGzip: requestedGzip,
+ continueCh: continueCh,
+ callerGone: gone,
+ }
+
+ var respHeaderTimer <-chan time.Time
+ cancelChan := req.Request.Cancel
+ ctxDoneChan := req.Context().Done()
+ pcClosed := pc.closech
+ canceled := false
+ for {
+ testHookWaitResLoop()
+ select {
+ case err := <-writeErrCh:
+ if debugRoundTrip {
+ req.logf("writeErrCh resv: %T/%#v", err, err)
+ }
+ if err != nil {
+ pc.close(fmt.Errorf("write error: %w", err))
+ return nil, pc.mapRoundTripError(req, startBytesWritten, err)
+ }
+ if d := pc.t.ResponseHeaderTimeout; d > 0 {
+ if debugRoundTrip {
+ req.logf("starting timer for %v", d)
+ }
+ timer := time.NewTimer(d)
+ defer timer.Stop() // prevent leaks
+ respHeaderTimer = timer.C
+ }
+ case <-pcClosed:
+ pcClosed = nil
+ if canceled || pc.t.replaceReqCanceler(req.cancelKey, nil) {
+ if debugRoundTrip {
+ req.logf("closech recv: %T %#v", pc.closed, pc.closed)
+ }
+ return nil, pc.mapRoundTripError(req, startBytesWritten, pc.closed)
+ }
+ case <-respHeaderTimer:
+ if debugRoundTrip {
+ req.logf("timeout waiting for response headers.")
+ }
+ pc.close(errTimeout)
+ return nil, errTimeout
+ case re := <-resc:
+ if (re.res == nil) == (re.err == nil) {
+ panic(fmt.Sprintf("internal error: exactly one of res or err should be set; nil=%v", re.res == nil))
+ }
+ if debugRoundTrip {
+ req.logf("resc recv: %p, %T/%#v", re.res, re.err, re.err)
+ }
+ if re.err != nil {
+ return nil, pc.mapRoundTripError(req, startBytesWritten, re.err)
+ }
+ return re.res, nil
+ case <-cancelChan:
+ canceled = pc.t.cancelRequest(req.cancelKey, errRequestCanceled)
+ cancelChan = nil
+ case <-ctxDoneChan:
+ canceled = pc.t.cancelRequest(req.cancelKey, req.Context().Err())
+ cancelChan = nil
+ ctxDoneChan = nil
+ }
+ }
+}
+
+// tLogKey is a context WithValue key for test debugging contexts containing
+// a t.Logf func. See export_test.go's Request.WithT method.
+type tLogKey struct{}
+
+func (tr *transportRequest) logf(format string, args ...any) {
+ if logf, ok := tr.Request.Context().Value(tLogKey{}).(func(string, ...any)); ok {
+ logf(time.Now().Format(time.RFC3339Nano)+": "+format, args...)
+ }
+}
+
+// markReused marks this connection as having been successfully used for a
+// request and response.
+func (pc *persistConn) markReused() {
+ pc.mu.Lock()
+ pc.reused = true
+ pc.mu.Unlock()
+}
+
+// close closes the underlying TCP connection and closes
+// the pc.closech channel.
+//
+// The provided err is only for testing and debugging; in normal
+// circumstances it should never be seen by users.
+func (pc *persistConn) close(err error) {
+ pc.mu.Lock()
+ defer pc.mu.Unlock()
+ pc.closeLocked(err)
+}
+
+func (pc *persistConn) closeLocked(err error) {
+ if err == nil {
+ panic("nil error")
+ }
+ pc.broken = true
+ if pc.closed == nil {
+ pc.closed = err
+ pc.t.decConnsPerHost(pc.cacheKey)
+ // Close HTTP/1 (pc.alt == nil) connection.
+ // HTTP/2 closes its connection itself.
+ if pc.alt == nil {
+ if err != errCallerOwnsConn {
+ pc.conn.Close()
+ }
+ close(pc.closech)
+ }
+ }
+ pc.mutateHeaderFunc = nil
+}
+
+var portMap = map[string]string{
+ "http": "80",
+ "https": "443",
+ "socks5": "1080",
+}
+
+func idnaASCIIFromURL(url *url.URL) string {
+ addr := url.Hostname()
+ if v, err := idnaASCII(addr); err == nil {
+ addr = v
+ }
+ return addr
+}
+
+// canonicalAddr returns url.Host but always with a ":port" suffix.
+func canonicalAddr(url *url.URL) string {
+ port := url.Port()
+ if port == "" {
+ port = portMap[url.Scheme]
+ }
+ return net.JoinHostPort(idnaASCIIFromURL(url), port)
+}
+
+// bodyEOFSignal is used by the HTTP/1 transport when reading response
+// bodies to make sure we see the end of a response body before
+// proceeding and reading on the connection again.
+//
+// It wraps a ReadCloser but runs fn (if non-nil) at most
+// once, right before its final (error-producing) Read or Close call
+// returns. fn should return the new error to return from Read or Close.
+//
+// If earlyCloseFn is non-nil and Close is called before io.EOF is
+// seen, earlyCloseFn is called instead of fn, and its return value is
+// the return value from Close.
+type bodyEOFSignal struct {
+ body io.ReadCloser
+ mu sync.Mutex // guards following 4 fields
+ closed bool // whether Close has been called
+ rerr error // sticky Read error
+ fn func(error) error // err will be nil on Read io.EOF
+ earlyCloseFn func() error // optional alt Close func used if io.EOF not seen
+}
+
+var errReadOnClosedResBody = errors.New("http: read on closed response body")
+
+func (es *bodyEOFSignal) Read(p []byte) (n int, err error) {
+ es.mu.Lock()
+ closed, rerr := es.closed, es.rerr
+ es.mu.Unlock()
+ if closed {
+ return 0, errReadOnClosedResBody
+ }
+ if rerr != nil {
+ return 0, rerr
+ }
+
+ n, err = es.body.Read(p)
+ if err != nil {
+ es.mu.Lock()
+ defer es.mu.Unlock()
+ if es.rerr == nil {
+ es.rerr = err
+ }
+ err = es.condfn(err)
+ }
+ return
+}
+
+func (es *bodyEOFSignal) Close() error {
+ es.mu.Lock()
+ defer es.mu.Unlock()
+ if es.closed {
+ return nil
+ }
+ es.closed = true
+ if es.earlyCloseFn != nil && es.rerr != io.EOF {
+ return es.earlyCloseFn()
+ }
+ err := es.body.Close()
+ return es.condfn(err)
+}
+
+// caller must hold es.mu.
+func (es *bodyEOFSignal) condfn(err error) error {
+ if es.fn == nil {
+ return err
+ }
+ err = es.fn(err)
+ es.fn = nil
+ return err
+}
+
+// gzipReader wraps a response body so it can lazily
+// call gzip.NewReader on the first call to Read
+type gzipReader struct {
+ _ incomparable
+ body *bodyEOFSignal // underlying HTTP/1 response body framing
+ zr *gzip.Reader // lazily-initialized gzip reader
+ zerr error // any error from gzip.NewReader; sticky
+}
+
+func (gz *gzipReader) Read(p []byte) (n int, err error) {
+ if gz.zr == nil {
+ if gz.zerr == nil {
+ gz.zr, gz.zerr = gzip.NewReader(gz.body)
+ }
+ if gz.zerr != nil {
+ return 0, gz.zerr
+ }
+ }
+
+ gz.body.mu.Lock()
+ if gz.body.closed {
+ err = errReadOnClosedResBody
+ }
+ gz.body.mu.Unlock()
+
+ if err != nil {
+ return 0, err
+ }
+ return gz.zr.Read(p)
+}
+
+func (gz *gzipReader) Close() error {
+ return gz.body.Close()
+}
+
+type tlsHandshakeTimeoutError struct{}
+
+func (tlsHandshakeTimeoutError) Timeout() bool { return true }
+func (tlsHandshakeTimeoutError) Temporary() bool { return true }
+func (tlsHandshakeTimeoutError) Error() string { return "net/http: TLS handshake timeout" }
+
+// fakeLocker is a sync.Locker which does nothing. It's used to guard
+// test-only fields when not under test, to avoid runtime atomic
+// overhead.
+type fakeLocker struct{}
+
+func (fakeLocker) Lock() {}
+func (fakeLocker) Unlock() {}
+
+// cloneTLSConfig returns a shallow clone of cfg, or a new zero tls.Config if
+// cfg is nil. This is safe to call even if cfg is in active use by a TLS
+// client or server.
+func cloneTLSConfig(cfg *tls.Config) *tls.Config {
+ if cfg == nil {
+ return &tls.Config{}
+ }
+ return cfg.Clone()
+}
+
+type connLRU struct {
+ ll *list.List // list.Element.Value type is of *persistConn
+ m map[*persistConn]*list.Element
+}
+
+// add adds pc to the head of the linked list.
+func (cl *connLRU) add(pc *persistConn) {
+ if cl.ll == nil {
+ cl.ll = list.New()
+ cl.m = make(map[*persistConn]*list.Element)
+ }
+ ele := cl.ll.PushFront(pc)
+ if _, ok := cl.m[pc]; ok {
+ panic("persistConn was already in LRU")
+ }
+ cl.m[pc] = ele
+}
+
+func (cl *connLRU) removeOldest() *persistConn {
+ ele := cl.ll.Back()
+ pc := ele.Value.(*persistConn)
+ cl.ll.Remove(ele)
+ delete(cl.m, pc)
+ return pc
+}
+
+// remove removes pc from cl.
+func (cl *connLRU) remove(pc *persistConn) {
+ if ele, ok := cl.m[pc]; ok {
+ cl.ll.Remove(ele)
+ delete(cl.m, pc)
+ }
+}
+
+// len returns the number of items in the cache.
+func (cl *connLRU) len() int {
+ return len(cl.m)
+}
diff --git a/src/net/http/transport_default_other.go b/src/net/http/transport_default_other.go
new file mode 100644
index 0000000..4f6c5c1
--- /dev/null
+++ b/src/net/http/transport_default_other.go
@@ -0,0 +1,16 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !wasm
+
+package http
+
+import (
+ "context"
+ "net"
+)
+
+func defaultTransportDialContext(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error) {
+ return dialer.DialContext
+}
diff --git a/src/net/http/transport_default_wasm.go b/src/net/http/transport_default_wasm.go
new file mode 100644
index 0000000..3946812
--- /dev/null
+++ b/src/net/http/transport_default_wasm.go
@@ -0,0 +1,16 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build (js && wasm) || wasip1
+
+package http
+
+import (
+ "context"
+ "net"
+)
+
+func defaultTransportDialContext(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error) {
+ return nil
+}
diff --git a/src/net/http/transport_internal_test.go b/src/net/http/transport_internal_test.go
new file mode 100644
index 0000000..2ed637e
--- /dev/null
+++ b/src/net/http/transport_internal_test.go
@@ -0,0 +1,267 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// White-box tests for transport.go (in package http instead of http_test).
+
+package http
+
+import (
+ "bytes"
+ "crypto/tls"
+ "errors"
+ "io"
+ "net"
+ "net/http/internal/testcert"
+ "strings"
+ "testing"
+)
+
+// Issue 15446: incorrect wrapping of errors when server closes an idle connection.
+func TestTransportPersistConnReadLoopEOF(t *testing.T) {
+ ln := newLocalListener(t)
+ defer ln.Close()
+
+ connc := make(chan net.Conn, 1)
+ go func() {
+ defer close(connc)
+ c, err := ln.Accept()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ connc <- c
+ }()
+
+ tr := new(Transport)
+ req, _ := NewRequest("GET", "http://"+ln.Addr().String(), nil)
+ req = req.WithT(t)
+ treq := &transportRequest{Request: req}
+ cm := connectMethod{targetScheme: "http", targetAddr: ln.Addr().String()}
+ pc, err := tr.getConn(treq, cm)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer pc.close(errors.New("test over"))
+
+ conn := <-connc
+ if conn == nil {
+ // Already called t.Error in the accept goroutine.
+ return
+ }
+ conn.Close() // simulate the server hanging up on the client
+
+ _, err = pc.roundTrip(treq)
+ if !isNothingWrittenError(err) && !isTransportReadFromServerError(err) && err != errServerClosedIdle {
+ t.Errorf("roundTrip = %#v, %v; want errServerClosedIdle, transportReadFromServerError, or nothingWrittenError", err, err)
+ }
+
+ <-pc.closech
+ err = pc.closed
+ if !isTransportReadFromServerError(err) && err != errServerClosedIdle {
+ t.Errorf("pc.closed = %#v, %v; want errServerClosedIdle or transportReadFromServerError", err, err)
+ }
+}
+
+func isNothingWrittenError(err error) bool {
+ _, ok := err.(nothingWrittenError)
+ return ok
+}
+
+func isTransportReadFromServerError(err error) bool {
+ _, ok := err.(transportReadFromServerError)
+ return ok
+}
+
+func newLocalListener(t *testing.T) net.Listener {
+ ln, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ ln, err = net.Listen("tcp6", "[::1]:0")
+ }
+ if err != nil {
+ t.Fatal(err)
+ }
+ return ln
+}
+
+func dummyRequest(method string) *Request {
+ req, err := NewRequest(method, "http://fake.tld/", nil)
+ if err != nil {
+ panic(err)
+ }
+ return req
+}
+func dummyRequestWithBody(method string) *Request {
+ req, err := NewRequest(method, "http://fake.tld/", strings.NewReader("foo"))
+ if err != nil {
+ panic(err)
+ }
+ return req
+}
+
+func dummyRequestWithBodyNoGetBody(method string) *Request {
+ req := dummyRequestWithBody(method)
+ req.GetBody = nil
+ return req
+}
+
+// issue22091Error acts like a golang.org/x/net/http2.ErrNoCachedConn.
+type issue22091Error struct{}
+
+func (issue22091Error) IsHTTP2NoCachedConnError() {}
+func (issue22091Error) Error() string { return "issue22091Error" }
+
+func TestTransportShouldRetryRequest(t *testing.T) {
+ tests := []struct {
+ pc *persistConn
+ req *Request
+
+ err error
+ want bool
+ }{
+ 0: {
+ pc: &persistConn{reused: false},
+ req: dummyRequest("POST"),
+ err: nothingWrittenError{},
+ want: false,
+ },
+ 1: {
+ pc: &persistConn{reused: true},
+ req: dummyRequest("POST"),
+ err: nothingWrittenError{},
+ want: true,
+ },
+ 2: {
+ pc: &persistConn{reused: true},
+ req: dummyRequest("POST"),
+ err: http2ErrNoCachedConn,
+ want: true,
+ },
+ 3: {
+ pc: nil,
+ req: nil,
+ err: issue22091Error{}, // like an external http2ErrNoCachedConn
+ want: true,
+ },
+ 4: {
+ pc: &persistConn{reused: true},
+ req: dummyRequest("POST"),
+ err: errMissingHost,
+ want: false,
+ },
+ 5: {
+ pc: &persistConn{reused: true},
+ req: dummyRequest("POST"),
+ err: transportReadFromServerError{},
+ want: false,
+ },
+ 6: {
+ pc: &persistConn{reused: true},
+ req: dummyRequest("GET"),
+ err: transportReadFromServerError{},
+ want: true,
+ },
+ 7: {
+ pc: &persistConn{reused: true},
+ req: dummyRequest("GET"),
+ err: errServerClosedIdle,
+ want: true,
+ },
+ 8: {
+ pc: &persistConn{reused: true},
+ req: dummyRequestWithBody("POST"),
+ err: nothingWrittenError{},
+ want: true,
+ },
+ 9: {
+ pc: &persistConn{reused: true},
+ req: dummyRequestWithBodyNoGetBody("POST"),
+ err: nothingWrittenError{},
+ want: false,
+ },
+ }
+ for i, tt := range tests {
+ got := tt.pc.shouldRetryRequest(tt.req, tt.err)
+ if got != tt.want {
+ t.Errorf("%d. shouldRetryRequest = %v; want %v", i, got, tt.want)
+ }
+ }
+}
+
+type roundTripFunc func(r *Request) (*Response, error)
+
+func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) {
+ return f(r)
+}
+
+// Issue 25009
+func TestTransportBodyAltRewind(t *testing.T) {
+ cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
+ if err != nil {
+ t.Fatal(err)
+ }
+ ln := newLocalListener(t)
+ defer ln.Close()
+
+ go func() {
+ tln := tls.NewListener(ln, &tls.Config{
+ NextProtos: []string{"foo"},
+ Certificates: []tls.Certificate{cert},
+ })
+ for i := 0; i < 2; i++ {
+ sc, err := tln.Accept()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ if err := sc.(*tls.Conn).Handshake(); err != nil {
+ t.Error(err)
+ return
+ }
+ sc.Close()
+ }
+ }()
+
+ addr := ln.Addr().String()
+ req, _ := NewRequest("POST", "https://example.org/", bytes.NewBufferString("request"))
+ roundTripped := false
+ tr := &Transport{
+ DisableKeepAlives: true,
+ TLSNextProto: map[string]func(string, *tls.Conn) RoundTripper{
+ "foo": func(authority string, c *tls.Conn) RoundTripper {
+ return roundTripFunc(func(r *Request) (*Response, error) {
+ n, _ := io.Copy(io.Discard, r.Body)
+ if n == 0 {
+ t.Error("body length is zero")
+ }
+ if roundTripped {
+ return &Response{
+ Body: NoBody,
+ StatusCode: 200,
+ }, nil
+ }
+ roundTripped = true
+ return nil, http2noCachedConnError{}
+ })
+ },
+ },
+ DialTLS: func(_, _ string) (net.Conn, error) {
+ tc, err := tls.Dial("tcp", addr, &tls.Config{
+ InsecureSkipVerify: true,
+ NextProtos: []string{"foo"},
+ })
+ if err != nil {
+ return nil, err
+ }
+ if err := tc.Handshake(); err != nil {
+ return nil, err
+ }
+ return tc, nil
+ },
+ }
+ c := &Client{Transport: tr}
+ _, err = c.Do(req)
+ if err != nil {
+ t.Error(err)
+ }
+}
diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go
new file mode 100644
index 0000000..028fecc
--- /dev/null
+++ b/src/net/http/transport_test.go
@@ -0,0 +1,6752 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Tests for transport.go.
+//
+// More tests are in clientserver_test.go (for things testing both client & server for both
+// HTTP/1 and HTTP/2). This
+
+package http_test
+
+import (
+ "bufio"
+ "bytes"
+ "compress/gzip"
+ "context"
+ "crypto/rand"
+ "crypto/tls"
+ "crypto/x509"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "go/token"
+ "internal/nettrace"
+ "io"
+ "log"
+ mrand "math/rand"
+ "net"
+ . "net/http"
+ "net/http/httptest"
+ "net/http/httptrace"
+ "net/http/httputil"
+ "net/http/internal/testcert"
+ "net/textproto"
+ "net/url"
+ "os"
+ "reflect"
+ "runtime"
+ "strconv"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "testing/iotest"
+ "time"
+
+ "golang.org/x/net/http/httpguts"
+)
+
+// TODO: test 5 pipelined requests with responses: 1) OK, 2) OK, Connection: Close
+// and then verify that the final 2 responses get errors back.
+
+// hostPortHandler writes back the client's "host:port".
+var hostPortHandler = HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.FormValue("close") == "true" {
+ w.Header().Set("Connection", "close")
+ }
+ w.Header().Set("X-Saw-Close", fmt.Sprint(r.Close))
+ w.Write([]byte(r.RemoteAddr))
+
+ // Include the address of the net.Conn in addition to the RemoteAddr,
+ // in case kernels reuse source ports quickly (see Issue 52450)
+ if c, ok := ResponseWriterConnForTesting(w); ok {
+ fmt.Fprintf(w, ", %T %p", c, c)
+ }
+})
+
+// testCloseConn is a net.Conn tracked by a testConnSet.
+type testCloseConn struct {
+ net.Conn
+ set *testConnSet
+}
+
+func (c *testCloseConn) Close() error {
+ c.set.remove(c)
+ return c.Conn.Close()
+}
+
+// testConnSet tracks a set of TCP connections and whether they've
+// been closed.
+type testConnSet struct {
+ t *testing.T
+ mu sync.Mutex // guards closed and list
+ closed map[net.Conn]bool
+ list []net.Conn // in order created
+}
+
+func (tcs *testConnSet) insert(c net.Conn) {
+ tcs.mu.Lock()
+ defer tcs.mu.Unlock()
+ tcs.closed[c] = false
+ tcs.list = append(tcs.list, c)
+}
+
+func (tcs *testConnSet) remove(c net.Conn) {
+ tcs.mu.Lock()
+ defer tcs.mu.Unlock()
+ tcs.closed[c] = true
+}
+
+// some tests use this to manage raw tcp connections for later inspection
+func makeTestDial(t *testing.T) (*testConnSet, func(n, addr string) (net.Conn, error)) {
+ connSet := &testConnSet{
+ t: t,
+ closed: make(map[net.Conn]bool),
+ }
+ dial := func(n, addr string) (net.Conn, error) {
+ c, err := net.Dial(n, addr)
+ if err != nil {
+ return nil, err
+ }
+ tc := &testCloseConn{c, connSet}
+ connSet.insert(tc)
+ return tc, nil
+ }
+ return connSet, dial
+}
+
+func (tcs *testConnSet) check(t *testing.T) {
+ tcs.mu.Lock()
+ defer tcs.mu.Unlock()
+ for i := 4; i >= 0; i-- {
+ for i, c := range tcs.list {
+ if tcs.closed[c] {
+ continue
+ }
+ if i != 0 {
+ // TODO(bcmills): What is the Sleep here doing, and why is this
+ // Unlock/Sleep/Lock cycle needed at all?
+ tcs.mu.Unlock()
+ time.Sleep(50 * time.Millisecond)
+ tcs.mu.Lock()
+ continue
+ }
+ t.Errorf("TCP connection #%d, %p (of %d total) was not closed", i+1, c, len(tcs.list))
+ }
+ }
+}
+
+func TestReuseRequest(t *testing.T) { run(t, testReuseRequest) }
+func testReuseRequest(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Write([]byte("{}"))
+ })).ts
+
+ c := ts.Client()
+ req, _ := NewRequest("GET", ts.URL, nil)
+ res, err := c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ err = res.Body.Close()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ res, err = c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ err = res.Body.Close()
+ if err != nil {
+ t.Fatal(err)
+ }
+}
+
+// Two subsequent requests and verify their response is the same.
+// The response from the server is our own IP:port
+func TestTransportKeepAlives(t *testing.T) { run(t, testTransportKeepAlives, []testMode{http1Mode}) }
+func testTransportKeepAlives(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, hostPortHandler).ts
+
+ c := ts.Client()
+ for _, disableKeepAlive := range []bool{false, true} {
+ c.Transport.(*Transport).DisableKeepAlives = disableKeepAlive
+ fetch := func(n int) string {
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatalf("error in disableKeepAlive=%v, req #%d, GET: %v", disableKeepAlive, n, err)
+ }
+ body, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatalf("error in disableKeepAlive=%v, req #%d, ReadAll: %v", disableKeepAlive, n, err)
+ }
+ return string(body)
+ }
+
+ body1 := fetch(1)
+ body2 := fetch(2)
+
+ bodiesDiffer := body1 != body2
+ if bodiesDiffer != disableKeepAlive {
+ t.Errorf("error in disableKeepAlive=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
+ disableKeepAlive, bodiesDiffer, body1, body2)
+ }
+ }
+}
+
+func TestTransportConnectionCloseOnResponse(t *testing.T) {
+ run(t, testTransportConnectionCloseOnResponse)
+}
+func testTransportConnectionCloseOnResponse(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, hostPortHandler).ts
+
+ connSet, testDial := makeTestDial(t)
+
+ c := ts.Client()
+ tr := c.Transport.(*Transport)
+ tr.Dial = testDial
+
+ for _, connectionClose := range []bool{false, true} {
+ fetch := func(n int) string {
+ req := new(Request)
+ var err error
+ req.URL, err = url.Parse(ts.URL + fmt.Sprintf("/?close=%v", connectionClose))
+ if err != nil {
+ t.Fatalf("URL parse error: %v", err)
+ }
+ req.Method = "GET"
+ req.Proto = "HTTP/1.1"
+ req.ProtoMajor = 1
+ req.ProtoMinor = 1
+
+ res, err := c.Do(req)
+ if err != nil {
+ t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err)
+ }
+ defer res.Body.Close()
+ body, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err)
+ }
+ return string(body)
+ }
+
+ body1 := fetch(1)
+ body2 := fetch(2)
+ bodiesDiffer := body1 != body2
+ if bodiesDiffer != connectionClose {
+ t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
+ connectionClose, bodiesDiffer, body1, body2)
+ }
+
+ tr.CloseIdleConnections()
+ }
+
+ connSet.check(t)
+}
+
+// TestTransportConnectionCloseOnRequest tests that the Transport's doesn't reuse
+// an underlying TCP connection after making an http.Request with Request.Close set.
+//
+// It tests the behavior by making an HTTP request to a server which
+// describes the source connection it got (remote port number +
+// address of its net.Conn).
+func TestTransportConnectionCloseOnRequest(t *testing.T) {
+ run(t, testTransportConnectionCloseOnRequest, []testMode{http1Mode})
+}
+func testTransportConnectionCloseOnRequest(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, hostPortHandler).ts
+
+ connSet, testDial := makeTestDial(t)
+
+ c := ts.Client()
+ tr := c.Transport.(*Transport)
+ tr.Dial = testDial
+ for _, reqClose := range []bool{false, true} {
+ fetch := func(n int) string {
+ req := new(Request)
+ var err error
+ req.URL, err = url.Parse(ts.URL)
+ if err != nil {
+ t.Fatalf("URL parse error: %v", err)
+ }
+ req.Method = "GET"
+ req.Proto = "HTTP/1.1"
+ req.ProtoMajor = 1
+ req.ProtoMinor = 1
+ req.Close = reqClose
+
+ res, err := c.Do(req)
+ if err != nil {
+ t.Fatalf("error in Request.Close=%v, req #%d, Do: %v", reqClose, n, err)
+ }
+ if got, want := res.Header.Get("X-Saw-Close"), fmt.Sprint(reqClose); got != want {
+ t.Errorf("for Request.Close = %v; handler's X-Saw-Close was %v; want %v",
+ reqClose, got, !reqClose)
+ }
+ body, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatalf("for Request.Close=%v, on request %v/2: ReadAll: %v", reqClose, n, err)
+ }
+ return string(body)
+ }
+
+ body1 := fetch(1)
+ body2 := fetch(2)
+
+ got := 1
+ if body1 != body2 {
+ got++
+ }
+ want := 1
+ if reqClose {
+ want = 2
+ }
+ if got != want {
+ t.Errorf("for Request.Close=%v: server saw %v unique connections, wanted %v\n\nbodies were: %q and %q",
+ reqClose, got, want, body1, body2)
+ }
+
+ tr.CloseIdleConnections()
+ }
+
+ connSet.check(t)
+}
+
+// if the Transport's DisableKeepAlives is set, all requests should
+// send Connection: close.
+// HTTP/1-only (Connection: close doesn't exist in h2)
+func TestTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T) {
+ run(t, testTransportConnectionCloseOnRequestDisableKeepAlive, []testMode{http1Mode})
+}
+func testTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, hostPortHandler).ts
+
+ c := ts.Client()
+ c.Transport.(*Transport).DisableKeepAlives = true
+
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ if res.Header.Get("X-Saw-Close") != "true" {
+ t.Errorf("handler didn't see Connection: close ")
+ }
+}
+
+// Test that Transport only sends one "Connection: close", regardless of
+// how "close" was indicated.
+func TestTransportRespectRequestWantsClose(t *testing.T) {
+ run(t, testTransportRespectRequestWantsClose, []testMode{http1Mode})
+}
+func testTransportRespectRequestWantsClose(t *testing.T, mode testMode) {
+ tests := []struct {
+ disableKeepAlives bool
+ close bool
+ }{
+ {disableKeepAlives: false, close: false},
+ {disableKeepAlives: false, close: true},
+ {disableKeepAlives: true, close: false},
+ {disableKeepAlives: true, close: true},
+ }
+
+ for _, tc := range tests {
+ t.Run(fmt.Sprintf("DisableKeepAlive=%v,RequestClose=%v", tc.disableKeepAlives, tc.close),
+ func(t *testing.T) {
+ ts := newClientServerTest(t, mode, hostPortHandler).ts
+
+ c := ts.Client()
+ c.Transport.(*Transport).DisableKeepAlives = tc.disableKeepAlives
+ req, err := NewRequest("GET", ts.URL, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ count := 0
+ trace := &httptrace.ClientTrace{
+ WroteHeaderField: func(key string, field []string) {
+ if key != "Connection" {
+ return
+ }
+ if httpguts.HeaderValuesContainsToken(field, "close") {
+ count += 1
+ }
+ },
+ }
+ req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
+ req.Close = tc.close
+ res, err := c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ if want := tc.disableKeepAlives || tc.close; count > 1 || (count == 1) != want {
+ t.Errorf("expecting want:%v, got 'Connection: close':%d", want, count)
+ }
+ })
+ }
+
+}
+
+func TestTransportIdleCacheKeys(t *testing.T) {
+ run(t, testTransportIdleCacheKeys, []testMode{http1Mode})
+}
+func testTransportIdleCacheKeys(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, hostPortHandler).ts
+ c := ts.Client()
+ tr := c.Transport.(*Transport)
+
+ if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
+ t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g)
+ }
+
+ resp, err := c.Get(ts.URL)
+ if err != nil {
+ t.Error(err)
+ }
+ io.ReadAll(resp.Body)
+
+ keys := tr.IdleConnKeysForTesting()
+ if e, g := 1, len(keys); e != g {
+ t.Fatalf("After Get expected %d idle conn cache keys; got %d", e, g)
+ }
+
+ if e := "|http|" + ts.Listener.Addr().String(); keys[0] != e {
+ t.Errorf("Expected idle cache key %q; got %q", e, keys[0])
+ }
+
+ tr.CloseIdleConnections()
+ if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
+ t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g)
+ }
+}
+
+// Tests that the HTTP transport re-uses connections when a client
+// reads to the end of a response Body without closing it.
+func TestTransportReadToEndReusesConn(t *testing.T) { run(t, testTransportReadToEndReusesConn) }
+func testTransportReadToEndReusesConn(t *testing.T, mode testMode) {
+ const msg = "foobar"
+
+ var addrSeen map[string]int
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ addrSeen[r.RemoteAddr]++
+ if r.URL.Path == "/chunked/" {
+ w.WriteHeader(200)
+ w.(Flusher).Flush()
+ } else {
+ w.Header().Set("Content-Length", strconv.Itoa(len(msg)))
+ w.WriteHeader(200)
+ }
+ w.Write([]byte(msg))
+ })).ts
+
+ for pi, path := range []string{"/content-length/", "/chunked/"} {
+ wantLen := []int{len(msg), -1}[pi]
+ addrSeen = make(map[string]int)
+ for i := 0; i < 3; i++ {
+ res, err := ts.Client().Get(ts.URL + path)
+ if err != nil {
+ t.Errorf("Get %s: %v", path, err)
+ continue
+ }
+ // We want to close this body eventually (before the
+ // defer afterTest at top runs), but not before the
+ // len(addrSeen) check at the bottom of this test,
+ // since Closing this early in the loop would risk
+ // making connections be re-used for the wrong reason.
+ defer res.Body.Close()
+
+ if res.ContentLength != int64(wantLen) {
+ t.Errorf("%s res.ContentLength = %d; want %d", path, res.ContentLength, wantLen)
+ }
+ got, err := io.ReadAll(res.Body)
+ if string(got) != msg || err != nil {
+ t.Errorf("%s ReadAll(Body) = %q, %v; want %q, nil", path, string(got), err, msg)
+ }
+ }
+ if len(addrSeen) != 1 {
+ t.Errorf("for %s, server saw %d distinct client addresses; want 1", path, len(addrSeen))
+ }
+ }
+}
+
+func TestTransportMaxPerHostIdleConns(t *testing.T) {
+ run(t, testTransportMaxPerHostIdleConns, []testMode{http1Mode})
+}
+func testTransportMaxPerHostIdleConns(t *testing.T, mode testMode) {
+ stop := make(chan struct{}) // stop marks the exit of main Test goroutine
+ defer close(stop)
+
+ resch := make(chan string)
+ gotReq := make(chan bool)
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ gotReq <- true
+ var msg string
+ select {
+ case <-stop:
+ return
+ case msg = <-resch:
+ }
+ _, err := w.Write([]byte(msg))
+ if err != nil {
+ t.Errorf("Write: %v", err)
+ return
+ }
+ })).ts
+
+ c := ts.Client()
+ tr := c.Transport.(*Transport)
+ maxIdleConnsPerHost := 2
+ tr.MaxIdleConnsPerHost = maxIdleConnsPerHost
+
+ // Start 3 outstanding requests and wait for the server to get them.
+ // Their responses will hang until we write to resch, though.
+ donech := make(chan bool)
+ doReq := func() {
+ defer func() {
+ select {
+ case <-stop:
+ return
+ case donech <- t.Failed():
+ }
+ }()
+ resp, err := c.Get(ts.URL)
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ if _, err := io.ReadAll(resp.Body); err != nil {
+ t.Errorf("ReadAll: %v", err)
+ return
+ }
+ }
+ go doReq()
+ <-gotReq
+ go doReq()
+ <-gotReq
+ go doReq()
+ <-gotReq
+
+ if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
+ t.Fatalf("Before writes, expected %d idle conn cache keys; got %d", e, g)
+ }
+
+ resch <- "res1"
+ <-donech
+ keys := tr.IdleConnKeysForTesting()
+ if e, g := 1, len(keys); e != g {
+ t.Fatalf("after first response, expected %d idle conn cache keys; got %d", e, g)
+ }
+ addr := ts.Listener.Addr().String()
+ cacheKey := "|http|" + addr
+ if keys[0] != cacheKey {
+ t.Fatalf("Expected idle cache key %q; got %q", cacheKey, keys[0])
+ }
+ if e, g := 1, tr.IdleConnCountForTesting("http", addr); e != g {
+ t.Errorf("after first response, expected %d idle conns; got %d", e, g)
+ }
+
+ resch <- "res2"
+ <-donech
+ if g, w := tr.IdleConnCountForTesting("http", addr), 2; g != w {
+ t.Errorf("after second response, idle conns = %d; want %d", g, w)
+ }
+
+ resch <- "res3"
+ <-donech
+ if g, w := tr.IdleConnCountForTesting("http", addr), maxIdleConnsPerHost; g != w {
+ t.Errorf("after third response, idle conns = %d; want %d", g, w)
+ }
+}
+
+func TestTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T) {
+ run(t, testTransportMaxConnsPerHostIncludeDialInProgress)
+}
+func testTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ _, err := w.Write([]byte("foo"))
+ if err != nil {
+ t.Fatalf("Write: %v", err)
+ }
+ })).ts
+ c := ts.Client()
+ tr := c.Transport.(*Transport)
+ dialStarted := make(chan struct{})
+ stallDial := make(chan struct{})
+ tr.Dial = func(network, addr string) (net.Conn, error) {
+ dialStarted <- struct{}{}
+ <-stallDial
+ return net.Dial(network, addr)
+ }
+
+ tr.DisableKeepAlives = true
+ tr.MaxConnsPerHost = 1
+
+ preDial := make(chan struct{})
+ reqComplete := make(chan struct{})
+ doReq := func(reqId string) {
+ req, _ := NewRequest("GET", ts.URL, nil)
+ trace := &httptrace.ClientTrace{
+ GetConn: func(hostPort string) {
+ preDial <- struct{}{}
+ },
+ }
+ req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
+ resp, err := tr.RoundTrip(req)
+ if err != nil {
+ t.Errorf("unexpected error for request %s: %v", reqId, err)
+ }
+ _, err = io.ReadAll(resp.Body)
+ if err != nil {
+ t.Errorf("unexpected error for request %s: %v", reqId, err)
+ }
+ reqComplete <- struct{}{}
+ }
+ // get req1 to dial-in-progress
+ go doReq("req1")
+ <-preDial
+ <-dialStarted
+
+ // get req2 to waiting on conns per host to go down below max
+ go doReq("req2")
+ <-preDial
+ select {
+ case <-dialStarted:
+ t.Error("req2 dial started while req1 dial in progress")
+ return
+ default:
+ }
+
+ // let req1 complete
+ stallDial <- struct{}{}
+ <-reqComplete
+
+ // let req2 complete
+ <-dialStarted
+ stallDial <- struct{}{}
+ <-reqComplete
+}
+
+func TestTransportMaxConnsPerHost(t *testing.T) {
+ run(t, testTransportMaxConnsPerHost, []testMode{http1Mode, https1Mode, http2Mode})
+}
+func testTransportMaxConnsPerHost(t *testing.T, mode testMode) {
+ CondSkipHTTP2(t)
+
+ h := HandlerFunc(func(w ResponseWriter, r *Request) {
+ _, err := w.Write([]byte("foo"))
+ if err != nil {
+ t.Fatalf("Write: %v", err)
+ }
+ })
+
+ ts := newClientServerTest(t, mode, h).ts
+ c := ts.Client()
+ tr := c.Transport.(*Transport)
+ tr.MaxConnsPerHost = 1
+
+ mu := sync.Mutex{}
+ var conns []net.Conn
+ var dialCnt, gotConnCnt, tlsHandshakeCnt int32
+ tr.Dial = func(network, addr string) (net.Conn, error) {
+ atomic.AddInt32(&dialCnt, 1)
+ c, err := net.Dial(network, addr)
+ mu.Lock()
+ defer mu.Unlock()
+ conns = append(conns, c)
+ return c, err
+ }
+
+ doReq := func() {
+ trace := &httptrace.ClientTrace{
+ GotConn: func(connInfo httptrace.GotConnInfo) {
+ if !connInfo.Reused {
+ atomic.AddInt32(&gotConnCnt, 1)
+ }
+ },
+ TLSHandshakeStart: func() {
+ atomic.AddInt32(&tlsHandshakeCnt, 1)
+ },
+ }
+ req, _ := NewRequest("GET", ts.URL, nil)
+ req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
+
+ resp, err := c.Do(req)
+ if err != nil {
+ t.Fatalf("request failed: %v", err)
+ }
+ defer resp.Body.Close()
+ _, err = io.ReadAll(resp.Body)
+ if err != nil {
+ t.Fatalf("read body failed: %v", err)
+ }
+ }
+
+ wg := sync.WaitGroup{}
+ for i := 0; i < 10; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ doReq()
+ }()
+ }
+ wg.Wait()
+
+ expected := int32(tr.MaxConnsPerHost)
+ if dialCnt != expected {
+ t.Errorf("round 1: too many dials: %d != %d", dialCnt, expected)
+ }
+ if gotConnCnt != expected {
+ t.Errorf("round 1: too many get connections: %d != %d", gotConnCnt, expected)
+ }
+ if ts.TLS != nil && tlsHandshakeCnt != expected {
+ t.Errorf("round 1: too many tls handshakes: %d != %d", tlsHandshakeCnt, expected)
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ mu.Lock()
+ for _, c := range conns {
+ c.Close()
+ }
+ conns = nil
+ mu.Unlock()
+ tr.CloseIdleConnections()
+
+ doReq()
+ expected++
+ if dialCnt != expected {
+ t.Errorf("round 2: too many dials: %d", dialCnt)
+ }
+ if gotConnCnt != expected {
+ t.Errorf("round 2: too many get connections: %d != %d", gotConnCnt, expected)
+ }
+ if ts.TLS != nil && tlsHandshakeCnt != expected {
+ t.Errorf("round 2: too many tls handshakes: %d != %d", tlsHandshakeCnt, expected)
+ }
+}
+
+func TestTransportRemovesDeadIdleConnections(t *testing.T) {
+ run(t, testTransportRemovesDeadIdleConnections, []testMode{http1Mode})
+}
+func testTransportRemovesDeadIdleConnections(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ io.WriteString(w, r.RemoteAddr)
+ })).ts
+
+ c := ts.Client()
+ tr := c.Transport.(*Transport)
+
+ doReq := func(name string) {
+ // Do a POST instead of a GET to prevent the Transport's
+ // idempotent request retry logic from kicking in...
+ res, err := c.Post(ts.URL, "", nil)
+ if err != nil {
+ t.Fatalf("%s: %v", name, err)
+ }
+ if res.StatusCode != 200 {
+ t.Fatalf("%s: %v", name, res.Status)
+ }
+ defer res.Body.Close()
+ slurp, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatalf("%s: %v", name, err)
+ }
+ t.Logf("%s: ok (%q)", name, slurp)
+ }
+
+ doReq("first")
+ keys1 := tr.IdleConnKeysForTesting()
+
+ ts.CloseClientConnections()
+
+ var keys2 []string
+ waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
+ keys2 = tr.IdleConnKeysForTesting()
+ if len(keys2) != 0 {
+ if d > 0 {
+ t.Logf("Transport hasn't noticed idle connection's death in %v.\nbefore: %q\n after: %q\n", d, keys1, keys2)
+ }
+ return false
+ }
+ return true
+ })
+
+ doReq("second")
+}
+
+// Test that the Transport notices when a server hangs up on its
+// unexpectedly (a keep-alive connection is closed).
+func TestTransportServerClosingUnexpectedly(t *testing.T) {
+ run(t, testTransportServerClosingUnexpectedly, []testMode{http1Mode})
+}
+func testTransportServerClosingUnexpectedly(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, hostPortHandler).ts
+ c := ts.Client()
+
+ fetch := func(n, retries int) string {
+ condFatalf := func(format string, arg ...any) {
+ if retries <= 0 {
+ t.Fatalf(format, arg...)
+ }
+ t.Logf("retrying shortly after expected error: "+format, arg...)
+ time.Sleep(time.Second / time.Duration(retries))
+ }
+ for retries >= 0 {
+ retries--
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ condFatalf("error in req #%d, GET: %v", n, err)
+ continue
+ }
+ body, err := io.ReadAll(res.Body)
+ if err != nil {
+ condFatalf("error in req #%d, ReadAll: %v", n, err)
+ continue
+ }
+ res.Body.Close()
+ return string(body)
+ }
+ panic("unreachable")
+ }
+
+ body1 := fetch(1, 0)
+ body2 := fetch(2, 0)
+
+ // Close all the idle connections in a way that's similar to
+ // the server hanging up on us. We don't use
+ // httptest.Server.CloseClientConnections because it's
+ // best-effort and stops blocking after 5 seconds. On a loaded
+ // machine running many tests concurrently it's possible for
+ // that method to be async and cause the body3 fetch below to
+ // run on an old connection. This function is synchronous.
+ ExportCloseTransportConnsAbruptly(c.Transport.(*Transport))
+
+ body3 := fetch(3, 5)
+
+ if body1 != body2 {
+ t.Errorf("expected body1 and body2 to be equal")
+ }
+ if body2 == body3 {
+ t.Errorf("expected body2 and body3 to be different")
+ }
+}
+
+// Test for https://golang.org/issue/2616 (appropriate issue number)
+// This fails pretty reliably with GOMAXPROCS=100 or something high.
+func TestStressSurpriseServerCloses(t *testing.T) {
+ run(t, testStressSurpriseServerCloses, []testMode{http1Mode})
+}
+func testStressSurpriseServerCloses(t *testing.T, mode testMode) {
+ if testing.Short() {
+ t.Skip("skipping test in short mode")
+ }
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Length", "5")
+ w.Header().Set("Content-Type", "text/plain")
+ w.Write([]byte("Hello"))
+ w.(Flusher).Flush()
+ conn, buf, _ := w.(Hijacker).Hijack()
+ buf.Flush()
+ conn.Close()
+ })).ts
+ c := ts.Client()
+
+ // Do a bunch of traffic from different goroutines. Send to activityc
+ // after each request completes, regardless of whether it failed.
+ // If these are too high, OS X exhausts its ephemeral ports
+ // and hangs waiting for them to transition TCP states. That's
+ // not what we want to test. TODO(bradfitz): use an io.Pipe
+ // dialer for this test instead?
+ const (
+ numClients = 20
+ reqsPerClient = 25
+ )
+ var wg sync.WaitGroup
+ wg.Add(numClients * reqsPerClient)
+ for i := 0; i < numClients; i++ {
+ go func() {
+ for i := 0; i < reqsPerClient; i++ {
+ res, err := c.Get(ts.URL)
+ if err == nil {
+ // We expect errors since the server is
+ // hanging up on us after telling us to
+ // send more requests, so we don't
+ // actually care what the error is.
+ // But we want to close the body in cases
+ // where we won the race.
+ res.Body.Close()
+ }
+ wg.Done()
+ }
+ }()
+ }
+
+ // Make sure all the request come back, one way or another.
+ wg.Wait()
+}
+
+// TestTransportHeadResponses verifies that we deal with Content-Lengths
+// with no bodies properly
+func TestTransportHeadResponses(t *testing.T) { run(t, testTransportHeadResponses) }
+func testTransportHeadResponses(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.Method != "HEAD" {
+ panic("expected HEAD; got " + r.Method)
+ }
+ w.Header().Set("Content-Length", "123")
+ w.WriteHeader(200)
+ })).ts
+ c := ts.Client()
+
+ for i := 0; i < 2; i++ {
+ res, err := c.Head(ts.URL)
+ if err != nil {
+ t.Errorf("error on loop %d: %v", i, err)
+ continue
+ }
+ if e, g := "123", res.Header.Get("Content-Length"); e != g {
+ t.Errorf("loop %d: expected Content-Length header of %q, got %q", i, e, g)
+ }
+ if e, g := int64(123), res.ContentLength; e != g {
+ t.Errorf("loop %d: expected res.ContentLength of %v, got %v", i, e, g)
+ }
+ if all, err := io.ReadAll(res.Body); err != nil {
+ t.Errorf("loop %d: Body ReadAll: %v", i, err)
+ } else if len(all) != 0 {
+ t.Errorf("Bogus body %q", all)
+ }
+ }
+}
+
+// TestTransportHeadChunkedResponse verifies that we ignore chunked transfer-encoding
+// on responses to HEAD requests.
+func TestTransportHeadChunkedResponse(t *testing.T) {
+ run(t, testTransportHeadChunkedResponse, []testMode{http1Mode}, testNotParallel)
+}
+func testTransportHeadChunkedResponse(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.Method != "HEAD" {
+ panic("expected HEAD; got " + r.Method)
+ }
+ w.Header().Set("Transfer-Encoding", "chunked") // client should ignore
+ w.Header().Set("x-client-ipport", r.RemoteAddr)
+ w.WriteHeader(200)
+ })).ts
+ c := ts.Client()
+
+ // Ensure that we wait for the readLoop to complete before
+ // calling Head again
+ didRead := make(chan bool)
+ SetReadLoopBeforeNextReadHook(func() { didRead <- true })
+ defer SetReadLoopBeforeNextReadHook(nil)
+
+ res1, err := c.Head(ts.URL)
+ <-didRead
+
+ if err != nil {
+ t.Fatalf("request 1 error: %v", err)
+ }
+
+ res2, err := c.Head(ts.URL)
+ <-didRead
+
+ if err != nil {
+ t.Fatalf("request 2 error: %v", err)
+ }
+ if v1, v2 := res1.Header.Get("x-client-ipport"), res2.Header.Get("x-client-ipport"); v1 != v2 {
+ t.Errorf("ip/ports differed between head requests: %q vs %q", v1, v2)
+ }
+}
+
+var roundTripTests = []struct {
+ accept string
+ expectAccept string
+ compressed bool
+}{
+ // Requests with no accept-encoding header use transparent compression
+ {"", "gzip", false},
+ // Requests with other accept-encoding should pass through unmodified
+ {"foo", "foo", false},
+ // Requests with accept-encoding == gzip should be passed through
+ {"gzip", "gzip", true},
+}
+
+// Test that the modification made to the Request by the RoundTripper is cleaned up
+func TestRoundTripGzip(t *testing.T) { run(t, testRoundTripGzip) }
+func testRoundTripGzip(t *testing.T, mode testMode) {
+ const responseBody = "test response body"
+ ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
+ accept := req.Header.Get("Accept-Encoding")
+ if expect := req.FormValue("expect_accept"); accept != expect {
+ t.Errorf("in handler, test %v: Accept-Encoding = %q, want %q",
+ req.FormValue("testnum"), accept, expect)
+ }
+ if accept == "gzip" {
+ rw.Header().Set("Content-Encoding", "gzip")
+ gz := gzip.NewWriter(rw)
+ gz.Write([]byte(responseBody))
+ gz.Close()
+ } else {
+ rw.Header().Set("Content-Encoding", accept)
+ rw.Write([]byte(responseBody))
+ }
+ })).ts
+ tr := ts.Client().Transport.(*Transport)
+
+ for i, test := range roundTripTests {
+ // Test basic request (no accept-encoding)
+ req, _ := NewRequest("GET", fmt.Sprintf("%s/?testnum=%d&expect_accept=%s", ts.URL, i, test.expectAccept), nil)
+ if test.accept != "" {
+ req.Header.Set("Accept-Encoding", test.accept)
+ }
+ res, err := tr.RoundTrip(req)
+ if err != nil {
+ t.Errorf("%d. RoundTrip: %v", i, err)
+ continue
+ }
+ var body []byte
+ if test.compressed {
+ var r *gzip.Reader
+ r, err = gzip.NewReader(res.Body)
+ if err != nil {
+ t.Errorf("%d. gzip NewReader: %v", i, err)
+ continue
+ }
+ body, err = io.ReadAll(r)
+ res.Body.Close()
+ } else {
+ body, err = io.ReadAll(res.Body)
+ }
+ if err != nil {
+ t.Errorf("%d. Error: %q", i, err)
+ continue
+ }
+ if g, e := string(body), responseBody; g != e {
+ t.Errorf("%d. body = %q; want %q", i, g, e)
+ }
+ if g, e := req.Header.Get("Accept-Encoding"), test.accept; g != e {
+ t.Errorf("%d. Accept-Encoding = %q; want %q (it was mutated, in violation of RoundTrip contract)", i, g, e)
+ }
+ if g, e := res.Header.Get("Content-Encoding"), test.accept; g != e {
+ t.Errorf("%d. Content-Encoding = %q; want %q", i, g, e)
+ }
+ }
+
+}
+
+func TestTransportGzip(t *testing.T) { run(t, testTransportGzip) }
+func testTransportGzip(t *testing.T, mode testMode) {
+ if mode == http2Mode {
+ t.Skip("https://go.dev/issue/56020")
+ }
+ const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
+ const nRandBytes = 1024 * 1024
+ ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
+ if req.Method == "HEAD" {
+ if g := req.Header.Get("Accept-Encoding"); g != "" {
+ t.Errorf("HEAD request sent with Accept-Encoding of %q; want none", g)
+ }
+ return
+ }
+ if g, e := req.Header.Get("Accept-Encoding"), "gzip"; g != e {
+ t.Errorf("Accept-Encoding = %q, want %q", g, e)
+ }
+ rw.Header().Set("Content-Encoding", "gzip")
+
+ var w io.Writer = rw
+ var buf bytes.Buffer
+ if req.FormValue("chunked") == "0" {
+ w = &buf
+ defer io.Copy(rw, &buf)
+ defer func() {
+ rw.Header().Set("Content-Length", strconv.Itoa(buf.Len()))
+ }()
+ }
+ gz := gzip.NewWriter(w)
+ gz.Write([]byte(testString))
+ if req.FormValue("body") == "large" {
+ io.CopyN(gz, rand.Reader, nRandBytes)
+ }
+ gz.Close()
+ })).ts
+ c := ts.Client()
+
+ for _, chunked := range []string{"1", "0"} {
+ // First fetch something large, but only read some of it.
+ res, err := c.Get(ts.URL + "/?body=large&chunked=" + chunked)
+ if err != nil {
+ t.Fatalf("large get: %v", err)
+ }
+ buf := make([]byte, len(testString))
+ n, err := io.ReadFull(res.Body, buf)
+ if err != nil {
+ t.Fatalf("partial read of large response: size=%d, %v", n, err)
+ }
+ if e, g := testString, string(buf); e != g {
+ t.Errorf("partial read got %q, expected %q", g, e)
+ }
+ res.Body.Close()
+ // Read on the body, even though it's closed
+ n, err = res.Body.Read(buf)
+ if n != 0 || err == nil {
+ t.Errorf("expected error post-closed large Read; got = %d, %v", n, err)
+ }
+
+ // Then something small.
+ res, err = c.Get(ts.URL + "/?chunked=" + chunked)
+ if err != nil {
+ t.Fatal(err)
+ }
+ body, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if g, e := string(body), testString; g != e {
+ t.Fatalf("body = %q; want %q", g, e)
+ }
+ if g, e := res.Header.Get("Content-Encoding"), ""; g != e {
+ t.Fatalf("Content-Encoding = %q; want %q", g, e)
+ }
+
+ // Read on the body after it's been fully read:
+ n, err = res.Body.Read(buf)
+ if n != 0 || err == nil {
+ t.Errorf("expected Read error after exhausted reads; got %d, %v", n, err)
+ }
+ res.Body.Close()
+ n, err = res.Body.Read(buf)
+ if n != 0 || err == nil {
+ t.Errorf("expected Read error after Close; got %d, %v", n, err)
+ }
+ }
+
+ // And a HEAD request too, because they're always weird.
+ res, err := c.Head(ts.URL)
+ if err != nil {
+ t.Fatalf("Head: %v", err)
+ }
+ if res.StatusCode != 200 {
+ t.Errorf("Head status=%d; want=200", res.StatusCode)
+ }
+}
+
+// If a request has Expect:100-continue header, the request blocks sending body until the first response.
+// Premature consumption of the request body should not be occurred.
+func TestTransportExpect100Continue(t *testing.T) {
+ run(t, testTransportExpect100Continue, []testMode{http1Mode})
+}
+func testTransportExpect100Continue(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
+ switch req.URL.Path {
+ case "/100":
+ // This endpoint implicitly responds 100 Continue and reads body.
+ if _, err := io.Copy(io.Discard, req.Body); err != nil {
+ t.Error("Failed to read Body", err)
+ }
+ rw.WriteHeader(StatusOK)
+ case "/200":
+ // Go 1.5 adds Connection: close header if the client expect
+ // continue but not entire request body is consumed.
+ rw.WriteHeader(StatusOK)
+ case "/500":
+ rw.WriteHeader(StatusInternalServerError)
+ case "/keepalive":
+ // This hijacked endpoint responds error without Connection:close.
+ _, bufrw, err := rw.(Hijacker).Hijack()
+ if err != nil {
+ log.Fatal(err)
+ }
+ bufrw.WriteString("HTTP/1.1 500 Internal Server Error\r\n")
+ bufrw.WriteString("Content-Length: 0\r\n\r\n")
+ bufrw.Flush()
+ case "/timeout":
+ // This endpoint tries to read body without 100 (Continue) response.
+ // After ExpectContinueTimeout, the reading will be started.
+ conn, bufrw, err := rw.(Hijacker).Hijack()
+ if err != nil {
+ log.Fatal(err)
+ }
+ if _, err := io.CopyN(io.Discard, bufrw, req.ContentLength); err != nil {
+ t.Error("Failed to read Body", err)
+ }
+ bufrw.WriteString("HTTP/1.1 200 OK\r\n\r\n")
+ bufrw.Flush()
+ conn.Close()
+ }
+
+ })).ts
+
+ tests := []struct {
+ path string
+ body []byte
+ sent int
+ status int
+ }{
+ {path: "/100", body: []byte("hello"), sent: 5, status: 200}, // Got 100 followed by 200, entire body is sent.
+ {path: "/200", body: []byte("hello"), sent: 0, status: 200}, // Got 200 without 100. body isn't sent.
+ {path: "/500", body: []byte("hello"), sent: 0, status: 500}, // Got 500 without 100. body isn't sent.
+ {path: "/keepalive", body: []byte("hello"), sent: 0, status: 500}, // Although without Connection:close, body isn't sent.
+ {path: "/timeout", body: []byte("hello"), sent: 5, status: 200}, // Timeout exceeded and entire body is sent.
+ }
+
+ c := ts.Client()
+ for i, v := range tests {
+ tr := &Transport{
+ ExpectContinueTimeout: 2 * time.Second,
+ }
+ defer tr.CloseIdleConnections()
+ c.Transport = tr
+ body := bytes.NewReader(v.body)
+ req, err := NewRequest("PUT", ts.URL+v.path, body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ req.Header.Set("Expect", "100-continue")
+ req.ContentLength = int64(len(v.body))
+
+ resp, err := c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ resp.Body.Close()
+
+ sent := len(v.body) - body.Len()
+ if v.status != resp.StatusCode {
+ t.Errorf("test %d: status code should be %d but got %d. (%s)", i, v.status, resp.StatusCode, v.path)
+ }
+ if v.sent != sent {
+ t.Errorf("test %d: sent body should be %d but sent %d. (%s)", i, v.sent, sent, v.path)
+ }
+ }
+}
+
+func TestSOCKS5Proxy(t *testing.T) {
+ run(t, testSOCKS5Proxy, []testMode{http1Mode, https1Mode, http2Mode})
+}
+func testSOCKS5Proxy(t *testing.T, mode testMode) {
+ ch := make(chan string, 1)
+ l := newLocalListener(t)
+ defer l.Close()
+ defer close(ch)
+ proxy := func(t *testing.T) {
+ s, err := l.Accept()
+ if err != nil {
+ t.Errorf("socks5 proxy Accept(): %v", err)
+ return
+ }
+ defer s.Close()
+ var buf [22]byte
+ if _, err := io.ReadFull(s, buf[:3]); err != nil {
+ t.Errorf("socks5 proxy initial read: %v", err)
+ return
+ }
+ if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) {
+ t.Errorf("socks5 proxy initial read: got %v, want %v", buf[:3], want)
+ return
+ }
+ if _, err := s.Write([]byte{5, 0}); err != nil {
+ t.Errorf("socks5 proxy initial write: %v", err)
+ return
+ }
+ if _, err := io.ReadFull(s, buf[:4]); err != nil {
+ t.Errorf("socks5 proxy second read: %v", err)
+ return
+ }
+ if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) {
+ t.Errorf("socks5 proxy second read: got %v, want %v", buf[:3], want)
+ return
+ }
+ var ipLen int
+ switch buf[3] {
+ case 1:
+ ipLen = net.IPv4len
+ case 4:
+ ipLen = net.IPv6len
+ default:
+ t.Errorf("socks5 proxy second read: unexpected address type %v", buf[4])
+ return
+ }
+ if _, err := io.ReadFull(s, buf[4:ipLen+6]); err != nil {
+ t.Errorf("socks5 proxy address read: %v", err)
+ return
+ }
+ ip := net.IP(buf[4 : ipLen+4])
+ port := binary.BigEndian.Uint16(buf[ipLen+4 : ipLen+6])
+ copy(buf[:3], []byte{5, 0, 0})
+ if _, err := s.Write(buf[:ipLen+6]); err != nil {
+ t.Errorf("socks5 proxy connect write: %v", err)
+ return
+ }
+ ch <- fmt.Sprintf("proxy for %s:%d", ip, port)
+
+ // Implement proxying.
+ targetHost := net.JoinHostPort(ip.String(), strconv.Itoa(int(port)))
+ targetConn, err := net.Dial("tcp", targetHost)
+ if err != nil {
+ t.Errorf("net.Dial failed")
+ return
+ }
+ go io.Copy(targetConn, s)
+ io.Copy(s, targetConn) // Wait for the client to close the socket.
+ targetConn.Close()
+ }
+
+ pu, err := url.Parse("socks5://" + l.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ sentinelHeader := "X-Sentinel"
+ sentinelValue := "12345"
+ h := HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set(sentinelHeader, sentinelValue)
+ })
+ for _, useTLS := range []bool{false, true} {
+ t.Run(fmt.Sprintf("useTLS=%v", useTLS), func(t *testing.T) {
+ ts := newClientServerTest(t, mode, h).ts
+ go proxy(t)
+ c := ts.Client()
+ c.Transport.(*Transport).Proxy = ProxyURL(pu)
+ r, err := c.Head(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if r.Header.Get(sentinelHeader) != sentinelValue {
+ t.Errorf("Failed to retrieve sentinel value")
+ }
+ got := <-ch
+ ts.Close()
+ tsu, err := url.Parse(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ want := "proxy for " + tsu.Host
+ if got != want {
+ t.Errorf("got %q, want %q", got, want)
+ }
+ })
+ }
+}
+
+func TestTransportProxy(t *testing.T) {
+ defer afterTest(t)
+ testCases := []struct{ siteMode, proxyMode testMode }{
+ {http1Mode, http1Mode},
+ {http1Mode, https1Mode},
+ {https1Mode, http1Mode},
+ {https1Mode, https1Mode},
+ }
+ for _, testCase := range testCases {
+ siteMode := testCase.siteMode
+ proxyMode := testCase.proxyMode
+ t.Run(fmt.Sprintf("site=%v/proxy=%v", siteMode, proxyMode), func(t *testing.T) {
+ siteCh := make(chan *Request, 1)
+ h1 := HandlerFunc(func(w ResponseWriter, r *Request) {
+ siteCh <- r
+ })
+ proxyCh := make(chan *Request, 1)
+ h2 := HandlerFunc(func(w ResponseWriter, r *Request) {
+ proxyCh <- r
+ // Implement an entire CONNECT proxy
+ if r.Method == "CONNECT" {
+ hijacker, ok := w.(Hijacker)
+ if !ok {
+ t.Errorf("hijack not allowed")
+ return
+ }
+ clientConn, _, err := hijacker.Hijack()
+ if err != nil {
+ t.Errorf("hijacking failed")
+ return
+ }
+ res := &Response{
+ StatusCode: StatusOK,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: make(Header),
+ }
+
+ targetConn, err := net.Dial("tcp", r.URL.Host)
+ if err != nil {
+ t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err)
+ return
+ }
+
+ if err := res.Write(clientConn); err != nil {
+ t.Errorf("Writing 200 OK failed: %v", err)
+ return
+ }
+
+ go io.Copy(targetConn, clientConn)
+ go func() {
+ io.Copy(clientConn, targetConn)
+ targetConn.Close()
+ }()
+ }
+ })
+ ts := newClientServerTest(t, siteMode, h1).ts
+ proxy := newClientServerTest(t, proxyMode, h2).ts
+
+ pu, err := url.Parse(proxy.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // If neither server is HTTPS or both are, then c may be derived from either.
+ // If only one server is HTTPS, c must be derived from that server in order
+ // to ensure that it is configured to use the fake root CA from testcert.go.
+ c := proxy.Client()
+ if siteMode == https1Mode {
+ c = ts.Client()
+ }
+
+ c.Transport.(*Transport).Proxy = ProxyURL(pu)
+ if _, err := c.Head(ts.URL); err != nil {
+ t.Error(err)
+ }
+ got := <-proxyCh
+ c.Transport.(*Transport).CloseIdleConnections()
+ ts.Close()
+ proxy.Close()
+ if siteMode == https1Mode {
+ // First message should be a CONNECT, asking for a socket to the real server,
+ if got.Method != "CONNECT" {
+ t.Errorf("Wrong method for secure proxying: %q", got.Method)
+ }
+ gotHost := got.URL.Host
+ pu, err := url.Parse(ts.URL)
+ if err != nil {
+ t.Fatal("Invalid site URL")
+ }
+ if wantHost := pu.Host; gotHost != wantHost {
+ t.Errorf("Got CONNECT host %q, want %q", gotHost, wantHost)
+ }
+
+ // The next message on the channel should be from the site's server.
+ next := <-siteCh
+ if next.Method != "HEAD" {
+ t.Errorf("Wrong method at destination: %s", next.Method)
+ }
+ if nextURL := next.URL.String(); nextURL != "/" {
+ t.Errorf("Wrong URL at destination: %s", nextURL)
+ }
+ } else {
+ if got.Method != "HEAD" {
+ t.Errorf("Wrong method for destination: %q", got.Method)
+ }
+ gotURL := got.URL.String()
+ wantURL := ts.URL + "/"
+ if gotURL != wantURL {
+ t.Errorf("Got URL %q, want %q", gotURL, wantURL)
+ }
+ }
+ })
+ }
+}
+
+func TestOnProxyConnectResponse(t *testing.T) {
+
+ var tcases = []struct {
+ proxyStatusCode int
+ err error
+ }{
+ {
+ StatusOK,
+ nil,
+ },
+ {
+ StatusForbidden,
+ errors.New("403"),
+ },
+ }
+ for _, tcase := range tcases {
+ h1 := HandlerFunc(func(w ResponseWriter, r *Request) {
+
+ })
+
+ h2 := HandlerFunc(func(w ResponseWriter, r *Request) {
+ // Implement an entire CONNECT proxy
+ if r.Method == "CONNECT" {
+ if tcase.proxyStatusCode != StatusOK {
+ w.WriteHeader(tcase.proxyStatusCode)
+ return
+ }
+ hijacker, ok := w.(Hijacker)
+ if !ok {
+ t.Errorf("hijack not allowed")
+ return
+ }
+ clientConn, _, err := hijacker.Hijack()
+ if err != nil {
+ t.Errorf("hijacking failed")
+ return
+ }
+ res := &Response{
+ StatusCode: StatusOK,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: make(Header),
+ }
+
+ targetConn, err := net.Dial("tcp", r.URL.Host)
+ if err != nil {
+ t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err)
+ return
+ }
+
+ if err := res.Write(clientConn); err != nil {
+ t.Errorf("Writing 200 OK failed: %v", err)
+ return
+ }
+
+ go io.Copy(targetConn, clientConn)
+ go func() {
+ io.Copy(clientConn, targetConn)
+ targetConn.Close()
+ }()
+ }
+ })
+ ts := newClientServerTest(t, https1Mode, h1).ts
+ proxy := newClientServerTest(t, https1Mode, h2).ts
+
+ pu, err := url.Parse(proxy.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ c := proxy.Client()
+
+ c.Transport.(*Transport).Proxy = ProxyURL(pu)
+ c.Transport.(*Transport).OnProxyConnectResponse = func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error {
+ if proxyURL.String() != pu.String() {
+ t.Errorf("proxy url got %s, want %s", proxyURL, pu)
+ }
+
+ if "https://"+connectReq.URL.String() != ts.URL {
+ t.Errorf("connect url got %s, want %s", connectReq.URL, ts.URL)
+ }
+ return tcase.err
+ }
+ if _, err := c.Head(ts.URL); err != nil {
+ if tcase.err != nil && !strings.Contains(err.Error(), tcase.err.Error()) {
+ t.Errorf("got %v, want %v", err, tcase.err)
+ }
+ }
+ }
+}
+
+// Issue 28012: verify that the Transport closes its TCP connection to http proxies
+// when they're slow to reply to HTTPS CONNECT responses.
+func TestTransportProxyHTTPSConnectLeak(t *testing.T) {
+ setParallel(t)
+ defer afterTest(t)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ ln := newLocalListener(t)
+ defer ln.Close()
+ listenerDone := make(chan struct{})
+ go func() {
+ defer close(listenerDone)
+ c, err := ln.Accept()
+ if err != nil {
+ t.Errorf("Accept: %v", err)
+ return
+ }
+ defer c.Close()
+ // Read the CONNECT request
+ br := bufio.NewReader(c)
+ cr, err := ReadRequest(br)
+ if err != nil {
+ t.Errorf("proxy server failed to read CONNECT request")
+ return
+ }
+ if cr.Method != "CONNECT" {
+ t.Errorf("unexpected method %q", cr.Method)
+ return
+ }
+
+ // Now hang and never write a response; instead, cancel the request and wait
+ // for the client to close.
+ // (Prior to Issue 28012 being fixed, we never closed.)
+ cancel()
+ var buf [1]byte
+ _, err = br.Read(buf[:])
+ if err != io.EOF {
+ t.Errorf("proxy server Read err = %v; want EOF", err)
+ }
+ return
+ }()
+
+ c := &Client{
+ Transport: &Transport{
+ Proxy: func(*Request) (*url.URL, error) {
+ return url.Parse("http://" + ln.Addr().String())
+ },
+ },
+ }
+ req, err := NewRequestWithContext(ctx, "GET", "https://golang.fake.tld/", nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ _, err = c.Do(req)
+ if err == nil {
+ t.Errorf("unexpected Get success")
+ }
+
+ // Wait unconditionally for the listener goroutine to exit: this should never
+ // hang, so if it does we want a full goroutine dump — and that's exactly what
+ // the testing package will give us when the test run times out.
+ <-listenerDone
+}
+
+// Issue 16997: test transport dial preserves typed errors
+func TestTransportDialPreservesNetOpProxyError(t *testing.T) {
+ defer afterTest(t)
+
+ var errDial = errors.New("some dial error")
+
+ tr := &Transport{
+ Proxy: func(*Request) (*url.URL, error) {
+ return url.Parse("http://proxy.fake.tld/")
+ },
+ Dial: func(string, string) (net.Conn, error) {
+ return nil, errDial
+ },
+ }
+ defer tr.CloseIdleConnections()
+
+ c := &Client{Transport: tr}
+ req, _ := NewRequest("GET", "http://fake.tld", nil)
+ res, err := c.Do(req)
+ if err == nil {
+ res.Body.Close()
+ t.Fatal("wanted a non-nil error")
+ }
+
+ uerr, ok := err.(*url.Error)
+ if !ok {
+ t.Fatalf("got %T, want *url.Error", err)
+ }
+ oe, ok := uerr.Err.(*net.OpError)
+ if !ok {
+ t.Fatalf("url.Error.Err = %T; want *net.OpError", uerr.Err)
+ }
+ want := &net.OpError{
+ Op: "proxyconnect",
+ Net: "tcp",
+ Err: errDial, // original error, unwrapped.
+ }
+ if !reflect.DeepEqual(oe, want) {
+ t.Errorf("Got error %#v; want %#v", oe, want)
+ }
+}
+
+// Issue 36431: calls to RoundTrip should not mutate t.ProxyConnectHeader.
+//
+// (A bug caused dialConn to instead write the per-request Proxy-Authorization
+// header through to the shared Header instance, introducing a data race.)
+func TestTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T) {
+ run(t, testTransportProxyDialDoesNotMutateProxyConnectHeader)
+}
+func testTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T, mode testMode) {
+ proxy := newClientServerTest(t, mode, NotFoundHandler()).ts
+ defer proxy.Close()
+ c := proxy.Client()
+
+ tr := c.Transport.(*Transport)
+ tr.Proxy = func(*Request) (*url.URL, error) {
+ u, _ := url.Parse(proxy.URL)
+ u.User = url.UserPassword("aladdin", "opensesame")
+ return u, nil
+ }
+ h := tr.ProxyConnectHeader
+ if h == nil {
+ h = make(Header)
+ }
+ tr.ProxyConnectHeader = h.Clone()
+
+ req, err := NewRequest("GET", "https://golang.fake.tld/", nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ _, err = c.Do(req)
+ if err == nil {
+ t.Errorf("unexpected Get success")
+ }
+
+ if !reflect.DeepEqual(tr.ProxyConnectHeader, h) {
+ t.Errorf("tr.ProxyConnectHeader = %v; want %v", tr.ProxyConnectHeader, h)
+ }
+}
+
+// TestTransportGzipRecursive sends a gzip quine and checks that the
+// client gets the same value back. This is more cute than anything,
+// but checks that we don't recurse forever, and checks that
+// Content-Encoding is removed.
+func TestTransportGzipRecursive(t *testing.T) { run(t, testTransportGzipRecursive) }
+func testTransportGzipRecursive(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Encoding", "gzip")
+ w.Write(rgz)
+ })).ts
+
+ c := ts.Client()
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ body, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !bytes.Equal(body, rgz) {
+ t.Fatalf("Incorrect result from recursive gz:\nhave=%x\nwant=%x",
+ body, rgz)
+ }
+ if g, e := res.Header.Get("Content-Encoding"), ""; g != e {
+ t.Fatalf("Content-Encoding = %q; want %q", g, e)
+ }
+}
+
+// golang.org/issue/7750: request fails when server replies with
+// a short gzip body
+func TestTransportGzipShort(t *testing.T) { run(t, testTransportGzipShort) }
+func testTransportGzipShort(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Encoding", "gzip")
+ w.Write([]byte{0x1f, 0x8b})
+ })).ts
+
+ c := ts.Client()
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ _, err = io.ReadAll(res.Body)
+ if err == nil {
+ t.Fatal("Expect an error from reading a body.")
+ }
+ if err != io.ErrUnexpectedEOF {
+ t.Errorf("ReadAll error = %v; want io.ErrUnexpectedEOF", err)
+ }
+}
+
+// Wait until number of goroutines is no greater than nmax, or time out.
+func waitNumGoroutine(nmax int) int {
+ nfinal := runtime.NumGoroutine()
+ for ntries := 10; ntries > 0 && nfinal > nmax; ntries-- {
+ time.Sleep(50 * time.Millisecond)
+ runtime.GC()
+ nfinal = runtime.NumGoroutine()
+ }
+ return nfinal
+}
+
+// tests that persistent goroutine connections shut down when no longer desired.
+func TestTransportPersistConnLeak(t *testing.T) {
+ run(t, testTransportPersistConnLeak, testNotParallel)
+}
+func testTransportPersistConnLeak(t *testing.T, mode testMode) {
+ if mode == http2Mode {
+ t.Skip("flaky in HTTP/2")
+ }
+ // Not parallel: counts goroutines
+
+ const numReq = 25
+ gotReqCh := make(chan bool, numReq)
+ unblockCh := make(chan bool, numReq)
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ gotReqCh <- true
+ <-unblockCh
+ w.Header().Set("Content-Length", "0")
+ w.WriteHeader(204)
+ })).ts
+ c := ts.Client()
+ tr := c.Transport.(*Transport)
+
+ n0 := runtime.NumGoroutine()
+
+ didReqCh := make(chan bool, numReq)
+ failed := make(chan bool, numReq)
+ for i := 0; i < numReq; i++ {
+ go func() {
+ res, err := c.Get(ts.URL)
+ didReqCh <- true
+ if err != nil {
+ t.Logf("client fetch error: %v", err)
+ failed <- true
+ return
+ }
+ res.Body.Close()
+ }()
+ }
+
+ // Wait for all goroutines to be stuck in the Handler.
+ for i := 0; i < numReq; i++ {
+ select {
+ case <-gotReqCh:
+ // ok
+ case <-failed:
+ // Not great but not what we are testing:
+ // sometimes an overloaded system will fail to make all the connections.
+ }
+ }
+
+ nhigh := runtime.NumGoroutine()
+
+ // Tell all handlers to unblock and reply.
+ close(unblockCh)
+
+ // Wait for all HTTP clients to be done.
+ for i := 0; i < numReq; i++ {
+ <-didReqCh
+ }
+
+ tr.CloseIdleConnections()
+ nfinal := waitNumGoroutine(n0 + 5)
+
+ growth := nfinal - n0
+
+ // We expect 0 or 1 extra goroutine, empirically. Allow up to 5.
+ // Previously we were leaking one per numReq.
+ if int(growth) > 5 {
+ t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth)
+ t.Error("too many new goroutines")
+ }
+}
+
+// golang.org/issue/4531: Transport leaks goroutines when
+// request.ContentLength is explicitly short
+func TestTransportPersistConnLeakShortBody(t *testing.T) {
+ run(t, testTransportPersistConnLeakShortBody, testNotParallel)
+}
+func testTransportPersistConnLeakShortBody(t *testing.T, mode testMode) {
+ if mode == http2Mode {
+ t.Skip("flaky in HTTP/2")
+ }
+
+ // Not parallel: measures goroutines.
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ })).ts
+ c := ts.Client()
+ tr := c.Transport.(*Transport)
+
+ n0 := runtime.NumGoroutine()
+ body := []byte("Hello")
+ for i := 0; i < 20; i++ {
+ req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
+ if err != nil {
+ t.Fatal(err)
+ }
+ req.ContentLength = int64(len(body) - 2) // explicitly short
+ _, err = c.Do(req)
+ if err == nil {
+ t.Fatal("Expect an error from writing too long of a body.")
+ }
+ }
+ nhigh := runtime.NumGoroutine()
+ tr.CloseIdleConnections()
+ nfinal := waitNumGoroutine(n0 + 5)
+
+ growth := nfinal - n0
+
+ // We expect 0 or 1 extra goroutine, empirically. Allow up to 5.
+ // Previously we were leaking one per numReq.
+ t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth)
+ if int(growth) > 5 {
+ t.Error("too many new goroutines")
+ }
+}
+
+// A countedConn is a net.Conn that decrements an atomic counter when finalized.
+type countedConn struct {
+ net.Conn
+}
+
+// A countingDialer dials connections and counts the number that remain reachable.
+type countingDialer struct {
+ dialer net.Dialer
+ mu sync.Mutex
+ total, live int64
+}
+
+func (d *countingDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
+ conn, err := d.dialer.DialContext(ctx, network, address)
+ if err != nil {
+ return nil, err
+ }
+
+ counted := new(countedConn)
+ counted.Conn = conn
+
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ d.total++
+ d.live++
+
+ runtime.SetFinalizer(counted, d.decrement)
+ return counted, nil
+}
+
+func (d *countingDialer) decrement(*countedConn) {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ d.live--
+}
+
+func (d *countingDialer) Read() (total, live int64) {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ return d.total, d.live
+}
+
+func TestTransportPersistConnLeakNeverIdle(t *testing.T) {
+ run(t, testTransportPersistConnLeakNeverIdle, []testMode{http1Mode})
+}
+func testTransportPersistConnLeakNeverIdle(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ // Close every connection so that it cannot be kept alive.
+ conn, _, err := w.(Hijacker).Hijack()
+ if err != nil {
+ t.Errorf("Hijack failed unexpectedly: %v", err)
+ return
+ }
+ conn.Close()
+ })).ts
+
+ var d countingDialer
+ c := ts.Client()
+ c.Transport.(*Transport).DialContext = d.DialContext
+
+ body := []byte("Hello")
+ for i := 0; ; i++ {
+ total, live := d.Read()
+ if live < total {
+ break
+ }
+ if i >= 1<<12 {
+ t.Fatalf("Count of live client net.Conns (%d) not lower than total (%d) after %d Do / GC iterations.", live, total, i)
+ }
+
+ req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
+ if err != nil {
+ t.Fatal(err)
+ }
+ _, err = c.Do(req)
+ if err == nil {
+ t.Fatal("expected broken connection")
+ }
+
+ runtime.GC()
+ }
+}
+
+type countedContext struct {
+ context.Context
+}
+
+type contextCounter struct {
+ mu sync.Mutex
+ live int64
+}
+
+func (cc *contextCounter) Track(ctx context.Context) context.Context {
+ counted := new(countedContext)
+ counted.Context = ctx
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ cc.live++
+ runtime.SetFinalizer(counted, cc.decrement)
+ return counted
+}
+
+func (cc *contextCounter) decrement(*countedContext) {
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ cc.live--
+}
+
+func (cc *contextCounter) Read() (live int64) {
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ return cc.live
+}
+
+func TestTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T) {
+ run(t, testTransportPersistConnContextLeakMaxConnsPerHost)
+}
+func testTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T, mode testMode) {
+ if mode == http2Mode {
+ t.Skip("https://go.dev/issue/56021")
+ }
+
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ runtime.Gosched()
+ w.WriteHeader(StatusOK)
+ })).ts
+
+ c := ts.Client()
+ c.Transport.(*Transport).MaxConnsPerHost = 1
+
+ ctx := context.Background()
+ body := []byte("Hello")
+ doPosts := func(cc *contextCounter) {
+ var wg sync.WaitGroup
+ for n := 64; n > 0; n-- {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+
+ ctx := cc.Track(ctx)
+ req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
+ if err != nil {
+ t.Error(err)
+ }
+
+ _, err = c.Do(req.WithContext(ctx))
+ if err != nil {
+ t.Errorf("Do failed with error: %v", err)
+ }
+ }()
+ }
+ wg.Wait()
+ }
+
+ var initialCC contextCounter
+ doPosts(&initialCC)
+
+ // flushCC exists only to put pressure on the GC to finalize the initialCC
+ // contexts: the flushCC allocations should eventually displace the initialCC
+ // allocations.
+ var flushCC contextCounter
+ for i := 0; ; i++ {
+ live := initialCC.Read()
+ if live == 0 {
+ break
+ }
+ if i >= 100 {
+ t.Fatalf("%d Contexts still not finalized after %d GC cycles.", live, i)
+ }
+ doPosts(&flushCC)
+ runtime.GC()
+ }
+}
+
+// This used to crash; https://golang.org/issue/3266
+func TestTransportIdleConnCrash(t *testing.T) { run(t, testTransportIdleConnCrash) }
+func testTransportIdleConnCrash(t *testing.T, mode testMode) {
+ var tr *Transport
+
+ unblockCh := make(chan bool, 1)
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ <-unblockCh
+ tr.CloseIdleConnections()
+ })).ts
+ c := ts.Client()
+ tr = c.Transport.(*Transport)
+
+ didreq := make(chan bool)
+ go func() {
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Error(err)
+ } else {
+ res.Body.Close() // returns idle conn
+ }
+ didreq <- true
+ }()
+ unblockCh <- true
+ <-didreq
+}
+
+// Test that the transport doesn't close the TCP connection early,
+// before the response body has been read. This was a regression
+// which sadly lacked a triggering test. The large response body made
+// the old race easier to trigger.
+func TestIssue3644(t *testing.T) { run(t, testIssue3644) }
+func testIssue3644(t *testing.T, mode testMode) {
+ const numFoos = 5000
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Connection", "close")
+ for i := 0; i < numFoos; i++ {
+ w.Write([]byte("foo "))
+ }
+ })).ts
+ c := ts.Client()
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ bs, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(bs) != numFoos*len("foo ") {
+ t.Errorf("unexpected response length")
+ }
+}
+
+// Test that a client receives a server's reply, even if the server doesn't read
+// the entire request body.
+func TestIssue3595(t *testing.T) { run(t, testIssue3595) }
+func testIssue3595(t *testing.T, mode testMode) {
+ const deniedMsg = "sorry, denied."
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ Error(w, deniedMsg, StatusUnauthorized)
+ })).ts
+ c := ts.Client()
+ res, err := c.Post(ts.URL, "application/octet-stream", neverEnding('a'))
+ if err != nil {
+ t.Errorf("Post: %v", err)
+ return
+ }
+ got, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatalf("Body ReadAll: %v", err)
+ }
+ if !strings.Contains(string(got), deniedMsg) {
+ t.Errorf("Known bug: response %q does not contain %q", got, deniedMsg)
+ }
+}
+
+// From https://golang.org/issue/4454 ,
+// "client fails to handle requests with no body and chunked encoding"
+func TestChunkedNoContent(t *testing.T) { run(t, testChunkedNoContent) }
+func testChunkedNoContent(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.WriteHeader(StatusNoContent)
+ })).ts
+
+ c := ts.Client()
+ for _, closeBody := range []bool{true, false} {
+ const n = 4
+ for i := 1; i <= n; i++ {
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Errorf("closingBody=%v, req %d/%d: %v", closeBody, i, n, err)
+ } else {
+ if closeBody {
+ res.Body.Close()
+ }
+ }
+ }
+ }
+}
+
+func TestTransportConcurrency(t *testing.T) {
+ run(t, testTransportConcurrency, testNotParallel, []testMode{http1Mode})
+}
+func testTransportConcurrency(t *testing.T, mode testMode) {
+ // Not parallel: uses global test hooks.
+ maxProcs, numReqs := 16, 500
+ if testing.Short() {
+ maxProcs, numReqs = 4, 50
+ }
+ defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs))
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ fmt.Fprintf(w, "%v", r.FormValue("echo"))
+ })).ts
+
+ var wg sync.WaitGroup
+ wg.Add(numReqs)
+
+ // Due to the Transport's "socket late binding" (see
+ // idleConnCh in transport.go), the numReqs HTTP requests
+ // below can finish with a dial still outstanding. To keep
+ // the leak checker happy, keep track of pending dials and
+ // wait for them to finish (and be closed or returned to the
+ // idle pool) before we close idle connections.
+ SetPendingDialHooks(func() { wg.Add(1) }, wg.Done)
+ defer SetPendingDialHooks(nil, nil)
+
+ c := ts.Client()
+ reqs := make(chan string)
+ defer close(reqs)
+
+ for i := 0; i < maxProcs*2; i++ {
+ go func() {
+ for req := range reqs {
+ res, err := c.Get(ts.URL + "/?echo=" + req)
+ if err != nil {
+ if runtime.GOOS == "netbsd" && strings.HasSuffix(err.Error(), ": connection reset by peer") {
+ // https://go.dev/issue/52168: this test was observed to fail with
+ // ECONNRESET errors in Dial on various netbsd builders.
+ t.Logf("error on req %s: %v", req, err)
+ t.Logf("(see https://go.dev/issue/52168)")
+ } else {
+ t.Errorf("error on req %s: %v", req, err)
+ }
+ wg.Done()
+ continue
+ }
+ all, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Errorf("read error on req %s: %v", req, err)
+ } else if string(all) != req {
+ t.Errorf("body of req %s = %q; want %q", req, all, req)
+ }
+ res.Body.Close()
+ wg.Done()
+ }
+ }()
+ }
+ for i := 0; i < numReqs; i++ {
+ reqs <- fmt.Sprintf("request-%d", i)
+ }
+ wg.Wait()
+}
+
+func TestIssue4191_InfiniteGetTimeout(t *testing.T) { run(t, testIssue4191_InfiniteGetTimeout) }
+func testIssue4191_InfiniteGetTimeout(t *testing.T, mode testMode) {
+ mux := NewServeMux()
+ mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
+ io.Copy(w, neverEnding('a'))
+ })
+ ts := newClientServerTest(t, mode, mux).ts
+
+ connc := make(chan net.Conn, 1)
+ c := ts.Client()
+ c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
+ conn, err := net.Dial(n, addr)
+ if err != nil {
+ return nil, err
+ }
+ select {
+ case connc <- conn:
+ default:
+ }
+ return conn, nil
+ }
+
+ res, err := c.Get(ts.URL + "/get")
+ if err != nil {
+ t.Fatalf("Error issuing GET: %v", err)
+ }
+ defer res.Body.Close()
+
+ conn := <-connc
+ conn.SetDeadline(time.Now().Add(1 * time.Millisecond))
+ _, err = io.Copy(io.Discard, res.Body)
+ if err == nil {
+ t.Errorf("Unexpected successful copy")
+ }
+}
+
+func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) {
+ run(t, testIssue4191_InfiniteGetToPutTimeout, []testMode{http1Mode})
+}
+func testIssue4191_InfiniteGetToPutTimeout(t *testing.T, mode testMode) {
+ const debug = false
+ mux := NewServeMux()
+ mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
+ io.Copy(w, neverEnding('a'))
+ })
+ mux.HandleFunc("/put", func(w ResponseWriter, r *Request) {
+ defer r.Body.Close()
+ io.Copy(io.Discard, r.Body)
+ })
+ ts := newClientServerTest(t, mode, mux).ts
+ timeout := 100 * time.Millisecond
+
+ c := ts.Client()
+ c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
+ conn, err := net.Dial(n, addr)
+ if err != nil {
+ return nil, err
+ }
+ conn.SetDeadline(time.Now().Add(timeout))
+ if debug {
+ conn = NewLoggingConn("client", conn)
+ }
+ return conn, nil
+ }
+
+ getFailed := false
+ nRuns := 5
+ if testing.Short() {
+ nRuns = 1
+ }
+ for i := 0; i < nRuns; i++ {
+ if debug {
+ println("run", i+1, "of", nRuns)
+ }
+ sres, err := c.Get(ts.URL + "/get")
+ if err != nil {
+ if !getFailed {
+ // Make the timeout longer, once.
+ getFailed = true
+ t.Logf("increasing timeout")
+ i--
+ timeout *= 10
+ continue
+ }
+ t.Errorf("Error issuing GET: %v", err)
+ break
+ }
+ req, _ := NewRequest("PUT", ts.URL+"/put", sres.Body)
+ _, err = c.Do(req)
+ if err == nil {
+ sres.Body.Close()
+ t.Errorf("Unexpected successful PUT")
+ break
+ }
+ sres.Body.Close()
+ }
+ if debug {
+ println("tests complete; waiting for handlers to finish")
+ }
+ ts.Close()
+}
+
+func TestTransportResponseHeaderTimeout(t *testing.T) { run(t, testTransportResponseHeaderTimeout) }
+func testTransportResponseHeaderTimeout(t *testing.T, mode testMode) {
+ if testing.Short() {
+ t.Skip("skipping timeout test in -short mode")
+ }
+
+ timeout := 2 * time.Millisecond
+ retry := true
+ for retry && !t.Failed() {
+ var srvWG sync.WaitGroup
+ inHandler := make(chan bool, 1)
+ mux := NewServeMux()
+ mux.HandleFunc("/fast", func(w ResponseWriter, r *Request) {
+ inHandler <- true
+ srvWG.Done()
+ })
+ mux.HandleFunc("/slow", func(w ResponseWriter, r *Request) {
+ inHandler <- true
+ <-r.Context().Done()
+ srvWG.Done()
+ })
+ ts := newClientServerTest(t, mode, mux).ts
+
+ c := ts.Client()
+ c.Transport.(*Transport).ResponseHeaderTimeout = timeout
+
+ retry = false
+ srvWG.Add(3)
+ tests := []struct {
+ path string
+ wantTimeout bool
+ }{
+ {path: "/fast"},
+ {path: "/slow", wantTimeout: true},
+ {path: "/fast"},
+ }
+ for i, tt := range tests {
+ req, _ := NewRequest("GET", ts.URL+tt.path, nil)
+ req = req.WithT(t)
+ res, err := c.Do(req)
+ <-inHandler
+ if err != nil {
+ uerr, ok := err.(*url.Error)
+ if !ok {
+ t.Errorf("error is not a url.Error; got: %#v", err)
+ continue
+ }
+ nerr, ok := uerr.Err.(net.Error)
+ if !ok {
+ t.Errorf("error does not satisfy net.Error interface; got: %#v", err)
+ continue
+ }
+ if !nerr.Timeout() {
+ t.Errorf("want timeout error; got: %q", nerr)
+ continue
+ }
+ if !tt.wantTimeout {
+ if !retry {
+ // The timeout may be set too short. Retry with a longer one.
+ t.Logf("unexpected timeout for path %q after %v; retrying with longer timeout", tt.path, timeout)
+ timeout *= 2
+ retry = true
+ }
+ }
+ if !strings.Contains(err.Error(), "timeout awaiting response headers") {
+ t.Errorf("%d. unexpected error: %v", i, err)
+ }
+ continue
+ }
+ if tt.wantTimeout {
+ t.Errorf(`no error for path %q; expected "timeout awaiting response headers"`, tt.path)
+ continue
+ }
+ if res.StatusCode != 200 {
+ t.Errorf("%d for path %q status = %d; want 200", i, tt.path, res.StatusCode)
+ }
+ }
+
+ srvWG.Wait()
+ ts.Close()
+ }
+}
+
+func TestTransportCancelRequest(t *testing.T) {
+ run(t, testTransportCancelRequest, []testMode{http1Mode})
+}
+func testTransportCancelRequest(t *testing.T, mode testMode) {
+ if testing.Short() {
+ t.Skip("skipping test in -short mode")
+ }
+
+ const msg = "Hello"
+ unblockc := make(chan bool)
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ io.WriteString(w, msg)
+ w.(Flusher).Flush() // send headers and some body
+ <-unblockc
+ })).ts
+ defer close(unblockc)
+
+ c := ts.Client()
+ tr := c.Transport.(*Transport)
+
+ req, _ := NewRequest("GET", ts.URL, nil)
+ res, err := c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ body := make([]byte, len(msg))
+ n, _ := io.ReadFull(res.Body, body)
+ if n != len(body) || !bytes.Equal(body, []byte(msg)) {
+ t.Errorf("Body = %q; want %q", body[:n], msg)
+ }
+ tr.CancelRequest(req)
+
+ tail, err := io.ReadAll(res.Body)
+ res.Body.Close()
+ if err != ExportErrRequestCanceled {
+ t.Errorf("Body.Read error = %v; want errRequestCanceled", err)
+ } else if len(tail) > 0 {
+ t.Errorf("Spurious bytes from Body.Read: %q", tail)
+ }
+
+ // Verify no outstanding requests after readLoop/writeLoop
+ // goroutines shut down.
+ waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
+ n := tr.NumPendingRequestsForTesting()
+ if n > 0 {
+ if d > 0 {
+ t.Logf("pending requests = %d after %v (want 0)", n, d)
+ }
+ }
+ return true
+ })
+}
+
+func testTransportCancelRequestInDo(t *testing.T, mode testMode, body io.Reader) {
+ if testing.Short() {
+ t.Skip("skipping test in -short mode")
+ }
+ unblockc := make(chan bool)
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ <-unblockc
+ })).ts
+ defer close(unblockc)
+
+ c := ts.Client()
+ tr := c.Transport.(*Transport)
+
+ donec := make(chan bool)
+ req, _ := NewRequest("GET", ts.URL, body)
+ go func() {
+ defer close(donec)
+ c.Do(req)
+ }()
+
+ unblockc <- true
+ waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
+ tr.CancelRequest(req)
+ select {
+ case <-donec:
+ return true
+ default:
+ if d > 0 {
+ t.Logf("Do of canceled request has not returned after %v", d)
+ }
+ return false
+ }
+ })
+}
+
+func TestTransportCancelRequestInDo(t *testing.T) {
+ run(t, func(t *testing.T, mode testMode) {
+ testTransportCancelRequestInDo(t, mode, nil)
+ }, []testMode{http1Mode})
+}
+
+func TestTransportCancelRequestWithBodyInDo(t *testing.T) {
+ run(t, func(t *testing.T, mode testMode) {
+ testTransportCancelRequestInDo(t, mode, bytes.NewBuffer([]byte{0}))
+ }, []testMode{http1Mode})
+}
+
+func TestTransportCancelRequestInDial(t *testing.T) {
+ defer afterTest(t)
+ if testing.Short() {
+ t.Skip("skipping test in -short mode")
+ }
+ var logbuf strings.Builder
+ eventLog := log.New(&logbuf, "", 0)
+
+ unblockDial := make(chan bool)
+ defer close(unblockDial)
+
+ inDial := make(chan bool)
+ tr := &Transport{
+ Dial: func(network, addr string) (net.Conn, error) {
+ eventLog.Println("dial: blocking")
+ if !<-inDial {
+ return nil, errors.New("main Test goroutine exited")
+ }
+ <-unblockDial
+ return nil, errors.New("nope")
+ },
+ }
+ cl := &Client{Transport: tr}
+ gotres := make(chan bool)
+ req, _ := NewRequest("GET", "http://something.no-network.tld/", nil)
+ go func() {
+ _, err := cl.Do(req)
+ eventLog.Printf("Get = %v", err)
+ gotres <- true
+ }()
+
+ inDial <- true
+
+ eventLog.Printf("canceling")
+ tr.CancelRequest(req)
+ tr.CancelRequest(req) // used to panic on second call
+
+ if d, ok := t.Deadline(); ok {
+ // When the test's deadline is about to expire, log the pending events for
+ // better debugging.
+ timeout := time.Until(d) * 19 / 20 // Allow 5% for cleanup.
+ timer := time.AfterFunc(timeout, func() {
+ panic(fmt.Sprintf("hang in %s. events are: %s", t.Name(), logbuf.String()))
+ })
+ defer timer.Stop()
+ }
+ <-gotres
+
+ got := logbuf.String()
+ want := `dial: blocking
+canceling
+Get = Get "http://something.no-network.tld/": net/http: request canceled while waiting for connection
+`
+ if got != want {
+ t.Errorf("Got events:\n%s\nWant:\n%s", got, want)
+ }
+}
+
+func TestCancelRequestWithChannel(t *testing.T) { run(t, testCancelRequestWithChannel) }
+func testCancelRequestWithChannel(t *testing.T, mode testMode) {
+ if testing.Short() {
+ t.Skip("skipping test in -short mode")
+ }
+
+ const msg = "Hello"
+ unblockc := make(chan struct{})
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ io.WriteString(w, msg)
+ w.(Flusher).Flush() // send headers and some body
+ <-unblockc
+ })).ts
+ defer close(unblockc)
+
+ c := ts.Client()
+ tr := c.Transport.(*Transport)
+
+ req, _ := NewRequest("GET", ts.URL, nil)
+ cancel := make(chan struct{})
+ req.Cancel = cancel
+
+ res, err := c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ body := make([]byte, len(msg))
+ n, _ := io.ReadFull(res.Body, body)
+ if n != len(body) || !bytes.Equal(body, []byte(msg)) {
+ t.Errorf("Body = %q; want %q", body[:n], msg)
+ }
+ close(cancel)
+
+ tail, err := io.ReadAll(res.Body)
+ res.Body.Close()
+ if err != ExportErrRequestCanceled {
+ t.Errorf("Body.Read error = %v; want errRequestCanceled", err)
+ } else if len(tail) > 0 {
+ t.Errorf("Spurious bytes from Body.Read: %q", tail)
+ }
+
+ // Verify no outstanding requests after readLoop/writeLoop
+ // goroutines shut down.
+ waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
+ n := tr.NumPendingRequestsForTesting()
+ if n > 0 {
+ if d > 0 {
+ t.Logf("pending requests = %d after %v (want 0)", n, d)
+ }
+ }
+ return true
+ })
+}
+
+func TestCancelRequestWithChannelBeforeDo_Cancel(t *testing.T) {
+ run(t, func(t *testing.T, mode testMode) {
+ testCancelRequestWithChannelBeforeDo(t, mode, false)
+ })
+}
+func TestCancelRequestWithChannelBeforeDo_Context(t *testing.T) {
+ run(t, func(t *testing.T, mode testMode) {
+ testCancelRequestWithChannelBeforeDo(t, mode, true)
+ })
+}
+func testCancelRequestWithChannelBeforeDo(t *testing.T, mode testMode, withCtx bool) {
+ unblockc := make(chan bool)
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ <-unblockc
+ })).ts
+ defer close(unblockc)
+
+ c := ts.Client()
+
+ req, _ := NewRequest("GET", ts.URL, nil)
+ if withCtx {
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+ req = req.WithContext(ctx)
+ } else {
+ ch := make(chan struct{})
+ req.Cancel = ch
+ close(ch)
+ }
+
+ _, err := c.Do(req)
+ if ue, ok := err.(*url.Error); ok {
+ err = ue.Err
+ }
+ if withCtx {
+ if err != context.Canceled {
+ t.Errorf("Do error = %v; want %v", err, context.Canceled)
+ }
+ } else {
+ if err == nil || !strings.Contains(err.Error(), "canceled") {
+ t.Errorf("Do error = %v; want cancellation", err)
+ }
+ }
+}
+
+// Issue 11020. The returned error message should be errRequestCanceled
+func TestTransportCancelBeforeResponseHeaders(t *testing.T) {
+ defer afterTest(t)
+
+ serverConnCh := make(chan net.Conn, 1)
+ tr := &Transport{
+ Dial: func(network, addr string) (net.Conn, error) {
+ cc, sc := net.Pipe()
+ serverConnCh <- sc
+ return cc, nil
+ },
+ }
+ defer tr.CloseIdleConnections()
+ errc := make(chan error, 1)
+ req, _ := NewRequest("GET", "http://example.com/", nil)
+ go func() {
+ _, err := tr.RoundTrip(req)
+ errc <- err
+ }()
+
+ sc := <-serverConnCh
+ verb := make([]byte, 3)
+ if _, err := io.ReadFull(sc, verb); err != nil {
+ t.Errorf("Error reading HTTP verb from server: %v", err)
+ }
+ if string(verb) != "GET" {
+ t.Errorf("server received %q; want GET", verb)
+ }
+ defer sc.Close()
+
+ tr.CancelRequest(req)
+
+ err := <-errc
+ if err == nil {
+ t.Fatalf("unexpected success from RoundTrip")
+ }
+ if err != ExportErrRequestCanceled {
+ t.Errorf("RoundTrip error = %v; want ExportErrRequestCanceled", err)
+ }
+}
+
+// golang.org/issue/3672 -- Client can't close HTTP stream
+// Calling Close on a Response.Body used to just read until EOF.
+// Now it actually closes the TCP connection.
+func TestTransportCloseResponseBody(t *testing.T) { run(t, testTransportCloseResponseBody) }
+func testTransportCloseResponseBody(t *testing.T, mode testMode) {
+ writeErr := make(chan error, 1)
+ msg := []byte("young\n")
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ for {
+ _, err := w.Write(msg)
+ if err != nil {
+ writeErr <- err
+ return
+ }
+ w.(Flusher).Flush()
+ }
+ })).ts
+
+ c := ts.Client()
+ tr := c.Transport.(*Transport)
+
+ req, _ := NewRequest("GET", ts.URL, nil)
+ defer tr.CancelRequest(req)
+
+ res, err := c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ const repeats = 3
+ buf := make([]byte, len(msg)*repeats)
+ want := bytes.Repeat(msg, repeats)
+
+ _, err = io.ReadFull(res.Body, buf)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("read %q; want %q", buf, want)
+ }
+
+ if err := res.Body.Close(); err != nil {
+ t.Errorf("Close = %v", err)
+ }
+
+ if err := <-writeErr; err == nil {
+ t.Errorf("expected non-nil write error")
+ }
+}
+
+type fooProto struct{}
+
+func (fooProto) RoundTrip(req *Request) (*Response, error) {
+ res := &Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Header: make(Header),
+ Body: io.NopCloser(strings.NewReader("You wanted " + req.URL.String())),
+ }
+ return res, nil
+}
+
+func TestTransportAltProto(t *testing.T) {
+ defer afterTest(t)
+ tr := &Transport{}
+ c := &Client{Transport: tr}
+ tr.RegisterProtocol("foo", fooProto{})
+ res, err := c.Get("foo://bar.com/path")
+ if err != nil {
+ t.Fatal(err)
+ }
+ bodyb, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ body := string(bodyb)
+ if e := "You wanted foo://bar.com/path"; body != e {
+ t.Errorf("got response %q, want %q", body, e)
+ }
+}
+
+func TestTransportNoHost(t *testing.T) {
+ defer afterTest(t)
+ tr := &Transport{}
+ _, err := tr.RoundTrip(&Request{
+ Header: make(Header),
+ URL: &url.URL{
+ Scheme: "http",
+ },
+ })
+ want := "http: no Host in request URL"
+ if got := fmt.Sprint(err); got != want {
+ t.Errorf("error = %v; want %q", err, want)
+ }
+}
+
+// Issue 13311
+func TestTransportEmptyMethod(t *testing.T) {
+ req, _ := NewRequest("GET", "http://foo.com/", nil)
+ req.Method = "" // docs say "For client requests an empty string means GET"
+ got, err := httputil.DumpRequestOut(req, false) // DumpRequestOut uses Transport
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !strings.Contains(string(got), "GET ") {
+ t.Fatalf("expected substring 'GET '; got: %s", got)
+ }
+}
+
+func TestTransportSocketLateBinding(t *testing.T) { run(t, testTransportSocketLateBinding) }
+func testTransportSocketLateBinding(t *testing.T, mode testMode) {
+ mux := NewServeMux()
+ fooGate := make(chan bool, 1)
+ mux.HandleFunc("/foo", func(w ResponseWriter, r *Request) {
+ w.Header().Set("foo-ipport", r.RemoteAddr)
+ w.(Flusher).Flush()
+ <-fooGate
+ })
+ mux.HandleFunc("/bar", func(w ResponseWriter, r *Request) {
+ w.Header().Set("bar-ipport", r.RemoteAddr)
+ })
+ ts := newClientServerTest(t, mode, mux).ts
+
+ dialGate := make(chan bool, 1)
+ dialing := make(chan bool)
+ c := ts.Client()
+ c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
+ for {
+ select {
+ case ok := <-dialGate:
+ if !ok {
+ return nil, errors.New("manually closed")
+ }
+ return net.Dial(n, addr)
+ case dialing <- true:
+ }
+ }
+ }
+ defer close(dialGate)
+
+ dialGate <- true // only allow one dial
+ fooRes, err := c.Get(ts.URL + "/foo")
+ if err != nil {
+ t.Fatal(err)
+ }
+ fooAddr := fooRes.Header.Get("foo-ipport")
+ if fooAddr == "" {
+ t.Fatal("No addr on /foo request")
+ }
+
+ fooDone := make(chan struct{})
+ go func() {
+ // We know that the foo Dial completed and reached the handler because we
+ // read its header. Wait for the bar request to block in Dial, then
+ // let the foo response finish so we can use its connection for /bar.
+
+ if mode == http2Mode {
+ // In HTTP/2 mode, the second Dial won't happen because the protocol
+ // multiplexes the streams by default. Just sleep for an arbitrary time;
+ // the test should pass regardless of how far the bar request gets by this
+ // point.
+ select {
+ case <-dialing:
+ t.Errorf("unexpected second Dial in HTTP/2 mode")
+ case <-time.After(10 * time.Millisecond):
+ }
+ } else {
+ <-dialing
+ }
+ fooGate <- true
+ io.Copy(io.Discard, fooRes.Body)
+ fooRes.Body.Close()
+ close(fooDone)
+ }()
+ defer func() {
+ <-fooDone
+ }()
+
+ barRes, err := c.Get(ts.URL + "/bar")
+ if err != nil {
+ t.Fatal(err)
+ }
+ barAddr := barRes.Header.Get("bar-ipport")
+ if barAddr != fooAddr {
+ t.Fatalf("/foo came from conn %q; /bar came from %q instead", fooAddr, barAddr)
+ }
+ barRes.Body.Close()
+}
+
+// Issue 2184
+func TestTransportReading100Continue(t *testing.T) {
+ defer afterTest(t)
+
+ const numReqs = 5
+ reqBody := func(n int) string { return fmt.Sprintf("request body %d", n) }
+ reqID := func(n int) string { return fmt.Sprintf("REQ-ID-%d", n) }
+
+ send100Response := func(w *io.PipeWriter, r *io.PipeReader) {
+ defer w.Close()
+ defer r.Close()
+ br := bufio.NewReader(r)
+ n := 0
+ for {
+ n++
+ req, err := ReadRequest(br)
+ if err == io.EOF {
+ return
+ }
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ slurp, err := io.ReadAll(req.Body)
+ if err != nil {
+ t.Errorf("Server request body slurp: %v", err)
+ return
+ }
+ id := req.Header.Get("Request-Id")
+ resCode := req.Header.Get("X-Want-Response-Code")
+ if resCode == "" {
+ resCode = "100 Continue"
+ if string(slurp) != reqBody(n) {
+ t.Errorf("Server got %q, %v; want %q", slurp, err, reqBody(n))
+ }
+ }
+ body := fmt.Sprintf("Response number %d", n)
+ v := []byte(strings.Replace(fmt.Sprintf(`HTTP/1.1 %s
+Date: Thu, 28 Feb 2013 17:55:41 GMT
+
+HTTP/1.1 200 OK
+Content-Type: text/html
+Echo-Request-Id: %s
+Content-Length: %d
+
+%s`, resCode, id, len(body), body), "\n", "\r\n", -1))
+ w.Write(v)
+ if id == reqID(numReqs) {
+ return
+ }
+ }
+
+ }
+
+ tr := &Transport{
+ Dial: func(n, addr string) (net.Conn, error) {
+ sr, sw := io.Pipe() // server read/write
+ cr, cw := io.Pipe() // client read/write
+ conn := &rwTestConn{
+ Reader: cr,
+ Writer: sw,
+ closeFunc: func() error {
+ sw.Close()
+ cw.Close()
+ return nil
+ },
+ }
+ go send100Response(cw, sr)
+ return conn, nil
+ },
+ DisableKeepAlives: false,
+ }
+ defer tr.CloseIdleConnections()
+ c := &Client{Transport: tr}
+
+ testResponse := func(req *Request, name string, wantCode int) {
+ t.Helper()
+ res, err := c.Do(req)
+ if err != nil {
+ t.Fatalf("%s: Do: %v", name, err)
+ }
+ if res.StatusCode != wantCode {
+ t.Fatalf("%s: Response Statuscode=%d; want %d", name, res.StatusCode, wantCode)
+ }
+ if id, idBack := req.Header.Get("Request-Id"), res.Header.Get("Echo-Request-Id"); id != "" && id != idBack {
+ t.Errorf("%s: response id %q != request id %q", name, idBack, id)
+ }
+ _, err = io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatalf("%s: Slurp error: %v", name, err)
+ }
+ }
+
+ // Few 100 responses, making sure we're not off-by-one.
+ for i := 1; i <= numReqs; i++ {
+ req, _ := NewRequest("POST", "http://dummy.tld/", strings.NewReader(reqBody(i)))
+ req.Header.Set("Request-Id", reqID(i))
+ testResponse(req, fmt.Sprintf("100, %d/%d", i, numReqs), 200)
+ }
+}
+
+// Issue 17739: the HTTP client must ignore any unknown 1xx
+// informational responses before the actual response.
+func TestTransportIgnore1xxResponses(t *testing.T) {
+ run(t, testTransportIgnore1xxResponses, []testMode{http1Mode})
+}
+func testTransportIgnore1xxResponses(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ conn, buf, _ := w.(Hijacker).Hijack()
+ buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\nFoo: bar\r\n\r\nHTTP/1.1 200 OK\r\nBar: baz\r\nContent-Length: 5\r\n\r\nHello"))
+ buf.Flush()
+ conn.Close()
+ }))
+ cst.tr.DisableKeepAlives = true // prevent log spam; our test server is hanging up anyway
+
+ var got strings.Builder
+
+ req, _ := NewRequest("GET", cst.ts.URL, nil)
+ req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
+ Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
+ fmt.Fprintf(&got, "1xx: code=%v, header=%v\n", code, header)
+ return nil
+ },
+ }))
+ res, err := cst.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+
+ res.Write(&got)
+ want := "1xx: code=123, header=map[Foo:[bar]]\nHTTP/1.1 200 OK\r\nContent-Length: 5\r\nBar: baz\r\n\r\nHello"
+ if got.String() != want {
+ t.Errorf(" got: %q\nwant: %q\n", got.String(), want)
+ }
+}
+
+func TestTransportLimits1xxResponses(t *testing.T) {
+ run(t, testTransportLimits1xxResponses, []testMode{http1Mode})
+}
+func testTransportLimits1xxResponses(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ conn, buf, _ := w.(Hijacker).Hijack()
+ for i := 0; i < 10; i++ {
+ buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\n\r\n"))
+ }
+ buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n"))
+ buf.Flush()
+ conn.Close()
+ }))
+ cst.tr.DisableKeepAlives = true // prevent log spam; our test server is hanging up anyway
+
+ res, err := cst.c.Get(cst.ts.URL)
+ if res != nil {
+ defer res.Body.Close()
+ }
+ got := fmt.Sprint(err)
+ wantSub := "too many 1xx informational responses"
+ if !strings.Contains(got, wantSub) {
+ t.Errorf("Get error = %v; want substring %q", err, wantSub)
+ }
+}
+
+// Issue 26161: the HTTP client must treat 101 responses
+// as the final response.
+func TestTransportTreat101Terminal(t *testing.T) {
+ run(t, testTransportTreat101Terminal, []testMode{http1Mode})
+}
+func testTransportTreat101Terminal(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ conn, buf, _ := w.(Hijacker).Hijack()
+ buf.Write([]byte("HTTP/1.1 101 Switching Protocols\r\n\r\n"))
+ buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n"))
+ buf.Flush()
+ conn.Close()
+ }))
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ if res.StatusCode != StatusSwitchingProtocols {
+ t.Errorf("StatusCode = %v; want 101 Switching Protocols", res.StatusCode)
+ }
+}
+
+type proxyFromEnvTest struct {
+ req string // URL to fetch; blank means "http://example.com"
+
+ env string // HTTP_PROXY
+ httpsenv string // HTTPS_PROXY
+ noenv string // NO_PROXY
+ reqmeth string // REQUEST_METHOD
+
+ want string
+ wanterr error
+}
+
+func (t proxyFromEnvTest) String() string {
+ var buf strings.Builder
+ space := func() {
+ if buf.Len() > 0 {
+ buf.WriteByte(' ')
+ }
+ }
+ if t.env != "" {
+ fmt.Fprintf(&buf, "http_proxy=%q", t.env)
+ }
+ if t.httpsenv != "" {
+ space()
+ fmt.Fprintf(&buf, "https_proxy=%q", t.httpsenv)
+ }
+ if t.noenv != "" {
+ space()
+ fmt.Fprintf(&buf, "no_proxy=%q", t.noenv)
+ }
+ if t.reqmeth != "" {
+ space()
+ fmt.Fprintf(&buf, "request_method=%q", t.reqmeth)
+ }
+ req := "http://example.com"
+ if t.req != "" {
+ req = t.req
+ }
+ space()
+ fmt.Fprintf(&buf, "req=%q", req)
+ return strings.TrimSpace(buf.String())
+}
+
+var proxyFromEnvTests = []proxyFromEnvTest{
+ {env: "127.0.0.1:8080", want: "http://127.0.0.1:8080"},
+ {env: "cache.corp.example.com:1234", want: "http://cache.corp.example.com:1234"},
+ {env: "cache.corp.example.com", want: "http://cache.corp.example.com"},
+ {env: "https://cache.corp.example.com", want: "https://cache.corp.example.com"},
+ {env: "http://127.0.0.1:8080", want: "http://127.0.0.1:8080"},
+ {env: "https://127.0.0.1:8080", want: "https://127.0.0.1:8080"},
+ {env: "socks5://127.0.0.1", want: "socks5://127.0.0.1"},
+
+ // Don't use secure for http
+ {req: "http://insecure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://http.proxy.tld"},
+ // Use secure for https.
+ {req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://secure.proxy.tld"},
+ {req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "https://secure.proxy.tld", want: "https://secure.proxy.tld"},
+
+ // Issue 16405: don't use HTTP_PROXY in a CGI environment,
+ // where HTTP_PROXY can be attacker-controlled.
+ {env: "http://10.1.2.3:8080", reqmeth: "POST",
+ want: "<nil>",
+ wanterr: errors.New("refusing to use HTTP_PROXY value in CGI environment; see golang.org/s/cgihttpproxy")},
+
+ {want: "<nil>"},
+
+ {noenv: "example.com", req: "http://example.com/", env: "proxy", want: "<nil>"},
+ {noenv: ".example.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
+ {noenv: "ample.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
+ {noenv: "example.com", req: "http://foo.example.com/", env: "proxy", want: "<nil>"},
+ {noenv: ".foo.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
+}
+
+func testProxyForRequest(t *testing.T, tt proxyFromEnvTest, proxyForRequest func(req *Request) (*url.URL, error)) {
+ t.Helper()
+ reqURL := tt.req
+ if reqURL == "" {
+ reqURL = "http://example.com"
+ }
+ req, _ := NewRequest("GET", reqURL, nil)
+ url, err := proxyForRequest(req)
+ if g, e := fmt.Sprintf("%v", err), fmt.Sprintf("%v", tt.wanterr); g != e {
+ t.Errorf("%v: got error = %q, want %q", tt, g, e)
+ return
+ }
+ if got := fmt.Sprintf("%s", url); got != tt.want {
+ t.Errorf("%v: got URL = %q, want %q", tt, url, tt.want)
+ }
+}
+
+func TestProxyFromEnvironment(t *testing.T) {
+ ResetProxyEnv()
+ defer ResetProxyEnv()
+ for _, tt := range proxyFromEnvTests {
+ testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) {
+ os.Setenv("HTTP_PROXY", tt.env)
+ os.Setenv("HTTPS_PROXY", tt.httpsenv)
+ os.Setenv("NO_PROXY", tt.noenv)
+ os.Setenv("REQUEST_METHOD", tt.reqmeth)
+ ResetCachedEnvironment()
+ return ProxyFromEnvironment(req)
+ })
+ }
+}
+
+func TestProxyFromEnvironmentLowerCase(t *testing.T) {
+ ResetProxyEnv()
+ defer ResetProxyEnv()
+ for _, tt := range proxyFromEnvTests {
+ testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) {
+ os.Setenv("http_proxy", tt.env)
+ os.Setenv("https_proxy", tt.httpsenv)
+ os.Setenv("no_proxy", tt.noenv)
+ os.Setenv("REQUEST_METHOD", tt.reqmeth)
+ ResetCachedEnvironment()
+ return ProxyFromEnvironment(req)
+ })
+ }
+}
+
+func TestIdleConnChannelLeak(t *testing.T) {
+ run(t, testIdleConnChannelLeak, []testMode{http1Mode}, testNotParallel)
+}
+func testIdleConnChannelLeak(t *testing.T, mode testMode) {
+ // Not parallel: uses global test hooks.
+ var mu sync.Mutex
+ var n int
+
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ mu.Lock()
+ n++
+ mu.Unlock()
+ })).ts
+
+ const nReqs = 5
+ didRead := make(chan bool, nReqs)
+ SetReadLoopBeforeNextReadHook(func() { didRead <- true })
+ defer SetReadLoopBeforeNextReadHook(nil)
+
+ c := ts.Client()
+ tr := c.Transport.(*Transport)
+ tr.Dial = func(netw, addr string) (net.Conn, error) {
+ return net.Dial(netw, ts.Listener.Addr().String())
+ }
+
+ // First, without keep-alives.
+ for _, disableKeep := range []bool{true, false} {
+ tr.DisableKeepAlives = disableKeep
+ for i := 0; i < nReqs; i++ {
+ _, err := c.Get(fmt.Sprintf("http://foo-host-%d.tld/", i))
+ if err != nil {
+ t.Fatal(err)
+ }
+ // Note: no res.Body.Close is needed here, since the
+ // response Content-Length is zero. Perhaps the test
+ // should be more explicit and use a HEAD, but tests
+ // elsewhere guarantee that zero byte responses generate
+ // a "Content-Length: 0" instead of chunking.
+ }
+
+ // At this point, each of the 5 Transport.readLoop goroutines
+ // are scheduling noting that there are no response bodies (see
+ // earlier comment), and are then calling putIdleConn, which
+ // decrements this count. Usually that happens quickly, which is
+ // why this test has seemed to work for ages. But it's still
+ // racey: we have wait for them to finish first. See Issue 10427
+ for i := 0; i < nReqs; i++ {
+ <-didRead
+ }
+
+ if got := tr.IdleConnWaitMapSizeForTesting(); got != 0 {
+ t.Fatalf("for DisableKeepAlives = %v, map size = %d; want 0", disableKeep, got)
+ }
+ }
+}
+
+// Verify the status quo: that the Client.Post function coerces its
+// body into a ReadCloser if it's a Closer, and that the Transport
+// then closes it.
+func TestTransportClosesRequestBody(t *testing.T) {
+ run(t, testTransportClosesRequestBody, []testMode{http1Mode})
+}
+func testTransportClosesRequestBody(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ io.Copy(io.Discard, r.Body)
+ })).ts
+
+ c := ts.Client()
+
+ closes := 0
+
+ res, err := c.Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ if closes != 1 {
+ t.Errorf("closes = %d; want 1", closes)
+ }
+}
+
+func TestTransportTLSHandshakeTimeout(t *testing.T) {
+ defer afterTest(t)
+ if testing.Short() {
+ t.Skip("skipping in short mode")
+ }
+ ln := newLocalListener(t)
+ defer ln.Close()
+ testdonec := make(chan struct{})
+ defer close(testdonec)
+
+ go func() {
+ c, err := ln.Accept()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ <-testdonec
+ c.Close()
+ }()
+
+ tr := &Transport{
+ Dial: func(_, _ string) (net.Conn, error) {
+ return net.Dial("tcp", ln.Addr().String())
+ },
+ TLSHandshakeTimeout: 250 * time.Millisecond,
+ }
+ cl := &Client{Transport: tr}
+ _, err := cl.Get("https://dummy.tld/")
+ if err == nil {
+ t.Error("expected error")
+ return
+ }
+ ue, ok := err.(*url.Error)
+ if !ok {
+ t.Errorf("expected url.Error; got %#v", err)
+ return
+ }
+ ne, ok := ue.Err.(net.Error)
+ if !ok {
+ t.Errorf("expected net.Error; got %#v", err)
+ return
+ }
+ if !ne.Timeout() {
+ t.Errorf("expected timeout error; got %v", err)
+ }
+ if !strings.Contains(err.Error(), "handshake timeout") {
+ t.Errorf("expected 'handshake timeout' in error; got %v", err)
+ }
+}
+
+// Trying to repro golang.org/issue/3514
+func TestTLSServerClosesConnection(t *testing.T) {
+ run(t, testTLSServerClosesConnection, []testMode{https1Mode})
+}
+func testTLSServerClosesConnection(t *testing.T, mode testMode) {
+ closedc := make(chan bool, 1)
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ if strings.Contains(r.URL.Path, "/keep-alive-then-die") {
+ conn, _, _ := w.(Hijacker).Hijack()
+ conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo"))
+ conn.Close()
+ closedc <- true
+ return
+ }
+ fmt.Fprintf(w, "hello")
+ })).ts
+
+ c := ts.Client()
+ tr := c.Transport.(*Transport)
+
+ var nSuccess = 0
+ var errs []error
+ const trials = 20
+ for i := 0; i < trials; i++ {
+ tr.CloseIdleConnections()
+ res, err := c.Get(ts.URL + "/keep-alive-then-die")
+ if err != nil {
+ t.Fatal(err)
+ }
+ <-closedc
+ slurp, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if string(slurp) != "foo" {
+ t.Errorf("Got %q, want foo", slurp)
+ }
+
+ // Now try again and see if we successfully
+ // pick a new connection.
+ res, err = c.Get(ts.URL + "/")
+ if err != nil {
+ errs = append(errs, err)
+ continue
+ }
+ slurp, err = io.ReadAll(res.Body)
+ if err != nil {
+ errs = append(errs, err)
+ continue
+ }
+ nSuccess++
+ }
+ if nSuccess > 0 {
+ t.Logf("successes = %d of %d", nSuccess, trials)
+ } else {
+ t.Errorf("All runs failed:")
+ }
+ for _, err := range errs {
+ t.Logf(" err: %v", err)
+ }
+}
+
+// byteFromChanReader is an io.Reader that reads a single byte at a
+// time from the channel. When the channel is closed, the reader
+// returns io.EOF.
+type byteFromChanReader chan byte
+
+func (c byteFromChanReader) Read(p []byte) (n int, err error) {
+ if len(p) == 0 {
+ return
+ }
+ b, ok := <-c
+ if !ok {
+ return 0, io.EOF
+ }
+ p[0] = b
+ return 1, nil
+}
+
+// Verifies that the Transport doesn't reuse a connection in the case
+// where the server replies before the request has been fully
+// written. We still honor that reply (see TestIssue3595), but don't
+// send future requests on the connection because it's then in a
+// questionable state.
+// golang.org/issue/7569
+func TestTransportNoReuseAfterEarlyResponse(t *testing.T) {
+ run(t, testTransportNoReuseAfterEarlyResponse, []testMode{http1Mode}, testNotParallel)
+}
+func testTransportNoReuseAfterEarlyResponse(t *testing.T, mode testMode) {
+ defer func(d time.Duration) {
+ *MaxWriteWaitBeforeConnReuse = d
+ }(*MaxWriteWaitBeforeConnReuse)
+ *MaxWriteWaitBeforeConnReuse = 10 * time.Millisecond
+ var sconn struct {
+ sync.Mutex
+ c net.Conn
+ }
+ var getOkay bool
+ closeConn := func() {
+ sconn.Lock()
+ defer sconn.Unlock()
+ if sconn.c != nil {
+ sconn.c.Close()
+ sconn.c = nil
+ if !getOkay {
+ t.Logf("Closed server connection")
+ }
+ }
+ }
+ defer closeConn()
+
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.Method == "GET" {
+ io.WriteString(w, "bar")
+ return
+ }
+ conn, _, _ := w.(Hijacker).Hijack()
+ sconn.Lock()
+ sconn.c = conn
+ sconn.Unlock()
+ conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo")) // keep-alive
+ go io.Copy(io.Discard, conn)
+ })).ts
+ c := ts.Client()
+
+ const bodySize = 256 << 10
+ finalBit := make(byteFromChanReader, 1)
+ req, _ := NewRequest("POST", ts.URL, io.MultiReader(io.LimitReader(neverEnding('x'), bodySize-1), finalBit))
+ req.ContentLength = bodySize
+ res, err := c.Do(req)
+ if err := wantBody(res, err, "foo"); err != nil {
+ t.Errorf("POST response: %v", err)
+ }
+
+ res, err = c.Get(ts.URL)
+ if err := wantBody(res, err, "bar"); err != nil {
+ t.Errorf("GET response: %v", err)
+ return
+ }
+ getOkay = true // suppress test noise
+ finalBit <- 'x' // unblock the writeloop of the first Post
+ close(finalBit)
+}
+
+// Tests that we don't leak Transport persistConn.readLoop goroutines
+// when a server hangs up immediately after saying it would keep-alive.
+func TestTransportIssue10457(t *testing.T) { run(t, testTransportIssue10457, []testMode{http1Mode}) }
+func testTransportIssue10457(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ // Send a response with no body, keep-alive
+ // (implicit), and then lie and immediately close the
+ // connection. This forces the Transport's readLoop to
+ // immediately Peek an io.EOF and get to the point
+ // that used to hang.
+ conn, _, _ := w.(Hijacker).Hijack()
+ conn.Write([]byte("HTTP/1.1 200 OK\r\nFoo: Bar\r\nContent-Length: 0\r\n\r\n")) // keep-alive
+ conn.Close()
+ })).ts
+ c := ts.Client()
+
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ defer res.Body.Close()
+
+ // Just a sanity check that we at least get the response. The real
+ // test here is that the "defer afterTest" above doesn't find any
+ // leaked goroutines.
+ if got, want := res.Header.Get("Foo"), "Bar"; got != want {
+ t.Errorf("Foo header = %q; want %q", got, want)
+ }
+}
+
+type closerFunc func() error
+
+func (f closerFunc) Close() error { return f() }
+
+type writerFuncConn struct {
+ net.Conn
+ write func(p []byte) (n int, err error)
+}
+
+func (c writerFuncConn) Write(p []byte) (n int, err error) { return c.write(p) }
+
+// Issues 4677, 18241, and 17844. If we try to reuse a connection that the
+// server is in the process of closing, we may end up successfully writing out
+// our request (or a portion of our request) only to find a connection error
+// when we try to read from (or finish writing to) the socket.
+//
+// NOTE: we resend a request only if:
+// - we reused a keep-alive connection
+// - we haven't yet received any header data
+// - either we wrote no bytes to the server, or the request is idempotent
+//
+// This automatically prevents an infinite resend loop because we'll run out of
+// the cached keep-alive connections eventually.
+func TestRetryRequestsOnError(t *testing.T) {
+ run(t, testRetryRequestsOnError, testNotParallel, []testMode{http1Mode})
+}
+func testRetryRequestsOnError(t *testing.T, mode testMode) {
+ newRequest := func(method, urlStr string, body io.Reader) *Request {
+ req, err := NewRequest(method, urlStr, body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ return req
+ }
+
+ testCases := []struct {
+ name string
+ failureN int
+ failureErr error
+ // Note that we can't just re-use the Request object across calls to c.Do
+ // because we need to rewind Body between calls. (GetBody is only used to
+ // rewind Body on failure and redirects, not just because it's done.)
+ req func() *Request
+ reqString string
+ }{
+ {
+ name: "IdempotentNoBodySomeWritten",
+ // Believe that we've written some bytes to the server, so we know we're
+ // not just in the "retry when no bytes sent" case".
+ failureN: 1,
+ // Use the specific error that shouldRetryRequest looks for with idempotent requests.
+ failureErr: ExportErrServerClosedIdle,
+ req: func() *Request {
+ return newRequest("GET", "http://fake.golang", nil)
+ },
+ reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`,
+ },
+ {
+ name: "IdempotentGetBodySomeWritten",
+ // Believe that we've written some bytes to the server, so we know we're
+ // not just in the "retry when no bytes sent" case".
+ failureN: 1,
+ // Use the specific error that shouldRetryRequest looks for with idempotent requests.
+ failureErr: ExportErrServerClosedIdle,
+ req: func() *Request {
+ return newRequest("GET", "http://fake.golang", strings.NewReader("foo\n"))
+ },
+ reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`,
+ },
+ {
+ name: "NothingWrittenNoBody",
+ // It's key that we return 0 here -- that's what enables Transport to know
+ // that nothing was written, even though this is a non-idempotent request.
+ failureN: 0,
+ failureErr: errors.New("second write fails"),
+ req: func() *Request {
+ return newRequest("DELETE", "http://fake.golang", nil)
+ },
+ reqString: `DELETE / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`,
+ },
+ {
+ name: "NothingWrittenGetBody",
+ // It's key that we return 0 here -- that's what enables Transport to know
+ // that nothing was written, even though this is a non-idempotent request.
+ failureN: 0,
+ failureErr: errors.New("second write fails"),
+ // Note that NewRequest will set up GetBody for strings.Reader, which is
+ // required for the retry to occur
+ req: func() *Request {
+ return newRequest("POST", "http://fake.golang", strings.NewReader("foo\n"))
+ },
+ reqString: `POST / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ var (
+ mu sync.Mutex
+ logbuf strings.Builder
+ )
+ logf := func(format string, args ...any) {
+ mu.Lock()
+ defer mu.Unlock()
+ fmt.Fprintf(&logbuf, format, args...)
+ logbuf.WriteByte('\n')
+ }
+
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ logf("Handler")
+ w.Header().Set("X-Status", "ok")
+ })).ts
+
+ var writeNumAtomic int32
+ c := ts.Client()
+ c.Transport.(*Transport).Dial = func(network, addr string) (net.Conn, error) {
+ logf("Dial")
+ c, err := net.Dial(network, ts.Listener.Addr().String())
+ if err != nil {
+ logf("Dial error: %v", err)
+ return nil, err
+ }
+ return &writerFuncConn{
+ Conn: c,
+ write: func(p []byte) (n int, err error) {
+ if atomic.AddInt32(&writeNumAtomic, 1) == 2 {
+ logf("intentional write failure")
+ return tc.failureN, tc.failureErr
+ }
+ logf("Write(%q)", p)
+ return c.Write(p)
+ },
+ }, nil
+ }
+
+ SetRoundTripRetried(func() {
+ logf("Retried.")
+ })
+ defer SetRoundTripRetried(nil)
+
+ for i := 0; i < 3; i++ {
+ t0 := time.Now()
+ req := tc.req()
+ res, err := c.Do(req)
+ if err != nil {
+ if time.Since(t0) < *MaxWriteWaitBeforeConnReuse/2 {
+ mu.Lock()
+ got := logbuf.String()
+ mu.Unlock()
+ t.Fatalf("i=%d: Do = %v; log:\n%s", i, err, got)
+ }
+ t.Skipf("connection likely wasn't recycled within %d, interfering with actual test; skipping", *MaxWriteWaitBeforeConnReuse)
+ }
+ res.Body.Close()
+ if res.Request != req {
+ t.Errorf("Response.Request != original request; want identical Request")
+ }
+ }
+
+ mu.Lock()
+ got := logbuf.String()
+ mu.Unlock()
+ want := fmt.Sprintf(`Dial
+Write("%s")
+Handler
+intentional write failure
+Retried.
+Dial
+Write("%s")
+Handler
+Write("%s")
+Handler
+`, tc.reqString, tc.reqString, tc.reqString)
+ if got != want {
+ t.Errorf("Log of events differs. Got:\n%s\nWant:\n%s", got, want)
+ }
+ })
+ }
+}
+
+// Issue 6981
+func TestTransportClosesBodyOnError(t *testing.T) { run(t, testTransportClosesBodyOnError) }
+func testTransportClosesBodyOnError(t *testing.T, mode testMode) {
+ readBody := make(chan error, 1)
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ _, err := io.ReadAll(r.Body)
+ readBody <- err
+ })).ts
+ c := ts.Client()
+ fakeErr := errors.New("fake error")
+ didClose := make(chan bool, 1)
+ req, _ := NewRequest("POST", ts.URL, struct {
+ io.Reader
+ io.Closer
+ }{
+ io.MultiReader(io.LimitReader(neverEnding('x'), 1<<20), iotest.ErrReader(fakeErr)),
+ closerFunc(func() error {
+ select {
+ case didClose <- true:
+ default:
+ }
+ return nil
+ }),
+ })
+ res, err := c.Do(req)
+ if res != nil {
+ defer res.Body.Close()
+ }
+ if err == nil || !strings.Contains(err.Error(), fakeErr.Error()) {
+ t.Fatalf("Do error = %v; want something containing %q", err, fakeErr.Error())
+ }
+ if err := <-readBody; err == nil {
+ t.Errorf("Unexpected success reading request body from handler; want 'unexpected EOF reading trailer'")
+ }
+ select {
+ case <-didClose:
+ default:
+ t.Errorf("didn't see Body.Close")
+ }
+}
+
+func TestTransportDialTLS(t *testing.T) {
+ run(t, testTransportDialTLS, []testMode{https1Mode, http2Mode})
+}
+func testTransportDialTLS(t *testing.T, mode testMode) {
+ var mu sync.Mutex // guards following
+ var gotReq, didDial bool
+
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ mu.Lock()
+ gotReq = true
+ mu.Unlock()
+ })).ts
+ c := ts.Client()
+ c.Transport.(*Transport).DialTLS = func(netw, addr string) (net.Conn, error) {
+ mu.Lock()
+ didDial = true
+ mu.Unlock()
+ c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
+ if err != nil {
+ return nil, err
+ }
+ return c, c.Handshake()
+ }
+
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ mu.Lock()
+ if !gotReq {
+ t.Error("didn't get request")
+ }
+ if !didDial {
+ t.Error("didn't use dial hook")
+ }
+}
+
+func TestTransportDialContext(t *testing.T) { run(t, testTransportDialContext) }
+func testTransportDialContext(t *testing.T, mode testMode) {
+ var mu sync.Mutex // guards following
+ var gotReq bool
+ var receivedContext context.Context
+
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ mu.Lock()
+ gotReq = true
+ mu.Unlock()
+ })).ts
+ c := ts.Client()
+ c.Transport.(*Transport).DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
+ mu.Lock()
+ receivedContext = ctx
+ mu.Unlock()
+ return net.Dial(netw, addr)
+ }
+
+ req, err := NewRequest("GET", ts.URL, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ ctx := context.WithValue(context.Background(), "some-key", "some-value")
+ res, err := c.Do(req.WithContext(ctx))
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ mu.Lock()
+ if !gotReq {
+ t.Error("didn't get request")
+ }
+ if receivedContext != ctx {
+ t.Error("didn't receive correct context")
+ }
+}
+
+func TestTransportDialTLSContext(t *testing.T) {
+ run(t, testTransportDialTLSContext, []testMode{https1Mode, http2Mode})
+}
+func testTransportDialTLSContext(t *testing.T, mode testMode) {
+ var mu sync.Mutex // guards following
+ var gotReq bool
+ var receivedContext context.Context
+
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ mu.Lock()
+ gotReq = true
+ mu.Unlock()
+ })).ts
+ c := ts.Client()
+ c.Transport.(*Transport).DialTLSContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
+ mu.Lock()
+ receivedContext = ctx
+ mu.Unlock()
+ c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
+ if err != nil {
+ return nil, err
+ }
+ return c, c.HandshakeContext(ctx)
+ }
+
+ req, err := NewRequest("GET", ts.URL, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ ctx := context.WithValue(context.Background(), "some-key", "some-value")
+ res, err := c.Do(req.WithContext(ctx))
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ mu.Lock()
+ if !gotReq {
+ t.Error("didn't get request")
+ }
+ if receivedContext != ctx {
+ t.Error("didn't receive correct context")
+ }
+}
+
+// Test for issue 8755
+// Ensure that if a proxy returns an error, it is exposed by RoundTrip
+func TestRoundTripReturnsProxyError(t *testing.T) {
+ badProxy := func(*Request) (*url.URL, error) {
+ return nil, errors.New("errorMessage")
+ }
+
+ tr := &Transport{Proxy: badProxy}
+
+ req, _ := NewRequest("GET", "http://example.com", nil)
+
+ _, err := tr.RoundTrip(req)
+
+ if err == nil {
+ t.Error("Expected proxy error to be returned by RoundTrip")
+ }
+}
+
+// tests that putting an idle conn after a call to CloseIdleConns does return it
+func TestTransportCloseIdleConnsThenReturn(t *testing.T) {
+ tr := &Transport{}
+ wantIdle := func(when string, n int) bool {
+ got := tr.IdleConnCountForTesting("http", "example.com") // key used by PutIdleTestConn
+ if got == n {
+ return true
+ }
+ t.Errorf("%s: idle conns = %d; want %d", when, got, n)
+ return false
+ }
+ wantIdle("start", 0)
+ if !tr.PutIdleTestConn("http", "example.com") {
+ t.Fatal("put failed")
+ }
+ if !tr.PutIdleTestConn("http", "example.com") {
+ t.Fatal("second put failed")
+ }
+ wantIdle("after put", 2)
+ tr.CloseIdleConnections()
+ if !tr.IsIdleForTesting() {
+ t.Error("should be idle after CloseIdleConnections")
+ }
+ wantIdle("after close idle", 0)
+ if tr.PutIdleTestConn("http", "example.com") {
+ t.Fatal("put didn't fail")
+ }
+ wantIdle("after second put", 0)
+
+ tr.QueueForIdleConnForTesting() // should toggle the transport out of idle mode
+ if tr.IsIdleForTesting() {
+ t.Error("shouldn't be idle after QueueForIdleConnForTesting")
+ }
+ if !tr.PutIdleTestConn("http", "example.com") {
+ t.Fatal("after re-activation")
+ }
+ wantIdle("after final put", 1)
+}
+
+// Test for issue 34282
+// Ensure that getConn doesn't call the GotConn trace hook on an HTTP/2 idle conn
+func TestTransportTraceGotConnH2IdleConns(t *testing.T) {
+ tr := &Transport{}
+ wantIdle := func(when string, n int) bool {
+ got := tr.IdleConnCountForTesting("https", "example.com:443") // key used by PutIdleTestConnH2
+ if got == n {
+ return true
+ }
+ t.Errorf("%s: idle conns = %d; want %d", when, got, n)
+ return false
+ }
+ wantIdle("start", 0)
+ alt := funcRoundTripper(func() {})
+ if !tr.PutIdleTestConnH2("https", "example.com:443", alt) {
+ t.Fatal("put failed")
+ }
+ wantIdle("after put", 1)
+ ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
+ GotConn: func(httptrace.GotConnInfo) {
+ // tr.getConn should leave it for the HTTP/2 alt to call GotConn.
+ t.Error("GotConn called")
+ },
+ })
+ req, _ := NewRequestWithContext(ctx, MethodGet, "https://example.com", nil)
+ _, err := tr.RoundTrip(req)
+ if err != errFakeRoundTrip {
+ t.Errorf("got error: %v; want %q", err, errFakeRoundTrip)
+ }
+ wantIdle("after round trip", 1)
+}
+
+func TestTransportRemovesH2ConnsAfterIdle(t *testing.T) {
+ run(t, testTransportRemovesH2ConnsAfterIdle, []testMode{http2Mode})
+}
+func testTransportRemovesH2ConnsAfterIdle(t *testing.T, mode testMode) {
+ if testing.Short() {
+ t.Skip("skipping in short mode")
+ }
+
+ timeout := 1 * time.Millisecond
+ retry := true
+ for retry {
+ trFunc := func(tr *Transport) {
+ tr.MaxConnsPerHost = 1
+ tr.MaxIdleConnsPerHost = 1
+ tr.IdleConnTimeout = timeout
+ }
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), trFunc)
+
+ retry = false
+ tooShort := func(err error) bool {
+ if err == nil || !strings.Contains(err.Error(), "use of closed network connection") {
+ return false
+ }
+ if !retry {
+ t.Helper()
+ t.Logf("idle conn timeout %v may be too short; retrying with longer", timeout)
+ timeout *= 2
+ retry = true
+ cst.close()
+ }
+ return true
+ }
+
+ if _, err := cst.c.Get(cst.ts.URL); err != nil {
+ if tooShort(err) {
+ continue
+ }
+ t.Fatalf("got error: %s", err)
+ }
+
+ time.Sleep(10 * timeout)
+ if _, err := cst.c.Get(cst.ts.URL); err != nil {
+ if tooShort(err) {
+ continue
+ }
+ t.Fatalf("got error: %s", err)
+ }
+ }
+}
+
+// This tests that a client requesting a content range won't also
+// implicitly ask for gzip support. If they want that, they need to do it
+// on their own.
+// golang.org/issue/8923
+func TestTransportRangeAndGzip(t *testing.T) { run(t, testTransportRangeAndGzip) }
+func testTransportRangeAndGzip(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ if strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
+ t.Error("Transport advertised gzip support in the Accept header")
+ }
+ if r.Header.Get("Range") == "" {
+ t.Error("no Range in request")
+ }
+ })).ts
+ c := ts.Client()
+
+ req, _ := NewRequest("GET", ts.URL, nil)
+ req.Header.Set("Range", "bytes=7-11")
+ res, err := c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+}
+
+// Test for issue 10474
+func TestTransportResponseCancelRace(t *testing.T) { run(t, testTransportResponseCancelRace) }
+func testTransportResponseCancelRace(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ // important that this response has a body.
+ var b [1024]byte
+ w.Write(b[:])
+ })).ts
+ tr := ts.Client().Transport.(*Transport)
+
+ req, err := NewRequest("GET", ts.URL, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res, err := tr.RoundTrip(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ // If we do an early close, Transport just throws the connection away and
+ // doesn't reuse it. In order to trigger the bug, it has to reuse the connection
+ // so read the body
+ if _, err := io.Copy(io.Discard, res.Body); err != nil {
+ t.Fatal(err)
+ }
+
+ req2, err := NewRequest("GET", ts.URL, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ tr.CancelRequest(req)
+ res, err = tr.RoundTrip(req2)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+}
+
+// Test for issue 19248: Content-Encoding's value is case insensitive.
+func TestTransportContentEncodingCaseInsensitive(t *testing.T) {
+ run(t, testTransportContentEncodingCaseInsensitive)
+}
+func testTransportContentEncodingCaseInsensitive(t *testing.T, mode testMode) {
+ for _, ce := range []string{"gzip", "GZIP"} {
+ ce := ce
+ t.Run(ce, func(t *testing.T) {
+ const encodedString = "Hello Gopher"
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Encoding", ce)
+ gz := gzip.NewWriter(w)
+ gz.Write([]byte(encodedString))
+ gz.Close()
+ })).ts
+
+ res, err := ts.Client().Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ body, err := io.ReadAll(res.Body)
+ res.Body.Close()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if string(body) != encodedString {
+ t.Fatalf("Expected body %q, got: %q\n", encodedString, string(body))
+ }
+ })
+ }
+}
+
+func TestTransportDialCancelRace(t *testing.T) {
+ run(t, testTransportDialCancelRace, testNotParallel, []testMode{http1Mode})
+}
+func testTransportDialCancelRace(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts
+ tr := ts.Client().Transport.(*Transport)
+
+ req, err := NewRequest("GET", ts.URL, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ SetEnterRoundTripHook(func() {
+ tr.CancelRequest(req)
+ })
+ defer SetEnterRoundTripHook(nil)
+ res, err := tr.RoundTrip(req)
+ if err != ExportErrRequestCanceled {
+ t.Errorf("expected canceled request error; got %v", err)
+ if err == nil {
+ res.Body.Close()
+ }
+ }
+}
+
+// https://go.dev/issue/49621
+func TestConnClosedBeforeRequestIsWritten(t *testing.T) {
+ run(t, testConnClosedBeforeRequestIsWritten, testNotParallel, []testMode{http1Mode})
+}
+func testConnClosedBeforeRequestIsWritten(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}),
+ func(tr *Transport) {
+ tr.DialContext = func(_ context.Context, network, addr string) (net.Conn, error) {
+ // Connection immediately returns errors.
+ return &funcConn{
+ read: func([]byte) (int, error) {
+ return 0, errors.New("error")
+ },
+ write: func([]byte) (int, error) {
+ return 0, errors.New("error")
+ },
+ }, nil
+ }
+ },
+ ).ts
+ // Set a short delay in RoundTrip to give the persistConn time to notice
+ // the connection is broken. We want to exercise the path where writeLoop exits
+ // before it reads the request to send. If this delay is too short, we may instead
+ // exercise the path where writeLoop accepts the request and then fails to write it.
+ // That's fine, so long as we get the desired path often enough.
+ SetEnterRoundTripHook(func() {
+ time.Sleep(1 * time.Millisecond)
+ })
+ defer SetEnterRoundTripHook(nil)
+ var closes int
+ _, err := ts.Client().Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
+ if err == nil {
+ t.Fatalf("expected request to fail, but it did not")
+ }
+ if closes != 1 {
+ t.Errorf("after RoundTrip, request body was closed %v times; want 1", closes)
+ }
+}
+
+// logWritesConn is a net.Conn that logs each Write call to writes
+// and then proxies to w.
+// It proxies Read calls to a reader it receives from rch.
+type logWritesConn struct {
+ net.Conn // nil. crash on use.
+
+ w io.Writer
+
+ rch <-chan io.Reader
+ r io.Reader // nil until received by rch
+
+ mu sync.Mutex
+ writes []string
+}
+
+func (c *logWritesConn) Write(p []byte) (n int, err error) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ c.writes = append(c.writes, string(p))
+ return c.w.Write(p)
+}
+
+func (c *logWritesConn) Read(p []byte) (n int, err error) {
+ if c.r == nil {
+ c.r = <-c.rch
+ }
+ return c.r.Read(p)
+}
+
+func (c *logWritesConn) Close() error { return nil }
+
+// Issue 6574
+func TestTransportFlushesBodyChunks(t *testing.T) {
+ defer afterTest(t)
+ resBody := make(chan io.Reader, 1)
+ connr, connw := io.Pipe() // connection pipe pair
+ lw := &logWritesConn{
+ rch: resBody,
+ w: connw,
+ }
+ tr := &Transport{
+ Dial: func(network, addr string) (net.Conn, error) {
+ return lw, nil
+ },
+ }
+ bodyr, bodyw := io.Pipe() // body pipe pair
+ go func() {
+ defer bodyw.Close()
+ for i := 0; i < 3; i++ {
+ fmt.Fprintf(bodyw, "num%d\n", i)
+ }
+ }()
+ resc := make(chan *Response)
+ go func() {
+ req, _ := NewRequest("POST", "http://localhost:8080", bodyr)
+ req.Header.Set("User-Agent", "x") // known value for test
+ res, err := tr.RoundTrip(req)
+ if err != nil {
+ t.Errorf("RoundTrip: %v", err)
+ close(resc)
+ return
+ }
+ resc <- res
+
+ }()
+ // Fully consume the request before checking the Write log vs. want.
+ req, err := ReadRequest(bufio.NewReader(connr))
+ if err != nil {
+ t.Fatal(err)
+ }
+ io.Copy(io.Discard, req.Body)
+
+ // Unblock the transport's roundTrip goroutine.
+ resBody <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n")
+ res, ok := <-resc
+ if !ok {
+ return
+ }
+ defer res.Body.Close()
+
+ want := []string{
+ "POST / HTTP/1.1\r\nHost: localhost:8080\r\nUser-Agent: x\r\nTransfer-Encoding: chunked\r\nAccept-Encoding: gzip\r\n\r\n",
+ "5\r\nnum0\n\r\n",
+ "5\r\nnum1\n\r\n",
+ "5\r\nnum2\n\r\n",
+ "0\r\n\r\n",
+ }
+ if !reflect.DeepEqual(lw.writes, want) {
+ t.Errorf("Writes differed.\n Got: %q\nWant: %q\n", lw.writes, want)
+ }
+}
+
+// Issue 22088: flush Transport request headers if we're not sure the body won't block on read.
+func TestTransportFlushesRequestHeader(t *testing.T) { run(t, testTransportFlushesRequestHeader) }
+func testTransportFlushesRequestHeader(t *testing.T, mode testMode) {
+ gotReq := make(chan struct{})
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ close(gotReq)
+ }))
+
+ pr, pw := io.Pipe()
+ req, err := NewRequest("POST", cst.ts.URL, pr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ gotRes := make(chan struct{})
+ go func() {
+ defer close(gotRes)
+ res, err := cst.tr.RoundTrip(req)
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ res.Body.Close()
+ }()
+
+ <-gotReq
+ pw.Close()
+ <-gotRes
+}
+
+type wgReadCloser struct {
+ io.Reader
+ wg *sync.WaitGroup
+ closed bool
+}
+
+func (c *wgReadCloser) Close() error {
+ if c.closed {
+ return net.ErrClosed
+ }
+ c.closed = true
+ c.wg.Done()
+ return nil
+}
+
+// Issue 11745.
+func TestTransportPrefersResponseOverWriteError(t *testing.T) {
+ run(t, testTransportPrefersResponseOverWriteError)
+}
+func testTransportPrefersResponseOverWriteError(t *testing.T, mode testMode) {
+ if testing.Short() {
+ t.Skip("skipping in short mode")
+ }
+ const contentLengthLimit = 1024 * 1024 // 1MB
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.ContentLength >= contentLengthLimit {
+ w.WriteHeader(StatusBadRequest)
+ r.Body.Close()
+ return
+ }
+ w.WriteHeader(StatusOK)
+ })).ts
+ c := ts.Client()
+
+ fail := 0
+ count := 100
+
+ bigBody := strings.Repeat("a", contentLengthLimit*2)
+ var wg sync.WaitGroup
+ defer wg.Wait()
+ getBody := func() (io.ReadCloser, error) {
+ wg.Add(1)
+ body := &wgReadCloser{
+ Reader: strings.NewReader(bigBody),
+ wg: &wg,
+ }
+ return body, nil
+ }
+
+ for i := 0; i < count; i++ {
+ reqBody, _ := getBody()
+ req, err := NewRequest("PUT", ts.URL, reqBody)
+ if err != nil {
+ reqBody.Close()
+ t.Fatal(err)
+ }
+ req.ContentLength = int64(len(bigBody))
+ req.GetBody = getBody
+
+ resp, err := c.Do(req)
+ if err != nil {
+ fail++
+ t.Logf("%d = %#v", i, err)
+ if ue, ok := err.(*url.Error); ok {
+ t.Logf("urlErr = %#v", ue.Err)
+ if ne, ok := ue.Err.(*net.OpError); ok {
+ t.Logf("netOpError = %#v", ne.Err)
+ }
+ }
+ } else {
+ resp.Body.Close()
+ if resp.StatusCode != 400 {
+ t.Errorf("Expected status code 400, got %v", resp.Status)
+ }
+ }
+ }
+ if fail > 0 {
+ t.Errorf("Failed %v out of %v\n", fail, count)
+ }
+}
+
+func TestTransportAutomaticHTTP2(t *testing.T) {
+ testTransportAutoHTTP(t, &Transport{}, true)
+}
+
+func TestTransportAutomaticHTTP2_DialerAndTLSConfigSupportsHTTP2AndTLSConfig(t *testing.T) {
+ testTransportAutoHTTP(t, &Transport{
+ ForceAttemptHTTP2: true,
+ TLSClientConfig: new(tls.Config),
+ }, true)
+}
+
+// golang.org/issue/14391: also check DefaultTransport
+func TestTransportAutomaticHTTP2_DefaultTransport(t *testing.T) {
+ testTransportAutoHTTP(t, DefaultTransport.(*Transport), true)
+}
+
+func TestTransportAutomaticHTTP2_TLSNextProto(t *testing.T) {
+ testTransportAutoHTTP(t, &Transport{
+ TLSNextProto: make(map[string]func(string, *tls.Conn) RoundTripper),
+ }, false)
+}
+
+func TestTransportAutomaticHTTP2_TLSConfig(t *testing.T) {
+ testTransportAutoHTTP(t, &Transport{
+ TLSClientConfig: new(tls.Config),
+ }, false)
+}
+
+func TestTransportAutomaticHTTP2_ExpectContinueTimeout(t *testing.T) {
+ testTransportAutoHTTP(t, &Transport{
+ ExpectContinueTimeout: 1 * time.Second,
+ }, true)
+}
+
+func TestTransportAutomaticHTTP2_Dial(t *testing.T) {
+ var d net.Dialer
+ testTransportAutoHTTP(t, &Transport{
+ Dial: d.Dial,
+ }, false)
+}
+
+func TestTransportAutomaticHTTP2_DialContext(t *testing.T) {
+ var d net.Dialer
+ testTransportAutoHTTP(t, &Transport{
+ DialContext: d.DialContext,
+ }, false)
+}
+
+func TestTransportAutomaticHTTP2_DialTLS(t *testing.T) {
+ testTransportAutoHTTP(t, &Transport{
+ DialTLS: func(network, addr string) (net.Conn, error) {
+ panic("unused")
+ },
+ }, false)
+}
+
+func testTransportAutoHTTP(t *testing.T, tr *Transport, wantH2 bool) {
+ CondSkipHTTP2(t)
+ _, err := tr.RoundTrip(new(Request))
+ if err == nil {
+ t.Error("expected error from RoundTrip")
+ }
+ if reg := tr.TLSNextProto["h2"] != nil; reg != wantH2 {
+ t.Errorf("HTTP/2 registered = %v; want %v", reg, wantH2)
+ }
+}
+
+// Issue 13633: there was a race where we returned bodyless responses
+// to callers before recycling the persistent connection, which meant
+// a client doing two subsequent requests could end up on different
+// connections. It's somewhat harmless but enough tests assume it's
+// not true in order to test other things that it's worth fixing.
+// Plus it's nice to be consistent and not have timing-dependent
+// behavior.
+func TestTransportReuseConnEmptyResponseBody(t *testing.T) {
+ run(t, testTransportReuseConnEmptyResponseBody)
+}
+func testTransportReuseConnEmptyResponseBody(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("X-Addr", r.RemoteAddr)
+ // Empty response body.
+ }))
+ n := 100
+ if testing.Short() {
+ n = 10
+ }
+ var firstAddr string
+ for i := 0; i < n; i++ {
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ log.Fatal(err)
+ }
+ addr := res.Header.Get("X-Addr")
+ if i == 0 {
+ firstAddr = addr
+ } else if addr != firstAddr {
+ t.Fatalf("On request %d, addr %q != original addr %q", i+1, addr, firstAddr)
+ }
+ res.Body.Close()
+ }
+}
+
+// Issue 13839
+func TestNoCrashReturningTransportAltConn(t *testing.T) {
+ cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
+ if err != nil {
+ t.Fatal(err)
+ }
+ ln := newLocalListener(t)
+ defer ln.Close()
+
+ var wg sync.WaitGroup
+ SetPendingDialHooks(func() { wg.Add(1) }, wg.Done)
+ defer SetPendingDialHooks(nil, nil)
+
+ testDone := make(chan struct{})
+ defer close(testDone)
+ go func() {
+ tln := tls.NewListener(ln, &tls.Config{
+ NextProtos: []string{"foo"},
+ Certificates: []tls.Certificate{cert},
+ })
+ sc, err := tln.Accept()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ if err := sc.(*tls.Conn).Handshake(); err != nil {
+ t.Error(err)
+ return
+ }
+ <-testDone
+ sc.Close()
+ }()
+
+ addr := ln.Addr().String()
+
+ req, _ := NewRequest("GET", "https://fake.tld/", nil)
+ cancel := make(chan struct{})
+ req.Cancel = cancel
+
+ doReturned := make(chan bool, 1)
+ madeRoundTripper := make(chan bool, 1)
+
+ tr := &Transport{
+ DisableKeepAlives: true,
+ TLSNextProto: map[string]func(string, *tls.Conn) RoundTripper{
+ "foo": func(authority string, c *tls.Conn) RoundTripper {
+ madeRoundTripper <- true
+ return funcRoundTripper(func() {
+ t.Error("foo RoundTripper should not be called")
+ })
+ },
+ },
+ Dial: func(_, _ string) (net.Conn, error) {
+ panic("shouldn't be called")
+ },
+ DialTLS: func(_, _ string) (net.Conn, error) {
+ tc, err := tls.Dial("tcp", addr, &tls.Config{
+ InsecureSkipVerify: true,
+ NextProtos: []string{"foo"},
+ })
+ if err != nil {
+ return nil, err
+ }
+ if err := tc.Handshake(); err != nil {
+ return nil, err
+ }
+ close(cancel)
+ <-doReturned
+ return tc, nil
+ },
+ }
+ c := &Client{Transport: tr}
+
+ _, err = c.Do(req)
+ if ue, ok := err.(*url.Error); !ok || ue.Err != ExportErrRequestCanceledConn {
+ t.Fatalf("Do error = %v; want url.Error with errRequestCanceledConn", err)
+ }
+
+ doReturned <- true
+ <-madeRoundTripper
+ wg.Wait()
+}
+
+func TestTransportReuseConnection_Gzip_Chunked(t *testing.T) {
+ run(t, func(t *testing.T, mode testMode) {
+ testTransportReuseConnection_Gzip(t, mode, true)
+ })
+}
+
+func TestTransportReuseConnection_Gzip_ContentLength(t *testing.T) {
+ run(t, func(t *testing.T, mode testMode) {
+ testTransportReuseConnection_Gzip(t, mode, false)
+ })
+}
+
+// Make sure we re-use underlying TCP connection for gzipped responses too.
+func testTransportReuseConnection_Gzip(t *testing.T, mode testMode, chunked bool) {
+ addr := make(chan string, 2)
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ addr <- r.RemoteAddr
+ w.Header().Set("Content-Encoding", "gzip")
+ if chunked {
+ w.(Flusher).Flush()
+ }
+ w.Write(rgz) // arbitrary gzip response
+ })).ts
+ c := ts.Client()
+
+ trace := &httptrace.ClientTrace{
+ GetConn: func(hostPort string) { t.Logf("GetConn(%q)", hostPort) },
+ GotConn: func(ci httptrace.GotConnInfo) { t.Logf("GotConn(%+v)", ci) },
+ PutIdleConn: func(err error) { t.Logf("PutIdleConn(%v)", err) },
+ ConnectStart: func(network, addr string) { t.Logf("ConnectStart(%q, %q)", network, addr) },
+ ConnectDone: func(network, addr string, err error) { t.Logf("ConnectDone(%q, %q, %v)", network, addr, err) },
+ }
+ ctx := httptrace.WithClientTrace(context.Background(), trace)
+
+ for i := 0; i < 2; i++ {
+ req, _ := NewRequest("GET", ts.URL, nil)
+ req = req.WithContext(ctx)
+ res, err := c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ buf := make([]byte, len(rgz))
+ if n, err := io.ReadFull(res.Body, buf); err != nil {
+ t.Errorf("%d. ReadFull = %v, %v", i, n, err)
+ }
+ // Note: no res.Body.Close call. It should work without it,
+ // since the flate.Reader's internal buffering will hit EOF
+ // and that should be sufficient.
+ }
+ a1, a2 := <-addr, <-addr
+ if a1 != a2 {
+ t.Fatalf("didn't reuse connection")
+ }
+}
+
+func TestTransportResponseHeaderLength(t *testing.T) { run(t, testTransportResponseHeaderLength) }
+func testTransportResponseHeaderLength(t *testing.T, mode testMode) {
+ if mode == http2Mode {
+ t.Skip("HTTP/2 Transport doesn't support MaxResponseHeaderBytes")
+ }
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.URL.Path == "/long" {
+ w.Header().Set("Long", strings.Repeat("a", 1<<20))
+ }
+ })).ts
+ c := ts.Client()
+ c.Transport.(*Transport).MaxResponseHeaderBytes = 512 << 10
+
+ if res, err := c.Get(ts.URL); err != nil {
+ t.Fatal(err)
+ } else {
+ res.Body.Close()
+ }
+
+ res, err := c.Get(ts.URL + "/long")
+ if err == nil {
+ defer res.Body.Close()
+ var n int64
+ for k, vv := range res.Header {
+ for _, v := range vv {
+ n += int64(len(k)) + int64(len(v))
+ }
+ }
+ t.Fatalf("Unexpected success. Got %v and %d bytes of response headers", res.Status, n)
+ }
+ if want := "server response headers exceeded 524288 bytes"; !strings.Contains(err.Error(), want) {
+ t.Errorf("got error: %v; want %q", err, want)
+ }
+}
+
+func TestTransportEventTrace(t *testing.T) {
+ run(t, func(t *testing.T, mode testMode) {
+ testTransportEventTrace(t, mode, false)
+ }, testNotParallel)
+}
+
+// test a non-nil httptrace.ClientTrace but with all hooks set to zero.
+func TestTransportEventTrace_NoHooks(t *testing.T) {
+ run(t, func(t *testing.T, mode testMode) {
+ testTransportEventTrace(t, mode, true)
+ }, testNotParallel)
+}
+
+func testTransportEventTrace(t *testing.T, mode testMode, noHooks bool) {
+ const resBody = "some body"
+ gotWroteReqEvent := make(chan struct{}, 500)
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.Method == "GET" {
+ // Do nothing for the second request.
+ return
+ }
+ if _, err := io.ReadAll(r.Body); err != nil {
+ t.Error(err)
+ }
+ if !noHooks {
+ <-gotWroteReqEvent
+ }
+ io.WriteString(w, resBody)
+ }), func(tr *Transport) {
+ if tr.TLSClientConfig != nil {
+ tr.TLSClientConfig.InsecureSkipVerify = true
+ }
+ })
+ defer cst.close()
+
+ cst.tr.ExpectContinueTimeout = 1 * time.Second
+
+ var mu sync.Mutex // guards buf
+ var buf strings.Builder
+ logf := func(format string, args ...any) {
+ mu.Lock()
+ defer mu.Unlock()
+ fmt.Fprintf(&buf, format, args...)
+ buf.WriteByte('\n')
+ }
+
+ addrStr := cst.ts.Listener.Addr().String()
+ ip, port, err := net.SplitHostPort(addrStr)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Install a fake DNS server.
+ ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) {
+ if host != "dns-is-faked.golang" {
+ t.Errorf("unexpected DNS host lookup for %q/%q", network, host)
+ return nil, nil
+ }
+ return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
+ })
+
+ body := "some body"
+ req, _ := NewRequest("POST", cst.scheme()+"://dns-is-faked.golang:"+port, strings.NewReader(body))
+ req.Header["X-Foo-Multiple-Vals"] = []string{"bar", "baz"}
+ trace := &httptrace.ClientTrace{
+ GetConn: func(hostPort string) { logf("Getting conn for %v ...", hostPort) },
+ GotConn: func(ci httptrace.GotConnInfo) { logf("got conn: %+v", ci) },
+ GotFirstResponseByte: func() { logf("first response byte") },
+ PutIdleConn: func(err error) { logf("PutIdleConn = %v", err) },
+ DNSStart: func(e httptrace.DNSStartInfo) { logf("DNS start: %+v", e) },
+ DNSDone: func(e httptrace.DNSDoneInfo) { logf("DNS done: %+v", e) },
+ ConnectStart: func(network, addr string) { logf("ConnectStart: Connecting to %s %s ...", network, addr) },
+ ConnectDone: func(network, addr string, err error) {
+ if err != nil {
+ t.Errorf("ConnectDone: %v", err)
+ }
+ logf("ConnectDone: connected to %s %s = %v", network, addr, err)
+ },
+ WroteHeaderField: func(key string, value []string) {
+ logf("WroteHeaderField: %s: %v", key, value)
+ },
+ WroteHeaders: func() {
+ logf("WroteHeaders")
+ },
+ Wait100Continue: func() { logf("Wait100Continue") },
+ Got100Continue: func() { logf("Got100Continue") },
+ WroteRequest: func(e httptrace.WroteRequestInfo) {
+ logf("WroteRequest: %+v", e)
+ gotWroteReqEvent <- struct{}{}
+ },
+ }
+ if mode == http2Mode {
+ trace.TLSHandshakeStart = func() { logf("tls handshake start") }
+ trace.TLSHandshakeDone = func(s tls.ConnectionState, err error) {
+ logf("tls handshake done. ConnectionState = %v \n err = %v", s, err)
+ }
+ }
+ if noHooks {
+ // zero out all func pointers, trying to get some path to crash
+ *trace = httptrace.ClientTrace{}
+ }
+ req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
+
+ req.Header.Set("Expect", "100-continue")
+ res, err := cst.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ logf("got roundtrip.response")
+ slurp, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ logf("consumed body")
+ if string(slurp) != resBody || res.StatusCode != 200 {
+ t.Fatalf("Got %q, %v; want %q, 200 OK", slurp, res.Status, resBody)
+ }
+ res.Body.Close()
+
+ if noHooks {
+ // Done at this point. Just testing a full HTTP
+ // requests can happen with a trace pointing to a zero
+ // ClientTrace, full of nil func pointers.
+ return
+ }
+
+ mu.Lock()
+ got := buf.String()
+ mu.Unlock()
+
+ wantOnce := func(sub string) {
+ if strings.Count(got, sub) != 1 {
+ t.Errorf("expected substring %q exactly once in output.", sub)
+ }
+ }
+ wantOnceOrMore := func(sub string) {
+ if strings.Count(got, sub) == 0 {
+ t.Errorf("expected substring %q at least once in output.", sub)
+ }
+ }
+ wantOnce("Getting conn for dns-is-faked.golang:" + port)
+ wantOnce("DNS start: {Host:dns-is-faked.golang}")
+ wantOnce("DNS done: {Addrs:[{IP:" + ip + " Zone:}] Err:<nil> Coalesced:false}")
+ wantOnce("got conn: {")
+ wantOnceOrMore("Connecting to tcp " + addrStr)
+ wantOnceOrMore("connected to tcp " + addrStr + " = <nil>")
+ wantOnce("Reused:false WasIdle:false IdleTime:0s")
+ wantOnce("first response byte")
+ if mode == http2Mode {
+ wantOnce("tls handshake start")
+ wantOnce("tls handshake done")
+ } else {
+ wantOnce("PutIdleConn = <nil>")
+ wantOnce("WroteHeaderField: User-Agent: [Go-http-client/1.1]")
+ // TODO(meirf): issue 19761. Make these agnostic to h1/h2. (These are not h1 specific, but the
+ // WroteHeaderField hook is not yet implemented in h2.)
+ wantOnce(fmt.Sprintf("WroteHeaderField: Host: [dns-is-faked.golang:%s]", port))
+ wantOnce(fmt.Sprintf("WroteHeaderField: Content-Length: [%d]", len(body)))
+ wantOnce("WroteHeaderField: X-Foo-Multiple-Vals: [bar baz]")
+ wantOnce("WroteHeaderField: Accept-Encoding: [gzip]")
+ }
+ wantOnce("WroteHeaders")
+ wantOnce("Wait100Continue")
+ wantOnce("Got100Continue")
+ wantOnce("WroteRequest: {Err:<nil>}")
+ if strings.Contains(got, " to udp ") {
+ t.Errorf("should not see UDP (DNS) connections")
+ }
+ if t.Failed() {
+ t.Errorf("Output:\n%s", got)
+ }
+
+ // And do a second request:
+ req, _ = NewRequest("GET", cst.scheme()+"://dns-is-faked.golang:"+port, nil)
+ req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
+ res, err = cst.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if res.StatusCode != 200 {
+ t.Fatal(res.Status)
+ }
+ res.Body.Close()
+
+ mu.Lock()
+ got = buf.String()
+ mu.Unlock()
+
+ sub := "Getting conn for dns-is-faked.golang:"
+ if gotn, want := strings.Count(got, sub), 2; gotn != want {
+ t.Errorf("substring %q appeared %d times; want %d. Log:\n%s", sub, gotn, want, got)
+ }
+
+}
+
+func TestTransportEventTraceTLSVerify(t *testing.T) {
+ run(t, testTransportEventTraceTLSVerify, []testMode{https1Mode, http2Mode})
+}
+func testTransportEventTraceTLSVerify(t *testing.T, mode testMode) {
+ var mu sync.Mutex
+ var buf strings.Builder
+ logf := func(format string, args ...any) {
+ mu.Lock()
+ defer mu.Unlock()
+ fmt.Fprintf(&buf, format, args...)
+ buf.WriteByte('\n')
+ }
+
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ t.Error("Unexpected request")
+ }), func(ts *httptest.Server) {
+ ts.Config.ErrorLog = log.New(funcWriter(func(p []byte) (int, error) {
+ logf("%s", p)
+ return len(p), nil
+ }), "", 0)
+ }).ts
+
+ certpool := x509.NewCertPool()
+ certpool.AddCert(ts.Certificate())
+
+ c := &Client{Transport: &Transport{
+ TLSClientConfig: &tls.Config{
+ ServerName: "dns-is-faked.golang",
+ RootCAs: certpool,
+ },
+ }}
+
+ trace := &httptrace.ClientTrace{
+ TLSHandshakeStart: func() { logf("TLSHandshakeStart") },
+ TLSHandshakeDone: func(s tls.ConnectionState, err error) {
+ logf("TLSHandshakeDone: ConnectionState = %v \n err = %v", s, err)
+ },
+ }
+
+ req, _ := NewRequest("GET", ts.URL, nil)
+ req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace))
+ _, err := c.Do(req)
+ if err == nil {
+ t.Error("Expected request to fail TLS verification")
+ }
+
+ mu.Lock()
+ got := buf.String()
+ mu.Unlock()
+
+ wantOnce := func(sub string) {
+ if strings.Count(got, sub) != 1 {
+ t.Errorf("expected substring %q exactly once in output.", sub)
+ }
+ }
+
+ wantOnce("TLSHandshakeStart")
+ wantOnce("TLSHandshakeDone")
+ wantOnce("err = tls: failed to verify certificate: x509: certificate is valid for example.com")
+
+ if t.Failed() {
+ t.Errorf("Output:\n%s", got)
+ }
+}
+
+var (
+ isDNSHijackedOnce sync.Once
+ isDNSHijacked bool
+)
+
+func skipIfDNSHijacked(t *testing.T) {
+ // Skip this test if the user is using a shady/ISP
+ // DNS server hijacking queries.
+ // See issues 16732, 16716.
+ isDNSHijackedOnce.Do(func() {
+ addrs, _ := net.LookupHost("dns-should-not-resolve.golang")
+ isDNSHijacked = len(addrs) != 0
+ })
+ if isDNSHijacked {
+ t.Skip("skipping; test requires non-hijacking DNS server")
+ }
+}
+
+func TestTransportEventTraceRealDNS(t *testing.T) {
+ skipIfDNSHijacked(t)
+ defer afterTest(t)
+ tr := &Transport{}
+ defer tr.CloseIdleConnections()
+ c := &Client{Transport: tr}
+
+ var mu sync.Mutex // guards buf
+ var buf strings.Builder
+ logf := func(format string, args ...any) {
+ mu.Lock()
+ defer mu.Unlock()
+ fmt.Fprintf(&buf, format, args...)
+ buf.WriteByte('\n')
+ }
+
+ req, _ := NewRequest("GET", "http://dns-should-not-resolve.golang:80", nil)
+ trace := &httptrace.ClientTrace{
+ DNSStart: func(e httptrace.DNSStartInfo) { logf("DNSStart: %+v", e) },
+ DNSDone: func(e httptrace.DNSDoneInfo) { logf("DNSDone: %+v", e) },
+ ConnectStart: func(network, addr string) { logf("ConnectStart: %s %s", network, addr) },
+ ConnectDone: func(network, addr string, err error) { logf("ConnectDone: %s %s %v", network, addr, err) },
+ }
+ req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace))
+
+ resp, err := c.Do(req)
+ if err == nil {
+ resp.Body.Close()
+ t.Fatal("expected error during DNS lookup")
+ }
+
+ mu.Lock()
+ got := buf.String()
+ mu.Unlock()
+
+ wantSub := func(sub string) {
+ if !strings.Contains(got, sub) {
+ t.Errorf("expected substring %q in output.", sub)
+ }
+ }
+ wantSub("DNSStart: {Host:dns-should-not-resolve.golang}")
+ wantSub("DNSDone: {Addrs:[] Err:")
+ if strings.Contains(got, "ConnectStart") || strings.Contains(got, "ConnectDone") {
+ t.Errorf("should not see Connect events")
+ }
+ if t.Failed() {
+ t.Errorf("Output:\n%s", got)
+ }
+}
+
+// Issue 14353: port can only contain digits.
+func TestTransportRejectsAlphaPort(t *testing.T) {
+ res, err := Get("http://dummy.tld:123foo/bar")
+ if err == nil {
+ res.Body.Close()
+ t.Fatal("unexpected success")
+ }
+ ue, ok := err.(*url.Error)
+ if !ok {
+ t.Fatalf("got %#v; want *url.Error", err)
+ }
+ got := ue.Err.Error()
+ want := `invalid port ":123foo" after host`
+ if got != want {
+ t.Errorf("got error %q; want %q", got, want)
+ }
+}
+
+// Test the httptrace.TLSHandshake{Start,Done} hooks with an https http1
+// connections. The http2 test is done in TestTransportEventTrace_h2
+func TestTLSHandshakeTrace(t *testing.T) {
+ run(t, testTLSHandshakeTrace, []testMode{https1Mode, http2Mode})
+}
+func testTLSHandshakeTrace(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts
+
+ var mu sync.Mutex
+ var start, done bool
+ trace := &httptrace.ClientTrace{
+ TLSHandshakeStart: func() {
+ mu.Lock()
+ defer mu.Unlock()
+ start = true
+ },
+ TLSHandshakeDone: func(s tls.ConnectionState, err error) {
+ mu.Lock()
+ defer mu.Unlock()
+ done = true
+ if err != nil {
+ t.Fatal("Expected error to be nil but was:", err)
+ }
+ },
+ }
+
+ c := ts.Client()
+ req, err := NewRequest("GET", ts.URL, nil)
+ if err != nil {
+ t.Fatal("Unable to construct test request:", err)
+ }
+ req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
+
+ r, err := c.Do(req)
+ if err != nil {
+ t.Fatal("Unexpected error making request:", err)
+ }
+ r.Body.Close()
+ mu.Lock()
+ defer mu.Unlock()
+ if !start {
+ t.Fatal("Expected TLSHandshakeStart to be called, but wasn't")
+ }
+ if !done {
+ t.Fatal("Expected TLSHandshakeDone to be called, but wasn't")
+ }
+}
+
+func TestTransportMaxIdleConns(t *testing.T) {
+ run(t, testTransportMaxIdleConns, []testMode{http1Mode})
+}
+func testTransportMaxIdleConns(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ // No body for convenience.
+ })).ts
+ c := ts.Client()
+ tr := c.Transport.(*Transport)
+ tr.MaxIdleConns = 4
+
+ ip, port, err := net.SplitHostPort(ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, _, host string) ([]net.IPAddr, error) {
+ return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
+ })
+
+ hitHost := func(n int) {
+ req, _ := NewRequest("GET", fmt.Sprintf("http://host-%d.dns-is-faked.golang:"+port, n), nil)
+ req = req.WithContext(ctx)
+ res, err := c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ }
+ for i := 0; i < 4; i++ {
+ hitHost(i)
+ }
+ want := []string{
+ "|http|host-0.dns-is-faked.golang:" + port,
+ "|http|host-1.dns-is-faked.golang:" + port,
+ "|http|host-2.dns-is-faked.golang:" + port,
+ "|http|host-3.dns-is-faked.golang:" + port,
+ }
+ if got := tr.IdleConnKeysForTesting(); !reflect.DeepEqual(got, want) {
+ t.Fatalf("idle conn keys mismatch.\n got: %q\nwant: %q\n", got, want)
+ }
+
+ // Now hitting the 5th host should kick out the first host:
+ hitHost(4)
+ want = []string{
+ "|http|host-1.dns-is-faked.golang:" + port,
+ "|http|host-2.dns-is-faked.golang:" + port,
+ "|http|host-3.dns-is-faked.golang:" + port,
+ "|http|host-4.dns-is-faked.golang:" + port,
+ }
+ if got := tr.IdleConnKeysForTesting(); !reflect.DeepEqual(got, want) {
+ t.Fatalf("idle conn keys mismatch after 5th host.\n got: %q\nwant: %q\n", got, want)
+ }
+}
+
+func TestTransportIdleConnTimeout(t *testing.T) { run(t, testTransportIdleConnTimeout) }
+func testTransportIdleConnTimeout(t *testing.T, mode testMode) {
+ if testing.Short() {
+ t.Skip("skipping in short mode")
+ }
+
+ timeout := 1 * time.Millisecond
+timeoutLoop:
+ for {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ // No body for convenience.
+ }))
+ tr := cst.tr
+ tr.IdleConnTimeout = timeout
+ defer tr.CloseIdleConnections()
+ c := &Client{Transport: tr}
+
+ idleConns := func() []string {
+ if mode == http2Mode {
+ return tr.IdleConnStrsForTesting_h2()
+ } else {
+ return tr.IdleConnStrsForTesting()
+ }
+ }
+
+ var conn string
+ doReq := func(n int) (timeoutOk bool) {
+ req, _ := NewRequest("GET", cst.ts.URL, nil)
+ req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
+ PutIdleConn: func(err error) {
+ if err != nil {
+ t.Errorf("failed to keep idle conn: %v", err)
+ }
+ },
+ }))
+ res, err := c.Do(req)
+ if err != nil {
+ if strings.Contains(err.Error(), "use of closed network connection") {
+ t.Logf("req %v: connection closed prematurely", n)
+ return false
+ }
+ }
+ res.Body.Close()
+ conns := idleConns()
+ if len(conns) != 1 {
+ if len(conns) == 0 {
+ t.Logf("req %v: no idle conns", n)
+ return false
+ }
+ t.Fatalf("req %v: unexpected number of idle conns: %q", n, conns)
+ }
+ if conn == "" {
+ conn = conns[0]
+ }
+ if conn != conns[0] {
+ t.Logf("req %v: cached connection changed; expected the same one throughout the test", n)
+ return false
+ }
+ return true
+ }
+ for i := 0; i < 3; i++ {
+ if !doReq(i) {
+ t.Logf("idle conn timeout %v appears to be too short; retrying with longer", timeout)
+ timeout *= 2
+ cst.close()
+ continue timeoutLoop
+ }
+ time.Sleep(timeout / 2)
+ }
+
+ waitCondition(t, timeout/2, func(d time.Duration) bool {
+ if got := idleConns(); len(got) != 0 {
+ if d >= timeout*3/2 {
+ t.Logf("after %v, idle conns = %q", d, got)
+ }
+ return false
+ }
+ return true
+ })
+ break
+ }
+}
+
+// Issue 16208: Go 1.7 crashed after Transport.IdleConnTimeout if an
+// HTTP/2 connection was established but its caller no longer
+// wanted it. (Assuming the connection cache was enabled, which it is
+// by default)
+//
+// This test reproduced the crash by setting the IdleConnTimeout low
+// (to make the test reasonable) and then making a request which is
+// canceled by the DialTLS hook, which then also waits to return the
+// real connection until after the RoundTrip saw the error. Then we
+// know the successful tls.Dial from DialTLS will need to go into the
+// idle pool. Then we give it a of time to explode.
+func TestIdleConnH2Crash(t *testing.T) { run(t, testIdleConnH2Crash, []testMode{http2Mode}) }
+func testIdleConnH2Crash(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ // nothing
+ }))
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ sawDoErr := make(chan bool, 1)
+ testDone := make(chan struct{})
+ defer close(testDone)
+
+ cst.tr.IdleConnTimeout = 5 * time.Millisecond
+ cst.tr.DialTLS = func(network, addr string) (net.Conn, error) {
+ c, err := tls.Dial(network, addr, &tls.Config{
+ InsecureSkipVerify: true,
+ NextProtos: []string{"h2"},
+ })
+ if err != nil {
+ t.Error(err)
+ return nil, err
+ }
+ if cs := c.ConnectionState(); cs.NegotiatedProtocol != "h2" {
+ t.Errorf("protocol = %q; want %q", cs.NegotiatedProtocol, "h2")
+ c.Close()
+ return nil, errors.New("bogus")
+ }
+
+ cancel()
+
+ select {
+ case <-sawDoErr:
+ case <-testDone:
+ }
+ return c, nil
+ }
+
+ req, _ := NewRequest("GET", cst.ts.URL, nil)
+ req = req.WithContext(ctx)
+ res, err := cst.c.Do(req)
+ if err == nil {
+ res.Body.Close()
+ t.Fatal("unexpected success")
+ }
+ sawDoErr <- true
+
+ // Wait for the explosion.
+ time.Sleep(cst.tr.IdleConnTimeout * 10)
+}
+
+type funcConn struct {
+ net.Conn
+ read func([]byte) (int, error)
+ write func([]byte) (int, error)
+}
+
+func (c funcConn) Read(p []byte) (int, error) { return c.read(p) }
+func (c funcConn) Write(p []byte) (int, error) { return c.write(p) }
+func (c funcConn) Close() error { return nil }
+
+// Issue 16465: Transport.RoundTrip should return the raw net.Conn.Read error from Peek
+// back to the caller.
+func TestTransportReturnsPeekError(t *testing.T) {
+ errValue := errors.New("specific error value")
+
+ wrote := make(chan struct{})
+ var wroteOnce sync.Once
+
+ tr := &Transport{
+ Dial: func(network, addr string) (net.Conn, error) {
+ c := funcConn{
+ read: func([]byte) (int, error) {
+ <-wrote
+ return 0, errValue
+ },
+ write: func(p []byte) (int, error) {
+ wroteOnce.Do(func() { close(wrote) })
+ return len(p), nil
+ },
+ }
+ return c, nil
+ },
+ }
+ _, err := tr.RoundTrip(httptest.NewRequest("GET", "http://fake.tld/", nil))
+ if err != errValue {
+ t.Errorf("error = %#v; want %v", err, errValue)
+ }
+}
+
+// Issue 13835: international domain names should work
+func TestTransportIDNA(t *testing.T) { run(t, testTransportIDNA) }
+func testTransportIDNA(t *testing.T, mode testMode) {
+ const uniDomain = "гофер.го"
+ const punyDomain = "xn--c1ae0ajs.xn--c1aw"
+
+ var port string
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ want := punyDomain + ":" + port
+ if r.Host != want {
+ t.Errorf("Host header = %q; want %q", r.Host, want)
+ }
+ if mode == http2Mode {
+ if r.TLS == nil {
+ t.Errorf("r.TLS == nil")
+ } else if r.TLS.ServerName != punyDomain {
+ t.Errorf("TLS.ServerName = %q; want %q", r.TLS.ServerName, punyDomain)
+ }
+ }
+ w.Header().Set("Hit-Handler", "1")
+ }), func(tr *Transport) {
+ if tr.TLSClientConfig != nil {
+ tr.TLSClientConfig.InsecureSkipVerify = true
+ }
+ })
+
+ ip, port, err := net.SplitHostPort(cst.ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Install a fake DNS server.
+ ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) {
+ if host != punyDomain {
+ t.Errorf("got DNS host lookup for %q/%q; want %q", network, host, punyDomain)
+ return nil, nil
+ }
+ return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
+ })
+
+ req, _ := NewRequest("GET", cst.scheme()+"://"+uniDomain+":"+port, nil)
+ trace := &httptrace.ClientTrace{
+ GetConn: func(hostPort string) {
+ want := net.JoinHostPort(punyDomain, port)
+ if hostPort != want {
+ t.Errorf("getting conn for %q; want %q", hostPort, want)
+ }
+ },
+ DNSStart: func(e httptrace.DNSStartInfo) {
+ if e.Host != punyDomain {
+ t.Errorf("DNSStart Host = %q; want %q", e.Host, punyDomain)
+ }
+ },
+ }
+ req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
+
+ res, err := cst.tr.RoundTrip(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ if res.Header.Get("Hit-Handler") != "1" {
+ out, err := httputil.DumpResponse(res, true)
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Errorf("Response body wasn't from Handler. Got:\n%s\n", out)
+ }
+}
+
+// Issue 13290: send User-Agent in proxy CONNECT
+func TestTransportProxyConnectHeader(t *testing.T) {
+ run(t, testTransportProxyConnectHeader, []testMode{http1Mode})
+}
+func testTransportProxyConnectHeader(t *testing.T, mode testMode) {
+ reqc := make(chan *Request, 1)
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.Method != "CONNECT" {
+ t.Errorf("method = %q; want CONNECT", r.Method)
+ }
+ reqc <- r
+ c, _, err := w.(Hijacker).Hijack()
+ if err != nil {
+ t.Errorf("Hijack: %v", err)
+ return
+ }
+ c.Close()
+ })).ts
+
+ c := ts.Client()
+ c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) {
+ return url.Parse(ts.URL)
+ }
+ c.Transport.(*Transport).ProxyConnectHeader = Header{
+ "User-Agent": {"foo"},
+ "Other": {"bar"},
+ }
+
+ res, err := c.Get("https://dummy.tld/") // https to force a CONNECT
+ if err == nil {
+ res.Body.Close()
+ t.Errorf("unexpected success")
+ }
+
+ r := <-reqc
+ if got, want := r.Header.Get("User-Agent"), "foo"; got != want {
+ t.Errorf("CONNECT request User-Agent = %q; want %q", got, want)
+ }
+ if got, want := r.Header.Get("Other"), "bar"; got != want {
+ t.Errorf("CONNECT request Other = %q; want %q", got, want)
+ }
+}
+
+func TestTransportProxyGetConnectHeader(t *testing.T) {
+ run(t, testTransportProxyGetConnectHeader, []testMode{http1Mode})
+}
+func testTransportProxyGetConnectHeader(t *testing.T, mode testMode) {
+ reqc := make(chan *Request, 1)
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.Method != "CONNECT" {
+ t.Errorf("method = %q; want CONNECT", r.Method)
+ }
+ reqc <- r
+ c, _, err := w.(Hijacker).Hijack()
+ if err != nil {
+ t.Errorf("Hijack: %v", err)
+ return
+ }
+ c.Close()
+ })).ts
+
+ c := ts.Client()
+ c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) {
+ return url.Parse(ts.URL)
+ }
+ // These should be ignored:
+ c.Transport.(*Transport).ProxyConnectHeader = Header{
+ "User-Agent": {"foo"},
+ "Other": {"bar"},
+ }
+ c.Transport.(*Transport).GetProxyConnectHeader = func(ctx context.Context, proxyURL *url.URL, target string) (Header, error) {
+ return Header{
+ "User-Agent": {"foo2"},
+ "Other": {"bar2"},
+ }, nil
+ }
+
+ res, err := c.Get("https://dummy.tld/") // https to force a CONNECT
+ if err == nil {
+ res.Body.Close()
+ t.Errorf("unexpected success")
+ }
+
+ r := <-reqc
+ if got, want := r.Header.Get("User-Agent"), "foo2"; got != want {
+ t.Errorf("CONNECT request User-Agent = %q; want %q", got, want)
+ }
+ if got, want := r.Header.Get("Other"), "bar2"; got != want {
+ t.Errorf("CONNECT request Other = %q; want %q", got, want)
+ }
+}
+
+var errFakeRoundTrip = errors.New("fake roundtrip")
+
+type funcRoundTripper func()
+
+func (fn funcRoundTripper) RoundTrip(*Request) (*Response, error) {
+ fn()
+ return nil, errFakeRoundTrip
+}
+
+func wantBody(res *Response, err error, want string) error {
+ if err != nil {
+ return err
+ }
+ slurp, err := io.ReadAll(res.Body)
+ if err != nil {
+ return fmt.Errorf("error reading body: %v", err)
+ }
+ if string(slurp) != want {
+ return fmt.Errorf("body = %q; want %q", slurp, want)
+ }
+ if err := res.Body.Close(); err != nil {
+ return fmt.Errorf("body Close = %v", err)
+ }
+ return nil
+}
+
+func newLocalListener(t *testing.T) net.Listener {
+ ln, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ ln, err = net.Listen("tcp6", "[::1]:0")
+ }
+ if err != nil {
+ t.Fatal(err)
+ }
+ return ln
+}
+
+type countCloseReader struct {
+ n *int
+ io.Reader
+}
+
+func (cr countCloseReader) Close() error {
+ (*cr.n)++
+ return nil
+}
+
+// rgz is a gzip quine that uncompresses to itself.
+var rgz = []byte{
+ 0x1f, 0x8b, 0x08, 0x08, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x72, 0x65, 0x63, 0x75, 0x72, 0x73,
+ 0x69, 0x76, 0x65, 0x00, 0x92, 0xef, 0xe6, 0xe0,
+ 0x60, 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2,
+ 0xe2, 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17,
+ 0x00, 0xe8, 0xff, 0x92, 0xef, 0xe6, 0xe0, 0x60,
+ 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2, 0xe2,
+ 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17, 0x00,
+ 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16, 0x06, 0x00,
+ 0x05, 0x00, 0xfa, 0xff, 0x42, 0x12, 0x46, 0x16,
+ 0x06, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, 0x05,
+ 0x00, 0xfa, 0xff, 0x00, 0x14, 0x00, 0xeb, 0xff,
+ 0x42, 0x12, 0x46, 0x16, 0x06, 0x00, 0x05, 0x00,
+ 0xfa, 0xff, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00,
+ 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4,
+ 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88,
+ 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff,
+ 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00,
+ 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00,
+ 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4,
+ 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00,
+ 0x00, 0xff, 0xff, 0x00, 0x17, 0x00, 0xe8, 0xff,
+ 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x00, 0x00,
+ 0xff, 0xff, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00,
+ 0x17, 0x00, 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16,
+ 0x06, 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08,
+ 0x00, 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa,
+ 0x00, 0x00, 0x00, 0x42, 0x12, 0x46, 0x16, 0x06,
+ 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08, 0x00,
+ 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00,
+ 0x00, 0x00, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00,
+ 0x00, 0x00,
+}
+
+// Ensure that a missing status doesn't make the server panic
+// See Issue https://golang.org/issues/21701
+func TestMissingStatusNoPanic(t *testing.T) {
+ t.Parallel()
+
+ const want = "unknown status code"
+
+ ln := newLocalListener(t)
+ addr := ln.Addr().String()
+ done := make(chan bool)
+ fullAddrURL := fmt.Sprintf("http://%s", addr)
+ raw := "HTTP/1.1 400\r\n" +
+ "Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" +
+ "Content-Type: text/html; charset=utf-8\r\n" +
+ "Content-Length: 10\r\n" +
+ "Last-Modified: Wed, 30 Aug 2017 19:02:02 GMT\r\n" +
+ "Vary: Accept-Encoding\r\n\r\n" +
+ "Aloha Olaa"
+
+ go func() {
+ defer close(done)
+
+ conn, _ := ln.Accept()
+ if conn != nil {
+ io.WriteString(conn, raw)
+ io.ReadAll(conn)
+ conn.Close()
+ }
+ }()
+
+ proxyURL, err := url.Parse(fullAddrURL)
+ if err != nil {
+ t.Fatalf("proxyURL: %v", err)
+ }
+
+ tr := &Transport{Proxy: ProxyURL(proxyURL)}
+
+ req, _ := NewRequest("GET", "https://golang.org/", nil)
+ res, err, panicked := doFetchCheckPanic(tr, req)
+ if panicked {
+ t.Error("panicked, expecting an error")
+ }
+ if res != nil && res.Body != nil {
+ io.Copy(io.Discard, res.Body)
+ res.Body.Close()
+ }
+
+ if err == nil || !strings.Contains(err.Error(), want) {
+ t.Errorf("got=%v want=%q", err, want)
+ }
+
+ ln.Close()
+ <-done
+}
+
+func doFetchCheckPanic(tr *Transport, req *Request) (res *Response, err error, panicked bool) {
+ defer func() {
+ if r := recover(); r != nil {
+ panicked = true
+ }
+ }()
+ res, err = tr.RoundTrip(req)
+ return
+}
+
+// Issue 22330: do not allow the response body to be read when the status code
+// forbids a response body.
+func TestNoBodyOnChunked304Response(t *testing.T) {
+ run(t, testNoBodyOnChunked304Response, []testMode{http1Mode})
+}
+func testNoBodyOnChunked304Response(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ conn, buf, _ := w.(Hijacker).Hijack()
+ buf.Write([]byte("HTTP/1.1 304 NOT MODIFIED\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n"))
+ buf.Flush()
+ conn.Close()
+ }))
+
+ // Our test server above is sending back bogus data after the
+ // response (the "0\r\n\r\n" part), which causes the Transport
+ // code to log spam. Disable keep-alives so we never even try
+ // to reuse the connection.
+ cst.tr.DisableKeepAlives = true
+
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if res.Body != NoBody {
+ t.Errorf("Unexpected body on 304 response")
+ }
+}
+
+type funcWriter func([]byte) (int, error)
+
+func (f funcWriter) Write(p []byte) (int, error) { return f(p) }
+
+type doneContext struct {
+ context.Context
+ err error
+}
+
+func (doneContext) Done() <-chan struct{} {
+ c := make(chan struct{})
+ close(c)
+ return c
+}
+
+func (d doneContext) Err() error { return d.err }
+
+// Issue 25852: Transport should check whether Context is done early.
+func TestTransportCheckContextDoneEarly(t *testing.T) {
+ tr := &Transport{}
+ req, _ := NewRequest("GET", "http://fake.example/", nil)
+ wantErr := errors.New("some error")
+ req = req.WithContext(doneContext{context.Background(), wantErr})
+ _, err := tr.RoundTrip(req)
+ if err != wantErr {
+ t.Errorf("error = %v; want %v", err, wantErr)
+ }
+}
+
+// Issue 23399: verify that if a client request times out, the Transport's
+// conn is closed so that it's not reused.
+//
+// This is the test variant that times out before the server replies with
+// any response headers.
+func TestClientTimeoutKillsConn_BeforeHeaders(t *testing.T) {
+ run(t, testClientTimeoutKillsConn_BeforeHeaders, []testMode{http1Mode})
+}
+func testClientTimeoutKillsConn_BeforeHeaders(t *testing.T, mode testMode) {
+ timeout := 1 * time.Millisecond
+ for {
+ inHandler := make(chan bool)
+ cancelHandler := make(chan struct{})
+ handlerDone := make(chan bool)
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ <-r.Context().Done()
+
+ select {
+ case <-cancelHandler:
+ return
+ case inHandler <- true:
+ }
+ defer func() { handlerDone <- true }()
+
+ // Read from the conn until EOF to verify that it was correctly closed.
+ conn, _, err := w.(Hijacker).Hijack()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ n, err := conn.Read([]byte{0})
+ if n != 0 || err != io.EOF {
+ t.Errorf("unexpected Read result: %v, %v", n, err)
+ }
+ conn.Close()
+ }))
+
+ cst.c.Timeout = timeout
+
+ _, err := cst.c.Get(cst.ts.URL)
+ if err == nil {
+ close(cancelHandler)
+ t.Fatal("unexpected Get success")
+ }
+
+ tooSlow := time.NewTimer(timeout * 10)
+ select {
+ case <-tooSlow.C:
+ // If we didn't get into the Handler, that probably means the builder was
+ // just slow and the Get failed in that time but never made it to the
+ // server. That's fine; we'll try again with a longer timeout.
+ t.Logf("no handler seen in %v; retrying with longer timeout", timeout)
+ close(cancelHandler)
+ cst.close()
+ timeout *= 2
+ continue
+ case <-inHandler:
+ tooSlow.Stop()
+ <-handlerDone
+ }
+ break
+ }
+}
+
+// Issue 23399: verify that if a client request times out, the Transport's
+// conn is closed so that it's not reused.
+//
+// This is the test variant that has the server send response headers
+// first, and time out during the write of the response body.
+func TestClientTimeoutKillsConn_AfterHeaders(t *testing.T) {
+ run(t, testClientTimeoutKillsConn_AfterHeaders, []testMode{http1Mode})
+}
+func testClientTimeoutKillsConn_AfterHeaders(t *testing.T, mode testMode) {
+ inHandler := make(chan bool)
+ cancelHandler := make(chan struct{})
+ handlerDone := make(chan bool)
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Length", "100")
+ w.(Flusher).Flush()
+
+ select {
+ case <-cancelHandler:
+ return
+ case inHandler <- true:
+ }
+ defer func() { handlerDone <- true }()
+
+ conn, _, err := w.(Hijacker).Hijack()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ conn.Write([]byte("foo"))
+
+ n, err := conn.Read([]byte{0})
+ // The error should be io.EOF or "read tcp
+ // 127.0.0.1:35827->127.0.0.1:40290: read: connection
+ // reset by peer" depending on timing. Really we just
+ // care that it returns at all. But if it returns with
+ // data, that's weird.
+ if n != 0 || err == nil {
+ t.Errorf("unexpected Read result: %v, %v", n, err)
+ }
+ conn.Close()
+ }))
+
+ // Set Timeout to something very long but non-zero to exercise
+ // the codepaths that check for it. But rather than wait for it to fire
+ // (which would make the test slow), we send on the req.Cancel channel instead,
+ // which happens to exercise the same code paths.
+ cst.c.Timeout = 24 * time.Hour // just to be non-zero, not to hit it.
+ req, _ := NewRequest("GET", cst.ts.URL, nil)
+ cancelReq := make(chan struct{})
+ req.Cancel = cancelReq
+
+ res, err := cst.c.Do(req)
+ if err != nil {
+ close(cancelHandler)
+ t.Fatalf("Get error: %v", err)
+ }
+
+ // Cancel the request while the handler is still blocked on sending to the
+ // inHandler channel. Then read it until it fails, to verify that the
+ // connection is broken before the handler itself closes it.
+ close(cancelReq)
+ got, err := io.ReadAll(res.Body)
+ if err == nil {
+ t.Errorf("unexpected success; read %q, nil", got)
+ }
+
+ // Now unblock the handler and wait for it to complete.
+ <-inHandler
+ <-handlerDone
+}
+
+func TestTransportResponseBodyWritableOnProtocolSwitch(t *testing.T) {
+ run(t, testTransportResponseBodyWritableOnProtocolSwitch, []testMode{http1Mode})
+}
+func testTransportResponseBodyWritableOnProtocolSwitch(t *testing.T, mode testMode) {
+ done := make(chan struct{})
+ defer close(done)
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ conn, _, err := w.(Hijacker).Hijack()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ defer conn.Close()
+ io.WriteString(conn, "HTTP/1.1 101 Switching Protocols Hi\r\nConnection: upgRADe\r\nUpgrade: foo\r\n\r\nSome buffered data\n")
+ bs := bufio.NewScanner(conn)
+ bs.Scan()
+ fmt.Fprintf(conn, "%s\n", strings.ToUpper(bs.Text()))
+ <-done
+ }))
+
+ req, _ := NewRequest("GET", cst.ts.URL, nil)
+ req.Header.Set("Upgrade", "foo")
+ req.Header.Set("Connection", "upgrade")
+ res, err := cst.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if res.StatusCode != 101 {
+ t.Fatalf("expected 101 switching protocols; got %v, %v", res.Status, res.Header)
+ }
+ rwc, ok := res.Body.(io.ReadWriteCloser)
+ if !ok {
+ t.Fatalf("expected a ReadWriteCloser; got a %T", res.Body)
+ }
+ defer rwc.Close()
+ bs := bufio.NewScanner(rwc)
+ if !bs.Scan() {
+ t.Fatalf("expected readable input")
+ }
+ if got, want := bs.Text(), "Some buffered data"; got != want {
+ t.Errorf("read %q; want %q", got, want)
+ }
+ io.WriteString(rwc, "echo\n")
+ if !bs.Scan() {
+ t.Fatalf("expected another line")
+ }
+ if got, want := bs.Text(), "ECHO"; got != want {
+ t.Errorf("read %q; want %q", got, want)
+ }
+}
+
+func TestTransportCONNECTBidi(t *testing.T) { run(t, testTransportCONNECTBidi, []testMode{http1Mode}) }
+func testTransportCONNECTBidi(t *testing.T, mode testMode) {
+ const target = "backend:443"
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.Method != "CONNECT" {
+ t.Errorf("unexpected method %q", r.Method)
+ w.WriteHeader(500)
+ return
+ }
+ if r.RequestURI != target {
+ t.Errorf("unexpected CONNECT target %q", r.RequestURI)
+ w.WriteHeader(500)
+ return
+ }
+ nc, brw, err := w.(Hijacker).Hijack()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ defer nc.Close()
+ nc.Write([]byte("HTTP/1.1 200 OK\r\n\r\n"))
+ // Switch to a little protocol that capitalize its input lines:
+ for {
+ line, err := brw.ReadString('\n')
+ if err != nil {
+ if err != io.EOF {
+ t.Error(err)
+ }
+ return
+ }
+ io.WriteString(brw, strings.ToUpper(line))
+ brw.Flush()
+ }
+ }))
+ pr, pw := io.Pipe()
+ defer pw.Close()
+ req, err := NewRequest("CONNECT", cst.ts.URL, pr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ req.URL.Opaque = target
+ res, err := cst.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ if res.StatusCode != 200 {
+ t.Fatalf("status code = %d; want 200", res.StatusCode)
+ }
+ br := bufio.NewReader(res.Body)
+ for _, str := range []string{"foo", "bar", "baz"} {
+ fmt.Fprintf(pw, "%s\n", str)
+ got, err := br.ReadString('\n')
+ if err != nil {
+ t.Fatal(err)
+ }
+ got = strings.TrimSpace(got)
+ want := strings.ToUpper(str)
+ if got != want {
+ t.Fatalf("got %q; want %q", got, want)
+ }
+ }
+}
+
+func TestTransportRequestReplayable(t *testing.T) {
+ someBody := io.NopCloser(strings.NewReader(""))
+ tests := []struct {
+ name string
+ req *Request
+ want bool
+ }{
+ {
+ name: "GET",
+ req: &Request{Method: "GET"},
+ want: true,
+ },
+ {
+ name: "GET_http.NoBody",
+ req: &Request{Method: "GET", Body: NoBody},
+ want: true,
+ },
+ {
+ name: "GET_body",
+ req: &Request{Method: "GET", Body: someBody},
+ want: false,
+ },
+ {
+ name: "POST",
+ req: &Request{Method: "POST"},
+ want: false,
+ },
+ {
+ name: "POST_idempotency-key",
+ req: &Request{Method: "POST", Header: Header{"Idempotency-Key": {"x"}}},
+ want: true,
+ },
+ {
+ name: "POST_x-idempotency-key",
+ req: &Request{Method: "POST", Header: Header{"X-Idempotency-Key": {"x"}}},
+ want: true,
+ },
+ {
+ name: "POST_body",
+ req: &Request{Method: "POST", Header: Header{"Idempotency-Key": {"x"}}, Body: someBody},
+ want: false,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := tt.req.ExportIsReplayable()
+ if got != tt.want {
+ t.Errorf("replyable = %v; want %v", got, tt.want)
+ }
+ })
+ }
+}
+
+// testMockTCPConn is a mock TCP connection used to test that
+// ReadFrom is called when sending the request body.
+type testMockTCPConn struct {
+ *net.TCPConn
+
+ ReadFromCalled bool
+}
+
+func (c *testMockTCPConn) ReadFrom(r io.Reader) (int64, error) {
+ c.ReadFromCalled = true
+ return c.TCPConn.ReadFrom(r)
+}
+
+func TestTransportRequestWriteRoundTrip(t *testing.T) { run(t, testTransportRequestWriteRoundTrip) }
+func testTransportRequestWriteRoundTrip(t *testing.T, mode testMode) {
+ nBytes := int64(1 << 10)
+ newFileFunc := func() (r io.Reader, done func(), err error) {
+ f, err := os.CreateTemp("", "net-http-newfilefunc")
+ if err != nil {
+ return nil, nil, err
+ }
+
+ // Write some bytes to the file to enable reading.
+ if _, err := io.CopyN(f, rand.Reader, nBytes); err != nil {
+ return nil, nil, fmt.Errorf("failed to write data to file: %v", err)
+ }
+ if _, err := f.Seek(0, 0); err != nil {
+ return nil, nil, fmt.Errorf("failed to seek to front: %v", err)
+ }
+
+ done = func() {
+ f.Close()
+ os.Remove(f.Name())
+ }
+
+ return f, done, nil
+ }
+
+ newBufferFunc := func() (io.Reader, func(), error) {
+ return bytes.NewBuffer(make([]byte, nBytes)), func() {}, nil
+ }
+
+ cases := []struct {
+ name string
+ readerFunc func() (io.Reader, func(), error)
+ contentLength int64
+ expectedReadFrom bool
+ }{
+ {
+ name: "file, length",
+ readerFunc: newFileFunc,
+ contentLength: nBytes,
+ expectedReadFrom: true,
+ },
+ {
+ name: "file, no length",
+ readerFunc: newFileFunc,
+ },
+ {
+ name: "file, negative length",
+ readerFunc: newFileFunc,
+ contentLength: -1,
+ },
+ {
+ name: "buffer",
+ contentLength: nBytes,
+ readerFunc: newBufferFunc,
+ },
+ {
+ name: "buffer, no length",
+ readerFunc: newBufferFunc,
+ },
+ {
+ name: "buffer, length -1",
+ contentLength: -1,
+ readerFunc: newBufferFunc,
+ },
+ }
+
+ for _, tc := range cases {
+ t.Run(tc.name, func(t *testing.T) {
+ r, cleanup, err := tc.readerFunc()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cleanup()
+
+ tConn := &testMockTCPConn{}
+ trFunc := func(tr *Transport) {
+ tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
+ var d net.Dialer
+ conn, err := d.DialContext(ctx, network, addr)
+ if err != nil {
+ return nil, err
+ }
+
+ tcpConn, ok := conn.(*net.TCPConn)
+ if !ok {
+ return nil, fmt.Errorf("%s/%s does not provide a *net.TCPConn", network, addr)
+ }
+
+ tConn.TCPConn = tcpConn
+ return tConn, nil
+ }
+ }
+
+ cst := newClientServerTest(
+ t,
+ mode,
+ HandlerFunc(func(w ResponseWriter, r *Request) {
+ io.Copy(io.Discard, r.Body)
+ r.Body.Close()
+ w.WriteHeader(200)
+ }),
+ trFunc,
+ )
+
+ req, err := NewRequest("PUT", cst.ts.URL, r)
+ if err != nil {
+ t.Fatal(err)
+ }
+ req.ContentLength = tc.contentLength
+ req.Header.Set("Content-Type", "application/octet-stream")
+ resp, err := cst.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode != 200 {
+ t.Fatalf("status code = %d; want 200", resp.StatusCode)
+ }
+
+ expectedReadFrom := tc.expectedReadFrom
+ if mode != http1Mode {
+ expectedReadFrom = false
+ }
+ if !tConn.ReadFromCalled && expectedReadFrom {
+ t.Fatalf("did not call ReadFrom")
+ }
+
+ if tConn.ReadFromCalled && !expectedReadFrom {
+ t.Fatalf("ReadFrom was unexpectedly invoked")
+ }
+ })
+ }
+}
+
+func TestTransportClone(t *testing.T) {
+ tr := &Transport{
+ Proxy: func(*Request) (*url.URL, error) { panic("") },
+ OnProxyConnectResponse: func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error {
+ return nil
+ },
+ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
+ Dial: func(network, addr string) (net.Conn, error) { panic("") },
+ DialTLS: func(network, addr string) (net.Conn, error) { panic("") },
+ DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
+ TLSClientConfig: new(tls.Config),
+ TLSHandshakeTimeout: time.Second,
+ DisableKeepAlives: true,
+ DisableCompression: true,
+ MaxIdleConns: 1,
+ MaxIdleConnsPerHost: 1,
+ MaxConnsPerHost: 1,
+ IdleConnTimeout: time.Second,
+ ResponseHeaderTimeout: time.Second,
+ ExpectContinueTimeout: time.Second,
+ ProxyConnectHeader: Header{},
+ GetProxyConnectHeader: func(context.Context, *url.URL, string) (Header, error) { return nil, nil },
+ MaxResponseHeaderBytes: 1,
+ ForceAttemptHTTP2: true,
+ TLSNextProto: map[string]func(authority string, c *tls.Conn) RoundTripper{
+ "foo": func(authority string, c *tls.Conn) RoundTripper { panic("") },
+ },
+ ReadBufferSize: 1,
+ WriteBufferSize: 1,
+ }
+ tr2 := tr.Clone()
+ rv := reflect.ValueOf(tr2).Elem()
+ rt := rv.Type()
+ for i := 0; i < rt.NumField(); i++ {
+ sf := rt.Field(i)
+ if !token.IsExported(sf.Name) {
+ continue
+ }
+ if rv.Field(i).IsZero() {
+ t.Errorf("cloned field t2.%s is zero", sf.Name)
+ }
+ }
+
+ if _, ok := tr2.TLSNextProto["foo"]; !ok {
+ t.Errorf("cloned Transport lacked TLSNextProto 'foo' key")
+ }
+
+ // But test that a nil TLSNextProto is kept nil:
+ tr = new(Transport)
+ tr2 = tr.Clone()
+ if tr2.TLSNextProto != nil {
+ t.Errorf("Transport.TLSNextProto unexpected non-nil")
+ }
+}
+
+func TestIs408(t *testing.T) {
+ tests := []struct {
+ in string
+ want bool
+ }{
+ {"HTTP/1.0 408", true},
+ {"HTTP/1.1 408", true},
+ {"HTTP/1.8 408", true},
+ {"HTTP/2.0 408", false}, // maybe h2c would do this? but false for now.
+ {"HTTP/1.1 408 ", true},
+ {"HTTP/1.1 40", false},
+ {"http/1.0 408", false},
+ {"HTTP/1-1 408", false},
+ }
+ for _, tt := range tests {
+ if got := Export_is408Message([]byte(tt.in)); got != tt.want {
+ t.Errorf("is408Message(%q) = %v; want %v", tt.in, got, tt.want)
+ }
+ }
+}
+
+func TestTransportIgnores408(t *testing.T) {
+ run(t, testTransportIgnores408, []testMode{http1Mode}, testNotParallel)
+}
+func testTransportIgnores408(t *testing.T, mode testMode) {
+ // Not parallel. Relies on mutating the log package's global Output.
+ defer log.SetOutput(log.Writer())
+
+ var logout strings.Builder
+ log.SetOutput(&logout)
+
+ const target = "backend:443"
+
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ nc, _, err := w.(Hijacker).Hijack()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ defer nc.Close()
+ nc.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok"))
+ nc.Write([]byte("HTTP/1.1 408 bye\r\n")) // changing 408 to 409 makes test fail
+ }))
+ req, err := NewRequest("GET", cst.ts.URL, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res, err := cst.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ slurp, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if err != nil {
+ t.Fatal(err)
+ }
+ if string(slurp) != "ok" {
+ t.Fatalf("got %q; want ok", slurp)
+ }
+
+ waitCondition(t, 1*time.Millisecond, func(d time.Duration) bool {
+ if n := cst.tr.IdleConnKeyCountForTesting(); n != 0 {
+ if d > 0 {
+ t.Logf("%v idle conns still present after %v", n, d)
+ }
+ return false
+ }
+ return true
+ })
+ if got := logout.String(); got != "" {
+ t.Fatalf("expected no log output; got: %s", got)
+ }
+}
+
+func TestInvalidHeaderResponse(t *testing.T) {
+ run(t, testInvalidHeaderResponse, []testMode{http1Mode})
+}
+func testInvalidHeaderResponse(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ conn, buf, _ := w.(Hijacker).Hijack()
+ buf.Write([]byte("HTTP/1.1 200 OK\r\n" +
+ "Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" +
+ "Content-Type: text/html; charset=utf-8\r\n" +
+ "Content-Length: 0\r\n" +
+ "Foo : bar\r\n\r\n"))
+ buf.Flush()
+ conn.Close()
+ }))
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ if v := res.Header.Get("Foo"); v != "" {
+ t.Errorf(`unexpected "Foo" header: %q`, v)
+ }
+ if v := res.Header.Get("Foo "); v != "bar" {
+ t.Errorf(`bad "Foo " header value: %q, want %q`, v, "bar")
+ }
+}
+
+type bodyCloser bool
+
+func (bc *bodyCloser) Close() error {
+ *bc = true
+ return nil
+}
+func (bc *bodyCloser) Read(b []byte) (n int, err error) {
+ return 0, io.EOF
+}
+
+// Issue 35015: ensure that Transport closes the body on any error
+// with an invalid request, as promised by Client.Do docs.
+func TestTransportClosesBodyOnInvalidRequests(t *testing.T) {
+ run(t, testTransportClosesBodyOnInvalidRequests)
+}
+func testTransportClosesBodyOnInvalidRequests(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ t.Errorf("Should not have been invoked")
+ })).ts
+
+ u, _ := url.Parse(cst.URL)
+
+ tests := []struct {
+ name string
+ req *Request
+ wantErr string
+ }{
+ {
+ name: "invalid method",
+ req: &Request{
+ Method: " ",
+ URL: u,
+ },
+ wantErr: `invalid method " "`,
+ },
+ {
+ name: "nil URL",
+ req: &Request{
+ Method: "GET",
+ },
+ wantErr: `nil Request.URL`,
+ },
+ {
+ name: "invalid header key",
+ req: &Request{
+ Method: "GET",
+ Header: Header{"💡": {"emoji"}},
+ URL: u,
+ },
+ wantErr: `invalid header field name "💡"`,
+ },
+ {
+ name: "invalid header value",
+ req: &Request{
+ Method: "POST",
+ Header: Header{"key": {"\x19"}},
+ URL: u,
+ },
+ wantErr: `invalid header field value for "key"`,
+ },
+ {
+ name: "non HTTP(s) scheme",
+ req: &Request{
+ Method: "POST",
+ URL: &url.URL{Scheme: "faux"},
+ },
+ wantErr: `unsupported protocol scheme "faux"`,
+ },
+ {
+ name: "no Host in URL",
+ req: &Request{
+ Method: "POST",
+ URL: &url.URL{Scheme: "http"},
+ },
+ wantErr: `no Host in request URL`,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ var bc bodyCloser
+ req := tt.req
+ req.Body = &bc
+ _, err := cst.Client().Do(tt.req)
+ if err == nil {
+ t.Fatal("Expected an error")
+ }
+ if !bc {
+ t.Fatal("Expected body to have been closed")
+ }
+ if g, w := err.Error(), tt.wantErr; !strings.HasSuffix(g, w) {
+ t.Fatalf("Error mismatch: %q does not end with %q", g, w)
+ }
+ })
+ }
+}
+
+// breakableConn is a net.Conn wrapper with a Write method
+// that will fail when its brokenState is true.
+type breakableConn struct {
+ net.Conn
+ *brokenState
+}
+
+type brokenState struct {
+ sync.Mutex
+ broken bool
+}
+
+func (w *breakableConn) Write(b []byte) (n int, err error) {
+ w.Lock()
+ defer w.Unlock()
+ if w.broken {
+ return 0, errors.New("some write error")
+ }
+ return w.Conn.Write(b)
+}
+
+// Issue 34978: don't cache a broken HTTP/2 connection
+func TestDontCacheBrokenHTTP2Conn(t *testing.T) {
+ run(t, testDontCacheBrokenHTTP2Conn, []testMode{http2Mode})
+}
+func testDontCacheBrokenHTTP2Conn(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), optQuietLog)
+
+ var brokenState brokenState
+
+ const numReqs = 5
+ var numDials, gotConns uint32 // atomic
+
+ cst.tr.Dial = func(netw, addr string) (net.Conn, error) {
+ atomic.AddUint32(&numDials, 1)
+ c, err := net.Dial(netw, addr)
+ if err != nil {
+ t.Errorf("unexpected Dial error: %v", err)
+ return nil, err
+ }
+ return &breakableConn{c, &brokenState}, err
+ }
+
+ for i := 1; i <= numReqs; i++ {
+ brokenState.Lock()
+ brokenState.broken = false
+ brokenState.Unlock()
+
+ // doBreak controls whether we break the TCP connection after the TLS
+ // handshake (before the HTTP/2 handshake). We test a few failures
+ // in a row followed by a final success.
+ doBreak := i != numReqs
+
+ ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
+ GotConn: func(info httptrace.GotConnInfo) {
+ t.Logf("got conn: %v, reused=%v, wasIdle=%v, idleTime=%v", info.Conn.LocalAddr(), info.Reused, info.WasIdle, info.IdleTime)
+ atomic.AddUint32(&gotConns, 1)
+ },
+ TLSHandshakeDone: func(cfg tls.ConnectionState, err error) {
+ brokenState.Lock()
+ defer brokenState.Unlock()
+ if doBreak {
+ brokenState.broken = true
+ }
+ },
+ })
+ req, err := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ _, err = cst.c.Do(req)
+ if doBreak != (err != nil) {
+ t.Errorf("for iteration %d, doBreak=%v; unexpected error %v", i, doBreak, err)
+ }
+ }
+ if got, want := atomic.LoadUint32(&gotConns), 1; int(got) != want {
+ t.Errorf("GotConn calls = %v; want %v", got, want)
+ }
+ if got, want := atomic.LoadUint32(&numDials), numReqs; int(got) != want {
+ t.Errorf("Dials = %v; want %v", got, want)
+ }
+}
+
+// Issue 34941
+// When the client has too many concurrent requests on a single connection,
+// http.http2noCachedConnError is reported on multiple requests. There should
+// only be one decrement regardless of the number of failures.
+func TestTransportDecrementConnWhenIdleConnRemoved(t *testing.T) {
+ run(t, testTransportDecrementConnWhenIdleConnRemoved, []testMode{http2Mode})
+}
+func testTransportDecrementConnWhenIdleConnRemoved(t *testing.T, mode testMode) {
+ CondSkipHTTP2(t)
+
+ h := HandlerFunc(func(w ResponseWriter, r *Request) {
+ _, err := w.Write([]byte("foo"))
+ if err != nil {
+ t.Fatalf("Write: %v", err)
+ }
+ })
+
+ ts := newClientServerTest(t, mode, h).ts
+
+ c := ts.Client()
+ tr := c.Transport.(*Transport)
+ tr.MaxConnsPerHost = 1
+
+ errCh := make(chan error, 300)
+ doReq := func() {
+ resp, err := c.Get(ts.URL)
+ if err != nil {
+ errCh <- fmt.Errorf("request failed: %v", err)
+ return
+ }
+ defer resp.Body.Close()
+ _, err = io.ReadAll(resp.Body)
+ if err != nil {
+ errCh <- fmt.Errorf("read body failed: %v", err)
+ }
+ }
+
+ var wg sync.WaitGroup
+ for i := 0; i < 300; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ doReq()
+ }()
+ }
+ wg.Wait()
+ close(errCh)
+
+ for err := range errCh {
+ t.Errorf("error occurred: %v", err)
+ }
+}
+
+// Issue 36820
+// Test that we use the older backward compatible cancellation protocol
+// when a RoundTripper is registered via RegisterProtocol.
+func TestAltProtoCancellation(t *testing.T) {
+ defer afterTest(t)
+ tr := &Transport{}
+ c := &Client{
+ Transport: tr,
+ Timeout: time.Millisecond,
+ }
+ tr.RegisterProtocol("cancel", cancelProto{})
+ _, err := c.Get("cancel://bar.com/path")
+ if err == nil {
+ t.Error("request unexpectedly succeeded")
+ } else if !strings.Contains(err.Error(), errCancelProto.Error()) {
+ t.Errorf("got error %q, does not contain expected string %q", err, errCancelProto)
+ }
+}
+
+var errCancelProto = errors.New("canceled as expected")
+
+type cancelProto struct{}
+
+func (cancelProto) RoundTrip(req *Request) (*Response, error) {
+ <-req.Cancel
+ return nil, errCancelProto
+}
+
+type roundTripFunc func(r *Request) (*Response, error)
+
+func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) { return f(r) }
+
+// Issue 32441: body is not reset after ErrSkipAltProtocol
+func TestIssue32441(t *testing.T) { run(t, testIssue32441, []testMode{http1Mode}) }
+func testIssue32441(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ if n, _ := io.Copy(io.Discard, r.Body); n == 0 {
+ t.Error("body length is zero")
+ }
+ })).ts
+ c := ts.Client()
+ c.Transport.(*Transport).RegisterProtocol("http", roundTripFunc(func(r *Request) (*Response, error) {
+ // Draining body to trigger failure condition on actual request to server.
+ if n, _ := io.Copy(io.Discard, r.Body); n == 0 {
+ t.Error("body length is zero during round trip")
+ }
+ return nil, ErrSkipAltProtocol
+ }))
+ if _, err := c.Post(ts.URL, "application/octet-stream", bytes.NewBufferString("data")); err != nil {
+ t.Error(err)
+ }
+}
+
+// Issue 39017. Ensure that HTTP/1 transports reject Content-Length headers
+// that contain a sign (eg. "+3"), per RFC 2616, Section 14.13.
+func TestTransportRejectsSignInContentLength(t *testing.T) {
+ run(t, testTransportRejectsSignInContentLength, []testMode{http1Mode})
+}
+func testTransportRejectsSignInContentLength(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Length", "+3")
+ w.Write([]byte("abc"))
+ })).ts
+
+ c := cst.Client()
+ res, err := c.Get(cst.URL)
+ if err == nil || res != nil {
+ t.Fatal("Expected a non-nil error and a nil http.Response")
+ }
+ if got, want := err.Error(), `bad Content-Length "+3"`; !strings.Contains(got, want) {
+ t.Fatalf("Error mismatch\nGot: %q\nWanted substring: %q", got, want)
+ }
+}
+
+// dumpConn is a net.Conn which writes to Writer and reads from Reader
+type dumpConn struct {
+ io.Writer
+ io.Reader
+}
+
+func (c *dumpConn) Close() error { return nil }
+func (c *dumpConn) LocalAddr() net.Addr { return nil }
+func (c *dumpConn) RemoteAddr() net.Addr { return nil }
+func (c *dumpConn) SetDeadline(t time.Time) error { return nil }
+func (c *dumpConn) SetReadDeadline(t time.Time) error { return nil }
+func (c *dumpConn) SetWriteDeadline(t time.Time) error { return nil }
+
+// delegateReader is a reader that delegates to another reader,
+// once it arrives on a channel.
+type delegateReader struct {
+ c chan io.Reader
+ r io.Reader // nil until received from c
+}
+
+func (r *delegateReader) Read(p []byte) (int, error) {
+ if r.r == nil {
+ var ok bool
+ if r.r, ok = <-r.c; !ok {
+ return 0, errors.New("delegate closed")
+ }
+ }
+ return r.r.Read(p)
+}
+
+func testTransportRace(req *Request) {
+ save := req.Body
+ pr, pw := io.Pipe()
+ defer pr.Close()
+ defer pw.Close()
+ dr := &delegateReader{c: make(chan io.Reader)}
+
+ t := &Transport{
+ Dial: func(net, addr string) (net.Conn, error) {
+ return &dumpConn{pw, dr}, nil
+ },
+ }
+ defer t.CloseIdleConnections()
+
+ quitReadCh := make(chan struct{})
+ // Wait for the request before replying with a dummy response:
+ go func() {
+ defer close(quitReadCh)
+
+ req, err := ReadRequest(bufio.NewReader(pr))
+ if err == nil {
+ // Ensure all the body is read; otherwise
+ // we'll get a partial dump.
+ io.Copy(io.Discard, req.Body)
+ req.Body.Close()
+ }
+ select {
+ case dr.c <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n"):
+ case quitReadCh <- struct{}{}:
+ // Ensure delegate is closed so Read doesn't block forever.
+ close(dr.c)
+ }
+ }()
+
+ t.RoundTrip(req)
+
+ // Ensure the reader returns before we reset req.Body to prevent
+ // a data race on req.Body.
+ pw.Close()
+ <-quitReadCh
+
+ req.Body = save
+}
+
+// Issue 37669
+// Test that a cancellation doesn't result in a data race due to the writeLoop
+// goroutine being left running, if the caller mutates the processed Request
+// upon completion.
+func TestErrorWriteLoopRace(t *testing.T) {
+ if testing.Short() {
+ return
+ }
+ t.Parallel()
+ for i := 0; i < 1000; i++ {
+ delay := time.Duration(mrand.Intn(5)) * time.Millisecond
+ ctx, cancel := context.WithTimeout(context.Background(), delay)
+ defer cancel()
+
+ r := bytes.NewBuffer(make([]byte, 10000))
+ req, err := NewRequestWithContext(ctx, MethodPost, "http://example.com", r)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ testTransportRace(req)
+ }
+}
+
+// Issue 41600
+// Test that a new request which uses the connection of an active request
+// cannot cause it to be canceled as well.
+func TestCancelRequestWhenSharingConnection(t *testing.T) {
+ run(t, testCancelRequestWhenSharingConnection, []testMode{http1Mode})
+}
+func testCancelRequestWhenSharingConnection(t *testing.T, mode testMode) {
+ reqc := make(chan chan struct{}, 2)
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) {
+ ch := make(chan struct{}, 1)
+ reqc <- ch
+ <-ch
+ w.Header().Add("Content-Length", "0")
+ })).ts
+
+ client := ts.Client()
+ transport := client.Transport.(*Transport)
+ transport.MaxIdleConns = 1
+ transport.MaxConnsPerHost = 1
+
+ var wg sync.WaitGroup
+
+ wg.Add(1)
+ putidlec := make(chan chan struct{}, 1)
+ reqerrc := make(chan error, 1)
+ go func() {
+ defer wg.Done()
+ ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
+ PutIdleConn: func(error) {
+ // Signal that the idle conn has been returned to the pool,
+ // and wait for the order to proceed.
+ ch := make(chan struct{})
+ putidlec <- ch
+ close(putidlec) // panic if PutIdleConn runs twice for some reason
+ <-ch
+ },
+ })
+ req, _ := NewRequestWithContext(ctx, "GET", ts.URL, nil)
+ res, err := client.Do(req)
+ reqerrc <- err
+ if err == nil {
+ res.Body.Close()
+ }
+ }()
+
+ // Wait for the first request to receive a response and return the
+ // connection to the idle pool.
+ r1c := <-reqc
+ close(r1c)
+ var idlec chan struct{}
+ select {
+ case err := <-reqerrc:
+ if err != nil {
+ t.Fatalf("request 1: got err %v, want nil", err)
+ }
+ idlec = <-putidlec
+ case idlec = <-putidlec:
+ }
+
+ wg.Add(1)
+ cancelctx, cancel := context.WithCancel(context.Background())
+ go func() {
+ defer wg.Done()
+ req, _ := NewRequestWithContext(cancelctx, "GET", ts.URL, nil)
+ res, err := client.Do(req)
+ if err == nil {
+ res.Body.Close()
+ }
+ if !errors.Is(err, context.Canceled) {
+ t.Errorf("request 2: got err %v, want Canceled", err)
+ }
+
+ // Unblock the first request.
+ close(idlec)
+ }()
+
+ // Wait for the second request to arrive at the server, and then cancel
+ // the request context.
+ r2c := <-reqc
+ cancel()
+
+ <-idlec
+
+ close(r2c)
+ wg.Wait()
+}
+
+func TestHandlerAbortRacesBodyRead(t *testing.T) { run(t, testHandlerAbortRacesBodyRead) }
+func testHandlerAbortRacesBodyRead(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
+ go io.Copy(io.Discard, req.Body)
+ panic(ErrAbortHandler)
+ })).ts
+
+ var wg sync.WaitGroup
+ for i := 0; i < 2; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ for j := 0; j < 10; j++ {
+ const reqLen = 6 * 1024 * 1024
+ req, _ := NewRequest("POST", ts.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen})
+ req.ContentLength = reqLen
+ resp, _ := ts.Client().Transport.RoundTrip(req)
+ if resp != nil {
+ resp.Body.Close()
+ }
+ }
+ }()
+ }
+ wg.Wait()
+}
+
+func TestRequestSanitization(t *testing.T) { run(t, testRequestSanitization) }
+func testRequestSanitization(t *testing.T, mode testMode) {
+ if mode == http2Mode {
+ // Remove this after updating x/net.
+ t.Skip("https://go.dev/issue/60374 test fails when run with HTTP/2")
+ }
+ ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
+ if h, ok := req.Header["X-Evil"]; ok {
+ t.Errorf("request has X-Evil header: %q", h)
+ }
+ })).ts
+ req, _ := NewRequest("GET", ts.URL, nil)
+ req.Host = "go.dev\r\nX-Evil:evil"
+ resp, _ := ts.Client().Do(req)
+ if resp != nil {
+ resp.Body.Close()
+ }
+}
diff --git a/src/net/http/triv.go b/src/net/http/triv.go
new file mode 100644
index 0000000..f614922
--- /dev/null
+++ b/src/net/http/triv.go
@@ -0,0 +1,140 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build ignore
+
+package main
+
+import (
+ "expvar"
+ "flag"
+ "fmt"
+ "io"
+ "log"
+ "net/http"
+ "os"
+ "os/exec"
+ "strconv"
+ "strings"
+ "sync"
+)
+
+// hello world, the web server
+var helloRequests = expvar.NewInt("hello-requests")
+
+func HelloServer(w http.ResponseWriter, req *http.Request) {
+ helloRequests.Add(1)
+ io.WriteString(w, "hello, world!\n")
+}
+
+// Simple counter server. POSTing to it will set the value.
+type Counter struct {
+ mu sync.Mutex // protects n
+ n int
+}
+
+// This makes Counter satisfy the expvar.Var interface, so we can export
+// it directly.
+func (ctr *Counter) String() string {
+ ctr.mu.Lock()
+ defer ctr.mu.Unlock()
+ return strconv.Itoa(ctr.n)
+}
+
+func (ctr *Counter) ServeHTTP(w http.ResponseWriter, req *http.Request) {
+ ctr.mu.Lock()
+ defer ctr.mu.Unlock()
+ switch req.Method {
+ case "GET":
+ ctr.n++
+ case "POST":
+ var buf strings.Builder
+ io.Copy(&buf, req.Body)
+ body := buf.String()
+ if n, err := strconv.Atoi(body); err != nil {
+ fmt.Fprintf(w, "bad POST: %v\nbody: [%v]\n", err, body)
+ } else {
+ ctr.n = n
+ fmt.Fprint(w, "counter reset\n")
+ }
+ }
+ fmt.Fprintf(w, "counter = %d\n", ctr.n)
+}
+
+// simple flag server
+var booleanflag = flag.Bool("boolean", true, "another flag for testing")
+
+func FlagServer(w http.ResponseWriter, req *http.Request) {
+ w.Header().Set("Content-Type", "text/plain; charset=utf-8")
+ fmt.Fprint(w, "Flags:\n")
+ flag.VisitAll(func(f *flag.Flag) {
+ if f.Value.String() != f.DefValue {
+ fmt.Fprintf(w, "%s = %s [default = %s]\n", f.Name, f.Value.String(), f.DefValue)
+ } else {
+ fmt.Fprintf(w, "%s = %s\n", f.Name, f.Value.String())
+ }
+ })
+}
+
+// simple argument server
+func ArgServer(w http.ResponseWriter, req *http.Request) {
+ for _, s := range os.Args {
+ fmt.Fprint(w, s, " ")
+ }
+}
+
+// a channel (just for the fun of it)
+type Chan chan int
+
+func ChanCreate() Chan {
+ c := make(Chan)
+ go func(c Chan) {
+ for x := 0; ; x++ {
+ c <- x
+ }
+ }(c)
+ return c
+}
+
+func (ch Chan) ServeHTTP(w http.ResponseWriter, req *http.Request) {
+ io.WriteString(w, fmt.Sprintf("channel send #%d\n", <-ch))
+}
+
+// exec a program, redirecting output.
+func DateServer(rw http.ResponseWriter, req *http.Request) {
+ rw.Header().Set("Content-Type", "text/plain; charset=utf-8")
+
+ date, err := exec.Command("/bin/date").Output()
+ if err != nil {
+ http.Error(rw, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ rw.Write(date)
+}
+
+func Logger(w http.ResponseWriter, req *http.Request) {
+ log.Print(req.URL)
+ http.Error(w, "oops", http.StatusNotFound)
+}
+
+var webroot = flag.String("root", "", "web root directory")
+
+func main() {
+ flag.Parse()
+
+ // The counter is published as a variable directly.
+ ctr := new(Counter)
+ expvar.Publish("counter", ctr)
+ http.Handle("/counter", ctr)
+ http.Handle("/", http.HandlerFunc(Logger))
+ if *webroot != "" {
+ http.Handle("/go/", http.StripPrefix("/go/", http.FileServer(http.Dir(*webroot))))
+ }
+ http.Handle("/chan", ChanCreate())
+ http.HandleFunc("/flags", FlagServer)
+ http.HandleFunc("/args", ArgServer)
+ http.HandleFunc("/go/hello", HelloServer)
+ http.HandleFunc("/date", DateServer)
+ log.Fatal(http.ListenAndServe("localhost:12345", nil))
+}
diff --git a/src/net/interface.go b/src/net/interface.go
new file mode 100644
index 0000000..e1c9a2e
--- /dev/null
+++ b/src/net/interface.go
@@ -0,0 +1,259 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "errors"
+ "internal/itoa"
+ "sync"
+ "time"
+)
+
+// BUG(mikio): On JS, methods and functions related to
+// Interface are not implemented.
+
+// BUG(mikio): On AIX, DragonFly BSD, NetBSD, OpenBSD, Plan 9 and
+// Solaris, the MulticastAddrs method of Interface is not implemented.
+
+var (
+ errInvalidInterface = errors.New("invalid network interface")
+ errInvalidInterfaceIndex = errors.New("invalid network interface index")
+ errInvalidInterfaceName = errors.New("invalid network interface name")
+ errNoSuchInterface = errors.New("no such network interface")
+ errNoSuchMulticastInterface = errors.New("no such multicast network interface")
+)
+
+// Interface represents a mapping between network interface name
+// and index. It also represents network interface facility
+// information.
+type Interface struct {
+ Index int // positive integer that starts at one, zero is never used
+ MTU int // maximum transmission unit
+ Name string // e.g., "en0", "lo0", "eth0.100"
+ HardwareAddr HardwareAddr // IEEE MAC-48, EUI-48 and EUI-64 form
+ Flags Flags // e.g., FlagUp, FlagLoopback, FlagMulticast
+}
+
+type Flags uint
+
+const (
+ FlagUp Flags = 1 << iota // interface is administratively up
+ FlagBroadcast // interface supports broadcast access capability
+ FlagLoopback // interface is a loopback interface
+ FlagPointToPoint // interface belongs to a point-to-point link
+ FlagMulticast // interface supports multicast access capability
+ FlagRunning // interface is in running state
+)
+
+var flagNames = []string{
+ "up",
+ "broadcast",
+ "loopback",
+ "pointtopoint",
+ "multicast",
+ "running",
+}
+
+func (f Flags) String() string {
+ s := ""
+ for i, name := range flagNames {
+ if f&(1<<uint(i)) != 0 {
+ if s != "" {
+ s += "|"
+ }
+ s += name
+ }
+ }
+ if s == "" {
+ s = "0"
+ }
+ return s
+}
+
+// Addrs returns a list of unicast interface addresses for a specific
+// interface.
+func (ifi *Interface) Addrs() ([]Addr, error) {
+ if ifi == nil {
+ return nil, &OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: errInvalidInterface}
+ }
+ ifat, err := interfaceAddrTable(ifi)
+ if err != nil {
+ err = &OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: err}
+ }
+ return ifat, err
+}
+
+// MulticastAddrs returns a list of multicast, joined group addresses
+// for a specific interface.
+func (ifi *Interface) MulticastAddrs() ([]Addr, error) {
+ if ifi == nil {
+ return nil, &OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: errInvalidInterface}
+ }
+ ifat, err := interfaceMulticastAddrTable(ifi)
+ if err != nil {
+ err = &OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: err}
+ }
+ return ifat, err
+}
+
+// Interfaces returns a list of the system's network interfaces.
+func Interfaces() ([]Interface, error) {
+ ift, err := interfaceTable(0)
+ if err != nil {
+ return nil, &OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: err}
+ }
+ if len(ift) != 0 {
+ zoneCache.update(ift, false)
+ }
+ return ift, nil
+}
+
+// InterfaceAddrs returns a list of the system's unicast interface
+// addresses.
+//
+// The returned list does not identify the associated interface; use
+// Interfaces and Interface.Addrs for more detail.
+func InterfaceAddrs() ([]Addr, error) {
+ ifat, err := interfaceAddrTable(nil)
+ if err != nil {
+ err = &OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: err}
+ }
+ return ifat, err
+}
+
+// InterfaceByIndex returns the interface specified by index.
+//
+// On Solaris, it returns one of the logical network interfaces
+// sharing the logical data link; for more precision use
+// InterfaceByName.
+func InterfaceByIndex(index int) (*Interface, error) {
+ if index <= 0 {
+ return nil, &OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: errInvalidInterfaceIndex}
+ }
+ ift, err := interfaceTable(index)
+ if err != nil {
+ return nil, &OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: err}
+ }
+ ifi, err := interfaceByIndex(ift, index)
+ if err != nil {
+ err = &OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: err}
+ }
+ return ifi, err
+}
+
+func interfaceByIndex(ift []Interface, index int) (*Interface, error) {
+ for _, ifi := range ift {
+ if index == ifi.Index {
+ return &ifi, nil
+ }
+ }
+ return nil, errNoSuchInterface
+}
+
+// InterfaceByName returns the interface specified by name.
+func InterfaceByName(name string) (*Interface, error) {
+ if name == "" {
+ return nil, &OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: errInvalidInterfaceName}
+ }
+ ift, err := interfaceTable(0)
+ if err != nil {
+ return nil, &OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: err}
+ }
+ if len(ift) != 0 {
+ zoneCache.update(ift, false)
+ }
+ for _, ifi := range ift {
+ if name == ifi.Name {
+ return &ifi, nil
+ }
+ }
+ return nil, &OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: errNoSuchInterface}
+}
+
+// An ipv6ZoneCache represents a cache holding partial network
+// interface information. It is used for reducing the cost of IPv6
+// addressing scope zone resolution.
+//
+// Multiple names sharing the index are managed by first-come
+// first-served basis for consistency.
+type ipv6ZoneCache struct {
+ sync.RWMutex // guard the following
+ lastFetched time.Time // last time routing information was fetched
+ toIndex map[string]int // interface name to its index
+ toName map[int]string // interface index to its name
+}
+
+var zoneCache = ipv6ZoneCache{
+ toIndex: make(map[string]int),
+ toName: make(map[int]string),
+}
+
+// update refreshes the network interface information if the cache was last
+// updated more than 1 minute ago, or if force is set. It reports whether the
+// cache was updated.
+func (zc *ipv6ZoneCache) update(ift []Interface, force bool) (updated bool) {
+ zc.Lock()
+ defer zc.Unlock()
+ now := time.Now()
+ if !force && zc.lastFetched.After(now.Add(-60*time.Second)) {
+ return false
+ }
+ zc.lastFetched = now
+ if len(ift) == 0 {
+ var err error
+ if ift, err = interfaceTable(0); err != nil {
+ return false
+ }
+ }
+ zc.toIndex = make(map[string]int, len(ift))
+ zc.toName = make(map[int]string, len(ift))
+ for _, ifi := range ift {
+ zc.toIndex[ifi.Name] = ifi.Index
+ if _, ok := zc.toName[ifi.Index]; !ok {
+ zc.toName[ifi.Index] = ifi.Name
+ }
+ }
+ return true
+}
+
+func (zc *ipv6ZoneCache) name(index int) string {
+ if index == 0 {
+ return ""
+ }
+ updated := zoneCache.update(nil, false)
+ zoneCache.RLock()
+ name, ok := zoneCache.toName[index]
+ zoneCache.RUnlock()
+ if !ok && !updated {
+ zoneCache.update(nil, true)
+ zoneCache.RLock()
+ name, ok = zoneCache.toName[index]
+ zoneCache.RUnlock()
+ }
+ if !ok { // last resort
+ name = itoa.Uitoa(uint(index))
+ }
+ return name
+}
+
+func (zc *ipv6ZoneCache) index(name string) int {
+ if name == "" {
+ return 0
+ }
+ updated := zoneCache.update(nil, false)
+ zoneCache.RLock()
+ index, ok := zoneCache.toIndex[name]
+ zoneCache.RUnlock()
+ if !ok && !updated {
+ zoneCache.update(nil, true)
+ zoneCache.RLock()
+ index, ok = zoneCache.toIndex[name]
+ zoneCache.RUnlock()
+ }
+ if !ok { // last resort
+ index, _, _ = dtoi(name)
+ }
+ return index
+}
diff --git a/src/net/interface_aix.go b/src/net/interface_aix.go
new file mode 100644
index 0000000..f2e967b
--- /dev/null
+++ b/src/net/interface_aix.go
@@ -0,0 +1,189 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "internal/poll"
+ "internal/syscall/unix"
+ "syscall"
+ "unsafe"
+)
+
+type rawSockaddrDatalink struct {
+ Len uint8
+ Family uint8
+ Index uint16
+ Type uint8
+ Nlen uint8
+ Alen uint8
+ Slen uint8
+ Data [120]byte
+}
+
+type ifreq struct {
+ Name [16]uint8
+ Ifru [16]byte
+}
+
+const _KINFO_RT_IFLIST = (0x1 << 8) | 3 | (1 << 30)
+
+const _RTAX_NETMASK = 2
+const _RTAX_IFA = 5
+const _RTAX_MAX = 8
+
+func getIfList() ([]byte, error) {
+ needed, err := syscall.Getkerninfo(_KINFO_RT_IFLIST, 0, 0, 0)
+ if err != nil {
+ return nil, err
+ }
+ tab := make([]byte, needed)
+ _, err = syscall.Getkerninfo(_KINFO_RT_IFLIST, uintptr(unsafe.Pointer(&tab[0])), uintptr(unsafe.Pointer(&needed)), 0)
+ if err != nil {
+ return nil, err
+ }
+ return tab[:needed], nil
+}
+
+// If the ifindex is zero, interfaceTable returns mappings of all
+// network interfaces. Otherwise it returns a mapping of a specific
+// interface.
+func interfaceTable(ifindex int) ([]Interface, error) {
+ tab, err := getIfList()
+ if err != nil {
+ return nil, err
+ }
+
+ sock, err := sysSocket(syscall.AF_INET, syscall.SOCK_DGRAM, 0)
+ if err != nil {
+ return nil, err
+ }
+ defer poll.CloseFunc(sock)
+
+ var ift []Interface
+ for len(tab) > 0 {
+ ifm := (*syscall.IfMsgHdr)(unsafe.Pointer(&tab[0]))
+ if ifm.Msglen == 0 {
+ break
+ }
+ if ifm.Type == syscall.RTM_IFINFO {
+ if ifindex == 0 || ifindex == int(ifm.Index) {
+ sdl := (*rawSockaddrDatalink)(unsafe.Pointer(&tab[syscall.SizeofIfMsghdr]))
+
+ ifi := &Interface{Index: int(ifm.Index), Flags: linkFlags(ifm.Flags)}
+ ifi.Name = string(sdl.Data[:sdl.Nlen])
+ ifi.HardwareAddr = sdl.Data[sdl.Nlen : sdl.Nlen+sdl.Alen]
+
+ // Retrieve MTU
+ ifr := &ifreq{}
+ copy(ifr.Name[:], ifi.Name)
+ err = unix.Ioctl(sock, syscall.SIOCGIFMTU, unsafe.Pointer(ifr))
+ if err != nil {
+ return nil, err
+ }
+ ifi.MTU = int(ifr.Ifru[0])<<24 | int(ifr.Ifru[1])<<16 | int(ifr.Ifru[2])<<8 | int(ifr.Ifru[3])
+
+ ift = append(ift, *ifi)
+ if ifindex == int(ifm.Index) {
+ break
+ }
+ }
+ }
+ tab = tab[ifm.Msglen:]
+ }
+
+ return ift, nil
+}
+
+func linkFlags(rawFlags int32) Flags {
+ var f Flags
+ if rawFlags&syscall.IFF_UP != 0 {
+ f |= FlagUp
+ }
+ if rawFlags&syscall.IFF_RUNNING != 0 {
+ f |= FlagRunning
+ }
+ if rawFlags&syscall.IFF_BROADCAST != 0 {
+ f |= FlagBroadcast
+ }
+ if rawFlags&syscall.IFF_LOOPBACK != 0 {
+ f |= FlagLoopback
+ }
+ if rawFlags&syscall.IFF_POINTOPOINT != 0 {
+ f |= FlagPointToPoint
+ }
+ if rawFlags&syscall.IFF_MULTICAST != 0 {
+ f |= FlagMulticast
+ }
+ return f
+}
+
+// If the ifi is nil, interfaceAddrTable returns addresses for all
+// network interfaces. Otherwise it returns addresses for a specific
+// interface.
+func interfaceAddrTable(ifi *Interface) ([]Addr, error) {
+ tab, err := getIfList()
+ if err != nil {
+ return nil, err
+ }
+
+ var ifat []Addr
+ for len(tab) > 0 {
+ ifm := (*syscall.IfMsgHdr)(unsafe.Pointer(&tab[0]))
+ if ifm.Msglen == 0 {
+ break
+ }
+ if ifm.Type == syscall.RTM_NEWADDR {
+ if ifi == nil || ifi.Index == int(ifm.Index) {
+ mask := ifm.Addrs
+ off := uint(syscall.SizeofIfMsghdr)
+
+ var iprsa, nmrsa *syscall.RawSockaddr
+ for i := uint(0); i < _RTAX_MAX; i++ {
+ if mask&(1<<i) == 0 {
+ continue
+ }
+ rsa := (*syscall.RawSockaddr)(unsafe.Pointer(&tab[off]))
+ if i == _RTAX_NETMASK {
+ nmrsa = rsa
+ }
+ if i == _RTAX_IFA {
+ iprsa = rsa
+ }
+ off += (uint(rsa.Len) + 3) &^ 3
+ }
+ if iprsa != nil && nmrsa != nil {
+ var mask IPMask
+ var ip IP
+
+ switch iprsa.Family {
+ case syscall.AF_INET:
+ ipsa := (*syscall.RawSockaddrInet4)(unsafe.Pointer(iprsa))
+ nmsa := (*syscall.RawSockaddrInet4)(unsafe.Pointer(nmrsa))
+ ip = IPv4(ipsa.Addr[0], ipsa.Addr[1], ipsa.Addr[2], ipsa.Addr[3])
+ mask = IPv4Mask(nmsa.Addr[0], nmsa.Addr[1], nmsa.Addr[2], nmsa.Addr[3])
+ case syscall.AF_INET6:
+ ipsa := (*syscall.RawSockaddrInet6)(unsafe.Pointer(iprsa))
+ nmsa := (*syscall.RawSockaddrInet6)(unsafe.Pointer(nmrsa))
+ ip = make(IP, IPv6len)
+ copy(ip, ipsa.Addr[:])
+ mask = make(IPMask, IPv6len)
+ copy(mask, nmsa.Addr[:])
+ }
+ ifa := &IPNet{IP: ip, Mask: mask}
+ ifat = append(ifat, ifa)
+ }
+ }
+ }
+ tab = tab[ifm.Msglen:]
+ }
+
+ return ifat, nil
+}
+
+// interfaceMulticastAddrTable returns addresses for a specific
+// interface.
+func interfaceMulticastAddrTable(ifi *Interface) ([]Addr, error) {
+ return nil, nil
+}
diff --git a/src/net/interface_bsd.go b/src/net/interface_bsd.go
new file mode 100644
index 0000000..9b2b42a
--- /dev/null
+++ b/src/net/interface_bsd.go
@@ -0,0 +1,121 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build darwin || dragonfly || freebsd || netbsd || openbsd
+
+package net
+
+import (
+ "syscall"
+
+ "golang.org/x/net/route"
+)
+
+// If the ifindex is zero, interfaceTable returns mappings of all
+// network interfaces. Otherwise it returns a mapping of a specific
+// interface.
+func interfaceTable(ifindex int) ([]Interface, error) {
+ msgs, err := interfaceMessages(ifindex)
+ if err != nil {
+ return nil, err
+ }
+ n := len(msgs)
+ if ifindex != 0 {
+ n = 1
+ }
+ ift := make([]Interface, n)
+ n = 0
+ for _, m := range msgs {
+ switch m := m.(type) {
+ case *route.InterfaceMessage:
+ if ifindex != 0 && ifindex != m.Index {
+ continue
+ }
+ ift[n].Index = m.Index
+ ift[n].Name = m.Name
+ ift[n].Flags = linkFlags(m.Flags)
+ if sa, ok := m.Addrs[syscall.RTAX_IFP].(*route.LinkAddr); ok && len(sa.Addr) > 0 {
+ ift[n].HardwareAddr = make([]byte, len(sa.Addr))
+ copy(ift[n].HardwareAddr, sa.Addr)
+ }
+ for _, sys := range m.Sys() {
+ if imx, ok := sys.(*route.InterfaceMetrics); ok {
+ ift[n].MTU = imx.MTU
+ break
+ }
+ }
+ n++
+ if ifindex == m.Index {
+ return ift[:n], nil
+ }
+ }
+ }
+ return ift[:n], nil
+}
+
+func linkFlags(rawFlags int) Flags {
+ var f Flags
+ if rawFlags&syscall.IFF_UP != 0 {
+ f |= FlagUp
+ }
+ if rawFlags&syscall.IFF_RUNNING != 0 {
+ f |= FlagRunning
+ }
+ if rawFlags&syscall.IFF_BROADCAST != 0 {
+ f |= FlagBroadcast
+ }
+ if rawFlags&syscall.IFF_LOOPBACK != 0 {
+ f |= FlagLoopback
+ }
+ if rawFlags&syscall.IFF_POINTOPOINT != 0 {
+ f |= FlagPointToPoint
+ }
+ if rawFlags&syscall.IFF_MULTICAST != 0 {
+ f |= FlagMulticast
+ }
+ return f
+}
+
+// If the ifi is nil, interfaceAddrTable returns addresses for all
+// network interfaces. Otherwise it returns addresses for a specific
+// interface.
+func interfaceAddrTable(ifi *Interface) ([]Addr, error) {
+ index := 0
+ if ifi != nil {
+ index = ifi.Index
+ }
+ msgs, err := interfaceMessages(index)
+ if err != nil {
+ return nil, err
+ }
+ ifat := make([]Addr, 0, len(msgs))
+ for _, m := range msgs {
+ switch m := m.(type) {
+ case *route.InterfaceAddrMessage:
+ if index != 0 && index != m.Index {
+ continue
+ }
+ var mask IPMask
+ switch sa := m.Addrs[syscall.RTAX_NETMASK].(type) {
+ case *route.Inet4Addr:
+ mask = IPv4Mask(sa.IP[0], sa.IP[1], sa.IP[2], sa.IP[3])
+ case *route.Inet6Addr:
+ mask = make(IPMask, IPv6len)
+ copy(mask, sa.IP[:])
+ }
+ var ip IP
+ switch sa := m.Addrs[syscall.RTAX_IFA].(type) {
+ case *route.Inet4Addr:
+ ip = IPv4(sa.IP[0], sa.IP[1], sa.IP[2], sa.IP[3])
+ case *route.Inet6Addr:
+ ip = make(IP, IPv6len)
+ copy(ip, sa.IP[:])
+ }
+ if ip != nil && mask != nil { // NetBSD may contain route.LinkAddr
+ ifat = append(ifat, &IPNet{IP: ip, Mask: mask})
+ }
+ }
+ }
+ return ifat, nil
+}
diff --git a/src/net/interface_bsd_test.go b/src/net/interface_bsd_test.go
new file mode 100644
index 0000000..ce59962
--- /dev/null
+++ b/src/net/interface_bsd_test.go
@@ -0,0 +1,60 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build darwin || dragonfly || freebsd || netbsd || openbsd
+
+package net
+
+import (
+ "errors"
+ "fmt"
+ "os/exec"
+ "runtime"
+)
+
+func (ti *testInterface) setBroadcast(vid int) error {
+ if runtime.GOOS == "openbsd" {
+ ti.name = fmt.Sprintf("vether%d", vid)
+ } else {
+ ti.name = fmt.Sprintf("vlan%d", vid)
+ }
+ xname, err := exec.LookPath("ifconfig")
+ if err != nil {
+ return err
+ }
+ ti.setupCmds = append(ti.setupCmds, &exec.Cmd{
+ Path: xname,
+ Args: []string{"ifconfig", ti.name, "create"},
+ })
+ ti.teardownCmds = append(ti.teardownCmds, &exec.Cmd{
+ Path: xname,
+ Args: []string{"ifconfig", ti.name, "destroy"},
+ })
+ return nil
+}
+
+func (ti *testInterface) setPointToPoint(suffix int) error {
+ ti.name = fmt.Sprintf("gif%d", suffix)
+ xname, err := exec.LookPath("ifconfig")
+ if err != nil {
+ return err
+ }
+ ti.setupCmds = append(ti.setupCmds, &exec.Cmd{
+ Path: xname,
+ Args: []string{"ifconfig", ti.name, "create"},
+ })
+ ti.setupCmds = append(ti.setupCmds, &exec.Cmd{
+ Path: xname,
+ Args: []string{"ifconfig", ti.name, "inet", ti.local, ti.remote},
+ })
+ ti.teardownCmds = append(ti.teardownCmds, &exec.Cmd{
+ Path: xname,
+ Args: []string{"ifconfig", ti.name, "destroy"},
+ })
+ return nil
+}
+
+func (ti *testInterface) setLinkLocal(suffix int) error {
+ return errors.New("not yet implemented for BSD")
+}
diff --git a/src/net/interface_bsdvar.go b/src/net/interface_bsdvar.go
new file mode 100644
index 0000000..e9bea3d
--- /dev/null
+++ b/src/net/interface_bsdvar.go
@@ -0,0 +1,28 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build dragonfly || netbsd || openbsd
+
+package net
+
+import (
+ "syscall"
+
+ "golang.org/x/net/route"
+)
+
+func interfaceMessages(ifindex int) ([]route.Message, error) {
+ rib, err := route.FetchRIB(syscall.AF_UNSPEC, syscall.NET_RT_IFLIST, ifindex)
+ if err != nil {
+ return nil, err
+ }
+ return route.ParseRIB(syscall.NET_RT_IFLIST, rib)
+}
+
+// interfaceMulticastAddrTable returns addresses for a specific
+// interface.
+func interfaceMulticastAddrTable(ifi *Interface) ([]Addr, error) {
+ // TODO(mikio): Implement this like other platforms.
+ return nil, nil
+}
diff --git a/src/net/interface_darwin.go b/src/net/interface_darwin.go
new file mode 100644
index 0000000..bb4fd73
--- /dev/null
+++ b/src/net/interface_darwin.go
@@ -0,0 +1,53 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "syscall"
+
+ "golang.org/x/net/route"
+)
+
+func interfaceMessages(ifindex int) ([]route.Message, error) {
+ rib, err := route.FetchRIB(syscall.AF_UNSPEC, syscall.NET_RT_IFLIST, ifindex)
+ if err != nil {
+ return nil, err
+ }
+ return route.ParseRIB(syscall.NET_RT_IFLIST, rib)
+}
+
+// interfaceMulticastAddrTable returns addresses for a specific
+// interface.
+func interfaceMulticastAddrTable(ifi *Interface) ([]Addr, error) {
+ rib, err := route.FetchRIB(syscall.AF_UNSPEC, syscall.NET_RT_IFLIST2, ifi.Index)
+ if err != nil {
+ return nil, err
+ }
+ msgs, err := route.ParseRIB(syscall.NET_RT_IFLIST2, rib)
+ if err != nil {
+ return nil, err
+ }
+ ifmat := make([]Addr, 0, len(msgs))
+ for _, m := range msgs {
+ switch m := m.(type) {
+ case *route.InterfaceMulticastAddrMessage:
+ if ifi.Index != m.Index {
+ continue
+ }
+ var ip IP
+ switch sa := m.Addrs[syscall.RTAX_IFA].(type) {
+ case *route.Inet4Addr:
+ ip = IPv4(sa.IP[0], sa.IP[1], sa.IP[2], sa.IP[3])
+ case *route.Inet6Addr:
+ ip = make(IP, IPv6len)
+ copy(ip, sa.IP[:])
+ }
+ if ip != nil {
+ ifmat = append(ifmat, &IPAddr{IP: ip})
+ }
+ }
+ }
+ return ifmat, nil
+}
diff --git a/src/net/interface_freebsd.go b/src/net/interface_freebsd.go
new file mode 100644
index 0000000..8536bd3
--- /dev/null
+++ b/src/net/interface_freebsd.go
@@ -0,0 +1,53 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "syscall"
+
+ "golang.org/x/net/route"
+)
+
+func interfaceMessages(ifindex int) ([]route.Message, error) {
+ rib, err := route.FetchRIB(syscall.AF_UNSPEC, route.RIBTypeInterface, ifindex)
+ if err != nil {
+ return nil, err
+ }
+ return route.ParseRIB(route.RIBTypeInterface, rib)
+}
+
+// interfaceMulticastAddrTable returns addresses for a specific
+// interface.
+func interfaceMulticastAddrTable(ifi *Interface) ([]Addr, error) {
+ rib, err := route.FetchRIB(syscall.AF_UNSPEC, syscall.NET_RT_IFMALIST, ifi.Index)
+ if err != nil {
+ return nil, err
+ }
+ msgs, err := route.ParseRIB(syscall.NET_RT_IFMALIST, rib)
+ if err != nil {
+ return nil, err
+ }
+ ifmat := make([]Addr, 0, len(msgs))
+ for _, m := range msgs {
+ switch m := m.(type) {
+ case *route.InterfaceMulticastAddrMessage:
+ if ifi.Index != m.Index {
+ continue
+ }
+ var ip IP
+ switch sa := m.Addrs[syscall.RTAX_IFA].(type) {
+ case *route.Inet4Addr:
+ ip = IPv4(sa.IP[0], sa.IP[1], sa.IP[2], sa.IP[3])
+ case *route.Inet6Addr:
+ ip = make(IP, IPv6len)
+ copy(ip, sa.IP[:])
+ }
+ if ip != nil {
+ ifmat = append(ifmat, &IPAddr{IP: ip})
+ }
+ }
+ }
+ return ifmat, nil
+}
diff --git a/src/net/interface_linux.go b/src/net/interface_linux.go
new file mode 100644
index 0000000..9112ecc
--- /dev/null
+++ b/src/net/interface_linux.go
@@ -0,0 +1,272 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "os"
+ "syscall"
+ "unsafe"
+)
+
+// If the ifindex is zero, interfaceTable returns mappings of all
+// network interfaces. Otherwise it returns a mapping of a specific
+// interface.
+func interfaceTable(ifindex int) ([]Interface, error) {
+ tab, err := syscall.NetlinkRIB(syscall.RTM_GETLINK, syscall.AF_UNSPEC)
+ if err != nil {
+ return nil, os.NewSyscallError("netlinkrib", err)
+ }
+ msgs, err := syscall.ParseNetlinkMessage(tab)
+ if err != nil {
+ return nil, os.NewSyscallError("parsenetlinkmessage", err)
+ }
+ var ift []Interface
+loop:
+ for _, m := range msgs {
+ switch m.Header.Type {
+ case syscall.NLMSG_DONE:
+ break loop
+ case syscall.RTM_NEWLINK:
+ ifim := (*syscall.IfInfomsg)(unsafe.Pointer(&m.Data[0]))
+ if ifindex == 0 || ifindex == int(ifim.Index) {
+ attrs, err := syscall.ParseNetlinkRouteAttr(&m)
+ if err != nil {
+ return nil, os.NewSyscallError("parsenetlinkrouteattr", err)
+ }
+ ift = append(ift, *newLink(ifim, attrs))
+ if ifindex == int(ifim.Index) {
+ break loop
+ }
+ }
+ }
+ }
+ return ift, nil
+}
+
+const (
+ // See linux/if_arp.h.
+ // Note that Linux doesn't support IPv4 over IPv6 tunneling.
+ sysARPHardwareIPv4IPv4 = 768 // IPv4 over IPv4 tunneling
+ sysARPHardwareIPv6IPv6 = 769 // IPv6 over IPv6 tunneling
+ sysARPHardwareIPv6IPv4 = 776 // IPv6 over IPv4 tunneling
+ sysARPHardwareGREIPv4 = 778 // any over GRE over IPv4 tunneling
+ sysARPHardwareGREIPv6 = 823 // any over GRE over IPv6 tunneling
+)
+
+func newLink(ifim *syscall.IfInfomsg, attrs []syscall.NetlinkRouteAttr) *Interface {
+ ifi := &Interface{Index: int(ifim.Index), Flags: linkFlags(ifim.Flags)}
+ for _, a := range attrs {
+ switch a.Attr.Type {
+ case syscall.IFLA_ADDRESS:
+ // We never return any /32 or /128 IP address
+ // prefix on any IP tunnel interface as the
+ // hardware address.
+ switch len(a.Value) {
+ case IPv4len:
+ switch ifim.Type {
+ case sysARPHardwareIPv4IPv4, sysARPHardwareGREIPv4, sysARPHardwareIPv6IPv4:
+ continue
+ }
+ case IPv6len:
+ switch ifim.Type {
+ case sysARPHardwareIPv6IPv6, sysARPHardwareGREIPv6:
+ continue
+ }
+ }
+ var nonzero bool
+ for _, b := range a.Value {
+ if b != 0 {
+ nonzero = true
+ break
+ }
+ }
+ if nonzero {
+ ifi.HardwareAddr = a.Value[:]
+ }
+ case syscall.IFLA_IFNAME:
+ ifi.Name = string(a.Value[:len(a.Value)-1])
+ case syscall.IFLA_MTU:
+ ifi.MTU = int(*(*uint32)(unsafe.Pointer(&a.Value[:4][0])))
+ }
+ }
+ return ifi
+}
+
+func linkFlags(rawFlags uint32) Flags {
+ var f Flags
+ if rawFlags&syscall.IFF_UP != 0 {
+ f |= FlagUp
+ }
+ if rawFlags&syscall.IFF_RUNNING != 0 {
+ f |= FlagRunning
+ }
+ if rawFlags&syscall.IFF_BROADCAST != 0 {
+ f |= FlagBroadcast
+ }
+ if rawFlags&syscall.IFF_LOOPBACK != 0 {
+ f |= FlagLoopback
+ }
+ if rawFlags&syscall.IFF_POINTOPOINT != 0 {
+ f |= FlagPointToPoint
+ }
+ if rawFlags&syscall.IFF_MULTICAST != 0 {
+ f |= FlagMulticast
+ }
+ return f
+}
+
+// If the ifi is nil, interfaceAddrTable returns addresses for all
+// network interfaces. Otherwise it returns addresses for a specific
+// interface.
+func interfaceAddrTable(ifi *Interface) ([]Addr, error) {
+ tab, err := syscall.NetlinkRIB(syscall.RTM_GETADDR, syscall.AF_UNSPEC)
+ if err != nil {
+ return nil, os.NewSyscallError("netlinkrib", err)
+ }
+ msgs, err := syscall.ParseNetlinkMessage(tab)
+ if err != nil {
+ return nil, os.NewSyscallError("parsenetlinkmessage", err)
+ }
+ var ift []Interface
+ if ifi == nil {
+ var err error
+ ift, err = interfaceTable(0)
+ if err != nil {
+ return nil, err
+ }
+ }
+ ifat, err := addrTable(ift, ifi, msgs)
+ if err != nil {
+ return nil, err
+ }
+ return ifat, nil
+}
+
+func addrTable(ift []Interface, ifi *Interface, msgs []syscall.NetlinkMessage) ([]Addr, error) {
+ var ifat []Addr
+loop:
+ for _, m := range msgs {
+ switch m.Header.Type {
+ case syscall.NLMSG_DONE:
+ break loop
+ case syscall.RTM_NEWADDR:
+ ifam := (*syscall.IfAddrmsg)(unsafe.Pointer(&m.Data[0]))
+ if len(ift) != 0 || ifi.Index == int(ifam.Index) {
+ if len(ift) != 0 {
+ var err error
+ ifi, err = interfaceByIndex(ift, int(ifam.Index))
+ if err != nil {
+ return nil, err
+ }
+ }
+ attrs, err := syscall.ParseNetlinkRouteAttr(&m)
+ if err != nil {
+ return nil, os.NewSyscallError("parsenetlinkrouteattr", err)
+ }
+ ifa := newAddr(ifam, attrs)
+ if ifa != nil {
+ ifat = append(ifat, ifa)
+ }
+ }
+ }
+ }
+ return ifat, nil
+}
+
+func newAddr(ifam *syscall.IfAddrmsg, attrs []syscall.NetlinkRouteAttr) Addr {
+ var ipPointToPoint bool
+ // Seems like we need to make sure whether the IP interface
+ // stack consists of IP point-to-point numbered or unnumbered
+ // addressing.
+ for _, a := range attrs {
+ if a.Attr.Type == syscall.IFA_LOCAL {
+ ipPointToPoint = true
+ break
+ }
+ }
+ for _, a := range attrs {
+ if ipPointToPoint && a.Attr.Type == syscall.IFA_ADDRESS {
+ continue
+ }
+ switch ifam.Family {
+ case syscall.AF_INET:
+ return &IPNet{IP: IPv4(a.Value[0], a.Value[1], a.Value[2], a.Value[3]), Mask: CIDRMask(int(ifam.Prefixlen), 8*IPv4len)}
+ case syscall.AF_INET6:
+ ifa := &IPNet{IP: make(IP, IPv6len), Mask: CIDRMask(int(ifam.Prefixlen), 8*IPv6len)}
+ copy(ifa.IP, a.Value[:])
+ return ifa
+ }
+ }
+ return nil
+}
+
+// interfaceMulticastAddrTable returns addresses for a specific
+// interface.
+func interfaceMulticastAddrTable(ifi *Interface) ([]Addr, error) {
+ ifmat4 := parseProcNetIGMP("/proc/net/igmp", ifi)
+ ifmat6 := parseProcNetIGMP6("/proc/net/igmp6", ifi)
+ return append(ifmat4, ifmat6...), nil
+}
+
+func parseProcNetIGMP(path string, ifi *Interface) []Addr {
+ fd, err := open(path)
+ if err != nil {
+ return nil
+ }
+ defer fd.close()
+ var (
+ ifmat []Addr
+ name string
+ )
+ fd.readLine() // skip first line
+ b := make([]byte, IPv4len)
+ for l, ok := fd.readLine(); ok; l, ok = fd.readLine() {
+ f := splitAtBytes(l, " :\r\t\n")
+ if len(f) < 4 {
+ continue
+ }
+ switch {
+ case l[0] != ' ' && l[0] != '\t': // new interface line
+ name = f[1]
+ case len(f[0]) == 8:
+ if ifi == nil || name == ifi.Name {
+ // The Linux kernel puts the IP
+ // address in /proc/net/igmp in native
+ // endianness.
+ for i := 0; i+1 < len(f[0]); i += 2 {
+ b[i/2], _ = xtoi2(f[0][i:i+2], 0)
+ }
+ i := *(*uint32)(unsafe.Pointer(&b[:4][0]))
+ ifma := &IPAddr{IP: IPv4(byte(i>>24), byte(i>>16), byte(i>>8), byte(i))}
+ ifmat = append(ifmat, ifma)
+ }
+ }
+ }
+ return ifmat
+}
+
+func parseProcNetIGMP6(path string, ifi *Interface) []Addr {
+ fd, err := open(path)
+ if err != nil {
+ return nil
+ }
+ defer fd.close()
+ var ifmat []Addr
+ b := make([]byte, IPv6len)
+ for l, ok := fd.readLine(); ok; l, ok = fd.readLine() {
+ f := splitAtBytes(l, " \r\t\n")
+ if len(f) < 6 {
+ continue
+ }
+ if ifi == nil || f[1] == ifi.Name {
+ for i := 0; i+1 < len(f[2]); i += 2 {
+ b[i/2], _ = xtoi2(f[2][i:i+2], 0)
+ }
+ ifma := &IPAddr{IP: IP{b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7], b[8], b[9], b[10], b[11], b[12], b[13], b[14], b[15]}}
+ ifmat = append(ifmat, ifma)
+ }
+ }
+ return ifmat
+}
diff --git a/src/net/interface_linux_test.go b/src/net/interface_linux_test.go
new file mode 100644
index 0000000..0699fec
--- /dev/null
+++ b/src/net/interface_linux_test.go
@@ -0,0 +1,133 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "fmt"
+ "os/exec"
+ "testing"
+)
+
+func (ti *testInterface) setBroadcast(suffix int) error {
+ ti.name = fmt.Sprintf("gotest%d", suffix)
+ xname, err := exec.LookPath("ip")
+ if err != nil {
+ return err
+ }
+ ti.setupCmds = append(ti.setupCmds, &exec.Cmd{
+ Path: xname,
+ Args: []string{"ip", "link", "add", ti.name, "type", "dummy"},
+ })
+ ti.setupCmds = append(ti.setupCmds, &exec.Cmd{
+ Path: xname,
+ Args: []string{"ip", "address", "add", ti.local, "peer", ti.remote, "dev", ti.name},
+ })
+ ti.teardownCmds = append(ti.teardownCmds, &exec.Cmd{
+ Path: xname,
+ Args: []string{"ip", "address", "del", ti.local, "peer", ti.remote, "dev", ti.name},
+ })
+ ti.teardownCmds = append(ti.teardownCmds, &exec.Cmd{
+ Path: xname,
+ Args: []string{"ip", "link", "delete", ti.name, "type", "dummy"},
+ })
+ return nil
+}
+
+func (ti *testInterface) setLinkLocal(suffix int) error {
+ ti.name = fmt.Sprintf("gotest%d", suffix)
+ xname, err := exec.LookPath("ip")
+ if err != nil {
+ return err
+ }
+ ti.setupCmds = append(ti.setupCmds, &exec.Cmd{
+ Path: xname,
+ Args: []string{"ip", "link", "add", ti.name, "type", "dummy"},
+ })
+ ti.setupCmds = append(ti.setupCmds, &exec.Cmd{
+ Path: xname,
+ Args: []string{"ip", "address", "add", ti.local, "dev", ti.name},
+ })
+ ti.teardownCmds = append(ti.teardownCmds, &exec.Cmd{
+ Path: xname,
+ Args: []string{"ip", "address", "del", ti.local, "dev", ti.name},
+ })
+ ti.teardownCmds = append(ti.teardownCmds, &exec.Cmd{
+ Path: xname,
+ Args: []string{"ip", "link", "delete", ti.name, "type", "dummy"},
+ })
+ return nil
+}
+
+func (ti *testInterface) setPointToPoint(suffix int) error {
+ ti.name = fmt.Sprintf("gotest%d", suffix)
+ xname, err := exec.LookPath("ip")
+ if err != nil {
+ return err
+ }
+ ti.setupCmds = append(ti.setupCmds, &exec.Cmd{
+ Path: xname,
+ Args: []string{"ip", "tunnel", "add", ti.name, "mode", "gre", "local", ti.local, "remote", ti.remote},
+ })
+ ti.setupCmds = append(ti.setupCmds, &exec.Cmd{
+ Path: xname,
+ Args: []string{"ip", "address", "add", ti.local, "peer", ti.remote, "dev", ti.name},
+ })
+ ti.teardownCmds = append(ti.teardownCmds, &exec.Cmd{
+ Path: xname,
+ Args: []string{"ip", "address", "del", ti.local, "peer", ti.remote, "dev", ti.name},
+ })
+ ti.teardownCmds = append(ti.teardownCmds, &exec.Cmd{
+ Path: xname,
+ Args: []string{"ip", "tunnel", "del", ti.name, "mode", "gre", "local", ti.local, "remote", ti.remote},
+ })
+ return nil
+}
+
+const (
+ numOfTestIPv4MCAddrs = 14
+ numOfTestIPv6MCAddrs = 18
+)
+
+var (
+ igmpInterfaceTable = []Interface{
+ {Name: "lo"},
+ {Name: "eth0"}, {Name: "eth1"}, {Name: "eth2"},
+ {Name: "eth0.100"}, {Name: "eth0.101"}, {Name: "eth0.102"}, {Name: "eth0.103"},
+ {Name: "device1tap2"},
+ }
+ igmp6InterfaceTable = []Interface{
+ {Name: "lo"},
+ {Name: "eth0"}, {Name: "eth1"}, {Name: "eth2"},
+ {Name: "eth0.100"}, {Name: "eth0.101"}, {Name: "eth0.102"}, {Name: "eth0.103"},
+ {Name: "device1tap2"},
+ {Name: "pan0"},
+ }
+)
+
+func TestParseProcNet(t *testing.T) {
+ defer func() {
+ if p := recover(); p != nil {
+ t.Fatalf("panicked: %v", p)
+ }
+ }()
+
+ var ifmat4 []Addr
+ for _, ifi := range igmpInterfaceTable {
+ ifmat := parseProcNetIGMP("testdata/igmp", &ifi)
+ ifmat4 = append(ifmat4, ifmat...)
+ }
+ if len(ifmat4) != numOfTestIPv4MCAddrs {
+ t.Fatalf("got %d; want %d", len(ifmat4), numOfTestIPv4MCAddrs)
+ }
+
+ var ifmat6 []Addr
+ for _, ifi := range igmp6InterfaceTable {
+ ifmat := parseProcNetIGMP6("testdata/igmp6", &ifi)
+ ifmat6 = append(ifmat6, ifmat...)
+ }
+ if len(ifmat6) != numOfTestIPv6MCAddrs {
+ t.Fatalf("got %d; want %d", len(ifmat6), numOfTestIPv6MCAddrs)
+ }
+}
diff --git a/src/net/interface_plan9.go b/src/net/interface_plan9.go
new file mode 100644
index 0000000..92b2eed
--- /dev/null
+++ b/src/net/interface_plan9.go
@@ -0,0 +1,200 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "errors"
+ "internal/itoa"
+ "os"
+)
+
+// If the ifindex is zero, interfaceTable returns mappings of all
+// network interfaces. Otherwise it returns a mapping of a specific
+// interface.
+func interfaceTable(ifindex int) ([]Interface, error) {
+ if ifindex == 0 {
+ n, err := interfaceCount()
+ if err != nil {
+ return nil, err
+ }
+ ifcs := make([]Interface, n)
+ for i := range ifcs {
+ ifc, err := readInterface(i)
+ if err != nil {
+ return nil, err
+ }
+ ifcs[i] = *ifc
+ }
+ return ifcs, nil
+ }
+
+ ifc, err := readInterface(ifindex - 1)
+ if err != nil {
+ return nil, err
+ }
+ return []Interface{*ifc}, nil
+}
+
+func readInterface(i int) (*Interface, error) {
+ ifc := &Interface{
+ Index: i + 1, // Offset the index by one to suit the contract
+ Name: netdir + "/ipifc/" + itoa.Itoa(i), // Name is the full path to the interface path in plan9
+ }
+
+ ifcstat := ifc.Name + "/status"
+ ifcstatf, err := open(ifcstat)
+ if err != nil {
+ return nil, err
+ }
+ defer ifcstatf.close()
+
+ line, ok := ifcstatf.readLine()
+ if !ok {
+ return nil, errors.New("invalid interface status file: " + ifcstat)
+ }
+
+ fields := getFields(line)
+ if len(fields) < 4 {
+ return nil, errors.New("invalid interface status file: " + ifcstat)
+ }
+
+ device := fields[1]
+ mtustr := fields[3]
+
+ mtu, _, ok := dtoi(mtustr)
+ if !ok {
+ return nil, errors.New("invalid status file of interface: " + ifcstat)
+ }
+ ifc.MTU = mtu
+
+ // Not a loopback device ("/dev/null") or packet interface (e.g. "pkt2")
+ if stringsHasPrefix(device, netdir+"/") {
+ deviceaddrf, err := open(device + "/addr")
+ if err != nil {
+ return nil, err
+ }
+ defer deviceaddrf.close()
+
+ line, ok = deviceaddrf.readLine()
+ if !ok {
+ return nil, errors.New("invalid address file for interface: " + device + "/addr")
+ }
+
+ if len(line) > 0 && len(line)%2 == 0 {
+ ifc.HardwareAddr = make([]byte, len(line)/2)
+ var ok bool
+ for i := range ifc.HardwareAddr {
+ j := (i + 1) * 2
+ ifc.HardwareAddr[i], ok = xtoi2(line[i*2:j], 0)
+ if !ok {
+ ifc.HardwareAddr = ifc.HardwareAddr[:i]
+ break
+ }
+ }
+ }
+
+ ifc.Flags = FlagUp | FlagRunning | FlagBroadcast | FlagMulticast
+ } else {
+ ifc.Flags = FlagUp | FlagRunning | FlagMulticast | FlagLoopback
+ }
+
+ return ifc, nil
+}
+
+func interfaceCount() (int, error) {
+ d, err := os.Open(netdir + "/ipifc")
+ if err != nil {
+ return -1, err
+ }
+ defer d.Close()
+
+ names, err := d.Readdirnames(0)
+ if err != nil {
+ return -1, err
+ }
+
+ // Assumes that numbered files in ipifc are strictly
+ // the incrementing numbered directories for the
+ // interfaces
+ c := 0
+ for _, name := range names {
+ if _, _, ok := dtoi(name); !ok {
+ continue
+ }
+ c++
+ }
+
+ return c, nil
+}
+
+// If the ifi is nil, interfaceAddrTable returns addresses for all
+// network interfaces. Otherwise it returns addresses for a specific
+// interface.
+func interfaceAddrTable(ifi *Interface) ([]Addr, error) {
+ var ifcs []Interface
+ if ifi == nil {
+ var err error
+ ifcs, err = interfaceTable(0)
+ if err != nil {
+ return nil, err
+ }
+ } else {
+ ifcs = []Interface{*ifi}
+ }
+
+ var addrs []Addr
+ for _, ifc := range ifcs {
+ status := ifc.Name + "/status"
+ statusf, err := open(status)
+ if err != nil {
+ return nil, err
+ }
+ defer statusf.close()
+
+ // Read but ignore first line as it only contains the table header.
+ // See https://9p.io/magic/man2html/3/ip
+ if _, ok := statusf.readLine(); !ok {
+ return nil, errors.New("cannot read header line for interface: " + status)
+ }
+
+ for line, ok := statusf.readLine(); ok; line, ok = statusf.readLine() {
+ fields := getFields(line)
+ if len(fields) < 1 {
+ return nil, errors.New("cannot parse IP address for interface: " + status)
+ }
+ addr := fields[0]
+ ip := ParseIP(addr)
+ if ip == nil {
+ return nil, errors.New("cannot parse IP address for interface: " + status)
+ }
+
+ // The mask is represented as CIDR relative to the IPv6 address.
+ // Plan 9 internal representation is always IPv6.
+ maskfld := fields[1]
+ maskfld = maskfld[1:]
+ pfxlen, _, ok := dtoi(maskfld)
+ if !ok {
+ return nil, errors.New("cannot parse network mask for interface: " + status)
+ }
+ var mask IPMask
+ if ip.To4() != nil { // IPv4 or IPv6 IPv4-mapped address
+ mask = CIDRMask(pfxlen-8*len(v4InV6Prefix), 8*IPv4len)
+ }
+ if ip.To16() != nil && ip.To4() == nil { // IPv6 address
+ mask = CIDRMask(pfxlen, 8*IPv6len)
+ }
+
+ addrs = append(addrs, &IPNet{IP: ip, Mask: mask})
+ }
+ }
+
+ return addrs, nil
+}
+
+// interfaceMulticastAddrTable returns addresses for a specific
+// interface.
+func interfaceMulticastAddrTable(ifi *Interface) ([]Addr, error) {
+ return nil, nil
+}
diff --git a/src/net/interface_solaris.go b/src/net/interface_solaris.go
new file mode 100644
index 0000000..32f503f
--- /dev/null
+++ b/src/net/interface_solaris.go
@@ -0,0 +1,92 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "syscall"
+
+ "golang.org/x/net/lif"
+)
+
+// If the ifindex is zero, interfaceTable returns mappings of all
+// network interfaces. Otherwise it returns a mapping of a specific
+// interface.
+func interfaceTable(ifindex int) ([]Interface, error) {
+ lls, err := lif.Links(syscall.AF_UNSPEC, "")
+ if err != nil {
+ return nil, err
+ }
+ var ift []Interface
+ for _, ll := range lls {
+ if ifindex != 0 && ifindex != ll.Index {
+ continue
+ }
+ ifi := Interface{Index: ll.Index, MTU: ll.MTU, Name: ll.Name, Flags: linkFlags(ll.Flags)}
+ if len(ll.Addr) > 0 {
+ ifi.HardwareAddr = HardwareAddr(ll.Addr)
+ }
+ ift = append(ift, ifi)
+ }
+ return ift, nil
+}
+
+func linkFlags(rawFlags int) Flags {
+ var f Flags
+ if rawFlags&syscall.IFF_UP != 0 {
+ f |= FlagUp
+ }
+ if rawFlags&syscall.IFF_RUNNING != 0 {
+ f |= FlagRunning
+ }
+ if rawFlags&syscall.IFF_BROADCAST != 0 {
+ f |= FlagBroadcast
+ }
+ if rawFlags&syscall.IFF_LOOPBACK != 0 {
+ f |= FlagLoopback
+ }
+ if rawFlags&syscall.IFF_POINTOPOINT != 0 {
+ f |= FlagPointToPoint
+ }
+ if rawFlags&syscall.IFF_MULTICAST != 0 {
+ f |= FlagMulticast
+ }
+ return f
+}
+
+// If the ifi is nil, interfaceAddrTable returns addresses for all
+// network interfaces. Otherwise it returns addresses for a specific
+// interface.
+func interfaceAddrTable(ifi *Interface) ([]Addr, error) {
+ var name string
+ if ifi != nil {
+ name = ifi.Name
+ }
+ as, err := lif.Addrs(syscall.AF_UNSPEC, name)
+ if err != nil {
+ return nil, err
+ }
+ var ifat []Addr
+ for _, a := range as {
+ var ip IP
+ var mask IPMask
+ switch a := a.(type) {
+ case *lif.Inet4Addr:
+ ip = IPv4(a.IP[0], a.IP[1], a.IP[2], a.IP[3])
+ mask = CIDRMask(a.PrefixLen, 8*IPv4len)
+ case *lif.Inet6Addr:
+ ip = make(IP, IPv6len)
+ copy(ip, a.IP[:])
+ mask = CIDRMask(a.PrefixLen, 8*IPv6len)
+ }
+ ifat = append(ifat, &IPNet{IP: ip, Mask: mask})
+ }
+ return ifat, nil
+}
+
+// interfaceMulticastAddrTable returns addresses for a specific
+// interface.
+func interfaceMulticastAddrTable(ifi *Interface) ([]Addr, error) {
+ return nil, nil
+}
diff --git a/src/net/interface_stub.go b/src/net/interface_stub.go
new file mode 100644
index 0000000..829dbc6
--- /dev/null
+++ b/src/net/interface_stub.go
@@ -0,0 +1,27 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build (js && wasm) || wasip1
+
+package net
+
+// If the ifindex is zero, interfaceTable returns mappings of all
+// network interfaces. Otherwise it returns a mapping of a specific
+// interface.
+func interfaceTable(ifindex int) ([]Interface, error) {
+ return nil, nil
+}
+
+// If the ifi is nil, interfaceAddrTable returns addresses for all
+// network interfaces. Otherwise it returns addresses for a specific
+// interface.
+func interfaceAddrTable(ifi *Interface) ([]Addr, error) {
+ return nil, nil
+}
+
+// interfaceMulticastAddrTable returns addresses for a specific
+// interface.
+func interfaceMulticastAddrTable(ifi *Interface) ([]Addr, error) {
+ return nil, nil
+}
diff --git a/src/net/interface_test.go b/src/net/interface_test.go
new file mode 100644
index 0000000..5590b06
--- /dev/null
+++ b/src/net/interface_test.go
@@ -0,0 +1,382 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !js && !wasip1
+
+package net
+
+import (
+ "fmt"
+ "reflect"
+ "runtime"
+ "testing"
+)
+
+// loopbackInterface returns an available logical network interface
+// for loopback tests. It returns nil if no suitable interface is
+// found.
+func loopbackInterface() *Interface {
+ ift, err := Interfaces()
+ if err != nil {
+ return nil
+ }
+ for _, ifi := range ift {
+ if ifi.Flags&FlagLoopback != 0 && ifi.Flags&FlagUp != 0 {
+ return &ifi
+ }
+ }
+ return nil
+}
+
+// ipv6LinkLocalUnicastAddr returns an IPv6 link-local unicast address
+// on the given network interface for tests. It returns "" if no
+// suitable address is found.
+func ipv6LinkLocalUnicastAddr(ifi *Interface) string {
+ if ifi == nil {
+ return ""
+ }
+ ifat, err := ifi.Addrs()
+ if err != nil {
+ return ""
+ }
+ for _, ifa := range ifat {
+ if ifa, ok := ifa.(*IPNet); ok {
+ if ifa.IP.To4() == nil && ifa.IP.IsLinkLocalUnicast() {
+ return ifa.IP.String()
+ }
+ }
+ }
+ return ""
+}
+
+func TestInterfaces(t *testing.T) {
+ ift, err := Interfaces()
+ if err != nil {
+ t.Fatal(err)
+ }
+ for _, ifi := range ift {
+ ifxi, err := InterfaceByIndex(ifi.Index)
+ if err != nil {
+ t.Fatal(err)
+ }
+ switch runtime.GOOS {
+ case "solaris", "illumos":
+ if ifxi.Index != ifi.Index {
+ t.Errorf("got %v; want %v", ifxi, ifi)
+ }
+ default:
+ if !reflect.DeepEqual(ifxi, &ifi) {
+ t.Errorf("got %v; want %v", ifxi, ifi)
+ }
+ }
+ ifxn, err := InterfaceByName(ifi.Name)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !reflect.DeepEqual(ifxn, &ifi) {
+ t.Errorf("got %v; want %v", ifxn, ifi)
+ }
+ t.Logf("%s: flags=%v index=%d mtu=%d hwaddr=%v", ifi.Name, ifi.Flags, ifi.Index, ifi.MTU, ifi.HardwareAddr)
+ }
+}
+
+func TestInterfaceAddrs(t *testing.T) {
+ ift, err := Interfaces()
+ if err != nil {
+ t.Fatal(err)
+ }
+ ifStats := interfaceStats(ift)
+ ifat, err := InterfaceAddrs()
+ if err != nil {
+ t.Fatal(err)
+ }
+ uniStats, err := validateInterfaceUnicastAddrs(ifat)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if err := checkUnicastStats(ifStats, uniStats); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestInterfaceUnicastAddrs(t *testing.T) {
+ ift, err := Interfaces()
+ if err != nil {
+ t.Fatal(err)
+ }
+ ifStats := interfaceStats(ift)
+ if err != nil {
+ t.Fatal(err)
+ }
+ var uniStats routeStats
+ for _, ifi := range ift {
+ ifat, err := ifi.Addrs()
+ if err != nil {
+ t.Fatal(ifi, err)
+ }
+ stats, err := validateInterfaceUnicastAddrs(ifat)
+ if err != nil {
+ t.Fatal(ifi, err)
+ }
+ uniStats.ipv4 += stats.ipv4
+ uniStats.ipv6 += stats.ipv6
+ }
+ if err := checkUnicastStats(ifStats, &uniStats); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestInterfaceMulticastAddrs(t *testing.T) {
+ ift, err := Interfaces()
+ if err != nil {
+ t.Fatal(err)
+ }
+ ifStats := interfaceStats(ift)
+ ifat, err := InterfaceAddrs()
+ if err != nil {
+ t.Fatal(err)
+ }
+ uniStats, err := validateInterfaceUnicastAddrs(ifat)
+ if err != nil {
+ t.Fatal(err)
+ }
+ var multiStats routeStats
+ for _, ifi := range ift {
+ ifmat, err := ifi.MulticastAddrs()
+ if err != nil {
+ t.Fatal(ifi, err)
+ }
+ stats, err := validateInterfaceMulticastAddrs(ifmat)
+ if err != nil {
+ t.Fatal(ifi, err)
+ }
+ multiStats.ipv4 += stats.ipv4
+ multiStats.ipv6 += stats.ipv6
+ }
+ if err := checkMulticastStats(ifStats, uniStats, &multiStats); err != nil {
+ t.Fatal(err)
+ }
+}
+
+type ifStats struct {
+ loop int // # of active loopback interfaces
+ other int // # of active other interfaces
+}
+
+func interfaceStats(ift []Interface) *ifStats {
+ var stats ifStats
+ for _, ifi := range ift {
+ if ifi.Flags&FlagUp != 0 {
+ if ifi.Flags&FlagLoopback != 0 {
+ stats.loop++
+ } else {
+ stats.other++
+ }
+ }
+ }
+ return &stats
+}
+
+type routeStats struct {
+ ipv4, ipv6 int // # of active connected unicast, anycast or multicast routes
+}
+
+func validateInterfaceUnicastAddrs(ifat []Addr) (*routeStats, error) {
+ // Note: BSD variants allow assigning any IPv4/IPv6 address
+ // prefix to IP interface. For example,
+ // - 0.0.0.0/0 through 255.255.255.255/32
+ // - ::/0 through ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff/128
+ // In other words, there is no tightly-coupled combination of
+ // interface address prefixes and connected routes.
+ stats := new(routeStats)
+ for _, ifa := range ifat {
+ switch ifa := ifa.(type) {
+ case *IPNet:
+ if ifa == nil || ifa.IP == nil || ifa.IP.IsMulticast() || ifa.Mask == nil {
+ return nil, fmt.Errorf("unexpected value: %#v", ifa)
+ }
+ if len(ifa.IP) != IPv6len {
+ return nil, fmt.Errorf("should be internal representation either IPv6 or IPv4-mapped IPv6 address: %#v", ifa)
+ }
+ prefixLen, maxPrefixLen := ifa.Mask.Size()
+ if ifa.IP.To4() != nil {
+ if 0 >= prefixLen || prefixLen > 8*IPv4len || maxPrefixLen != 8*IPv4len {
+ return nil, fmt.Errorf("unexpected prefix length: %d/%d for %#v", prefixLen, maxPrefixLen, ifa)
+ }
+ if ifa.IP.IsLoopback() && prefixLen < 8 { // see RFC 1122
+ return nil, fmt.Errorf("unexpected prefix length: %d/%d for %#v", prefixLen, maxPrefixLen, ifa)
+ }
+ stats.ipv4++
+ }
+ if ifa.IP.To16() != nil && ifa.IP.To4() == nil {
+ if 0 >= prefixLen || prefixLen > 8*IPv6len || maxPrefixLen != 8*IPv6len {
+ return nil, fmt.Errorf("unexpected prefix length: %d/%d for %#v", prefixLen, maxPrefixLen, ifa)
+ }
+ if ifa.IP.IsLoopback() && prefixLen != 8*IPv6len { // see RFC 4291
+ return nil, fmt.Errorf("unexpected prefix length: %d/%d for %#v", prefixLen, maxPrefixLen, ifa)
+ }
+ stats.ipv6++
+ }
+ case *IPAddr:
+ if ifa == nil || ifa.IP == nil || ifa.IP.IsMulticast() {
+ return nil, fmt.Errorf("unexpected value: %#v", ifa)
+ }
+ if len(ifa.IP) != IPv6len {
+ return nil, fmt.Errorf("should be internal representation either IPv6 or IPv4-mapped IPv6 address: %#v", ifa)
+ }
+ if ifa.IP.To4() != nil {
+ stats.ipv4++
+ }
+ if ifa.IP.To16() != nil && ifa.IP.To4() == nil {
+ stats.ipv6++
+ }
+ default:
+ return nil, fmt.Errorf("unexpected type: %T", ifa)
+ }
+ }
+ return stats, nil
+}
+
+func validateInterfaceMulticastAddrs(ifat []Addr) (*routeStats, error) {
+ stats := new(routeStats)
+ for _, ifa := range ifat {
+ switch ifa := ifa.(type) {
+ case *IPAddr:
+ if ifa == nil || ifa.IP == nil || ifa.IP.IsUnspecified() || !ifa.IP.IsMulticast() {
+ return nil, fmt.Errorf("unexpected value: %#v", ifa)
+ }
+ if len(ifa.IP) != IPv6len {
+ return nil, fmt.Errorf("should be internal representation either IPv6 or IPv4-mapped IPv6 address: %#v", ifa)
+ }
+ if ifa.IP.To4() != nil {
+ stats.ipv4++
+ }
+ if ifa.IP.To16() != nil && ifa.IP.To4() == nil {
+ stats.ipv6++
+ }
+ default:
+ return nil, fmt.Errorf("unexpected type: %T", ifa)
+ }
+ }
+ return stats, nil
+}
+
+func checkUnicastStats(ifStats *ifStats, uniStats *routeStats) error {
+ // Test the existence of connected unicast routes for IPv4.
+ if supportsIPv4() && ifStats.loop+ifStats.other > 0 && uniStats.ipv4 == 0 {
+ return fmt.Errorf("num IPv4 unicast routes = 0; want >0; summary: %+v, %+v", ifStats, uniStats)
+ }
+ // Test the existence of connected unicast routes for IPv6.
+ // We can assume the existence of ::1/128 when at least one
+ // loopback interface is installed.
+ if supportsIPv6() && ifStats.loop > 0 && uniStats.ipv6 == 0 {
+ return fmt.Errorf("num IPv6 unicast routes = 0; want >0; summary: %+v, %+v", ifStats, uniStats)
+ }
+ return nil
+}
+
+func checkMulticastStats(ifStats *ifStats, uniStats, multiStats *routeStats) error {
+ switch runtime.GOOS {
+ case "aix", "dragonfly", "netbsd", "openbsd", "plan9", "solaris", "illumos":
+ default:
+ // Test the existence of connected multicast route
+ // clones for IPv4. Unlike IPv6, IPv4 multicast
+ // capability is not a mandatory feature, and so IPv4
+ // multicast validation is ignored and we only check
+ // IPv6 below.
+ //
+ // Test the existence of connected multicast route
+ // clones for IPv6. Some platform never uses loopback
+ // interface as the nexthop for multicast routing.
+ // We can assume the existence of connected multicast
+ // route clones when at least two connected unicast
+ // routes, ::1/128 and other, are installed.
+ if supportsIPv6() && ifStats.loop > 0 && uniStats.ipv6 > 1 && multiStats.ipv6 == 0 {
+ return fmt.Errorf("num IPv6 multicast route clones = 0; want >0; summary: %+v, %+v, %+v", ifStats, uniStats, multiStats)
+ }
+ }
+ return nil
+}
+
+func BenchmarkInterfaces(b *testing.B) {
+ b.ReportAllocs()
+ testHookUninstaller.Do(uninstallTestHooks)
+
+ for i := 0; i < b.N; i++ {
+ if _, err := Interfaces(); err != nil {
+ b.Fatal(err)
+ }
+ }
+}
+
+func BenchmarkInterfaceByIndex(b *testing.B) {
+ b.ReportAllocs()
+ testHookUninstaller.Do(uninstallTestHooks)
+
+ ifi := loopbackInterface()
+ if ifi == nil {
+ b.Skip("loopback interface not found")
+ }
+ for i := 0; i < b.N; i++ {
+ if _, err := InterfaceByIndex(ifi.Index); err != nil {
+ b.Fatal(err)
+ }
+ }
+}
+
+func BenchmarkInterfaceByName(b *testing.B) {
+ b.ReportAllocs()
+ testHookUninstaller.Do(uninstallTestHooks)
+
+ ifi := loopbackInterface()
+ if ifi == nil {
+ b.Skip("loopback interface not found")
+ }
+ for i := 0; i < b.N; i++ {
+ if _, err := InterfaceByName(ifi.Name); err != nil {
+ b.Fatal(err)
+ }
+ }
+}
+
+func BenchmarkInterfaceAddrs(b *testing.B) {
+ b.ReportAllocs()
+ testHookUninstaller.Do(uninstallTestHooks)
+
+ for i := 0; i < b.N; i++ {
+ if _, err := InterfaceAddrs(); err != nil {
+ b.Fatal(err)
+ }
+ }
+}
+
+func BenchmarkInterfacesAndAddrs(b *testing.B) {
+ b.ReportAllocs()
+ testHookUninstaller.Do(uninstallTestHooks)
+
+ ifi := loopbackInterface()
+ if ifi == nil {
+ b.Skip("loopback interface not found")
+ }
+ for i := 0; i < b.N; i++ {
+ if _, err := ifi.Addrs(); err != nil {
+ b.Fatal(err)
+ }
+ }
+}
+
+func BenchmarkInterfacesAndMulticastAddrs(b *testing.B) {
+ b.ReportAllocs()
+ testHookUninstaller.Do(uninstallTestHooks)
+
+ ifi := loopbackInterface()
+ if ifi == nil {
+ b.Skip("loopback interface not found")
+ }
+ for i := 0; i < b.N; i++ {
+ if _, err := ifi.MulticastAddrs(); err != nil {
+ b.Fatal(err)
+ }
+ }
+}
diff --git a/src/net/interface_unix_test.go b/src/net/interface_unix_test.go
new file mode 100644
index 0000000..b0a9bcf
--- /dev/null
+++ b/src/net/interface_unix_test.go
@@ -0,0 +1,215 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build darwin || dragonfly || freebsd || linux || netbsd || openbsd
+
+package net
+
+import (
+ "fmt"
+ "os"
+ "os/exec"
+ "runtime"
+ "strings"
+ "testing"
+ "time"
+)
+
+type testInterface struct {
+ name string
+ local string
+ remote string
+ setupCmds []*exec.Cmd
+ teardownCmds []*exec.Cmd
+}
+
+func (ti *testInterface) setup() error {
+ for _, cmd := range ti.setupCmds {
+ if out, err := cmd.CombinedOutput(); err != nil {
+ return fmt.Errorf("args=%v out=%q err=%v", cmd.Args, string(out), err)
+ }
+ }
+ return nil
+}
+
+func (ti *testInterface) teardown() error {
+ for _, cmd := range ti.teardownCmds {
+ if out, err := cmd.CombinedOutput(); err != nil {
+ return fmt.Errorf("args=%v out=%q err=%v ", cmd.Args, string(out), err)
+ }
+ }
+ return nil
+}
+
+func TestPointToPointInterface(t *testing.T) {
+ if testing.Short() {
+ t.Skip("avoid external network")
+ }
+ if runtime.GOOS == "darwin" || runtime.GOOS == "ios" {
+ t.Skipf("not supported on %s", runtime.GOOS)
+ }
+ if os.Getuid() != 0 {
+ t.Skip("must be root")
+ }
+
+ // We suppose that using IPv4 link-local addresses doesn't
+ // harm anyone.
+ local, remote := "169.254.0.1", "169.254.0.254"
+ ip := ParseIP(remote)
+ for i := 0; i < 3; i++ {
+ ti := &testInterface{local: local, remote: remote}
+ if err := ti.setPointToPoint(5963 + i); err != nil {
+ t.Skipf("test requires external command: %v", err)
+ }
+ if err := ti.setup(); err != nil {
+ if e := err.Error(); strings.Contains(e, "No such device") && strings.Contains(e, "gre0") {
+ t.Skip("skipping test; no gre0 device. likely running in container?")
+ }
+ t.Fatal(err)
+ } else {
+ time.Sleep(3 * time.Millisecond)
+ }
+ ift, err := Interfaces()
+ if err != nil {
+ ti.teardown()
+ t.Fatal(err)
+ }
+ for _, ifi := range ift {
+ if ti.name != ifi.Name {
+ continue
+ }
+ ifat, err := ifi.Addrs()
+ if err != nil {
+ ti.teardown()
+ t.Fatal(err)
+ }
+ for _, ifa := range ifat {
+ if ip.Equal(ifa.(*IPNet).IP) {
+ ti.teardown()
+ t.Fatalf("got %v", ifa)
+ }
+ }
+ }
+ if err := ti.teardown(); err != nil {
+ t.Fatal(err)
+ } else {
+ time.Sleep(3 * time.Millisecond)
+ }
+ }
+}
+
+func TestInterfaceArrivalAndDeparture(t *testing.T) {
+ if testing.Short() {
+ t.Skip("avoid external network")
+ }
+ if os.Getuid() != 0 {
+ t.Skip("must be root")
+ }
+
+ // We suppose that using IPv4 link-local addresses and the
+ // dot1Q ID for Token Ring and FDDI doesn't harm anyone.
+ local, remote := "169.254.0.1", "169.254.0.254"
+ ip := ParseIP(remote)
+ for _, vid := range []int{1002, 1003, 1004, 1005} {
+ ift1, err := Interfaces()
+ if err != nil {
+ t.Fatal(err)
+ }
+ ti := &testInterface{local: local, remote: remote}
+ if err := ti.setBroadcast(vid); err != nil {
+ t.Skipf("test requires external command: %v", err)
+ }
+ if err := ti.setup(); err != nil {
+ t.Fatal(err)
+ } else {
+ time.Sleep(3 * time.Millisecond)
+ }
+ ift2, err := Interfaces()
+ if err != nil {
+ ti.teardown()
+ t.Fatal(err)
+ }
+ if len(ift2) <= len(ift1) {
+ for _, ifi := range ift1 {
+ t.Logf("before: %v", ifi)
+ }
+ for _, ifi := range ift2 {
+ t.Logf("after: %v", ifi)
+ }
+ ti.teardown()
+ t.Fatalf("got %v; want gt %v", len(ift2), len(ift1))
+ }
+ for _, ifi := range ift2 {
+ if ti.name != ifi.Name {
+ continue
+ }
+ ifat, err := ifi.Addrs()
+ if err != nil {
+ ti.teardown()
+ t.Fatal(err)
+ }
+ for _, ifa := range ifat {
+ if ip.Equal(ifa.(*IPNet).IP) {
+ ti.teardown()
+ t.Fatalf("got %v", ifa)
+ }
+ }
+ }
+ if err := ti.teardown(); err != nil {
+ t.Fatal(err)
+ } else {
+ time.Sleep(3 * time.Millisecond)
+ }
+ ift3, err := Interfaces()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(ift3) >= len(ift2) {
+ for _, ifi := range ift2 {
+ t.Logf("before: %v", ifi)
+ }
+ for _, ifi := range ift3 {
+ t.Logf("after: %v", ifi)
+ }
+ t.Fatalf("got %v; want lt %v", len(ift3), len(ift2))
+ }
+ }
+}
+
+func TestInterfaceArrivalAndDepartureZoneCache(t *testing.T) {
+ if testing.Short() {
+ t.Skip("avoid external network")
+ }
+ if os.Getuid() != 0 {
+ t.Skip("must be root")
+ }
+
+ // Ensure zoneCache is filled:
+ _, _ = Listen("tcp", "[fe80::1%nonexistent]:0")
+
+ ti := &testInterface{local: "fe80::1"}
+ if err := ti.setLinkLocal(0); err != nil {
+ t.Skipf("test requires external command: %v", err)
+ }
+ if err := ti.setup(); err != nil {
+ if e := err.Error(); strings.Contains(e, "Permission denied") {
+ t.Skipf("permission denied, skipping test: %v", e)
+ }
+ t.Fatal(err)
+ }
+ defer ti.teardown()
+
+ time.Sleep(3 * time.Millisecond)
+
+ // If Listen fails (on Linux with “bind: invalid argument”), zoneCache was
+ // not updated when encountering a nonexistent interface:
+ ln, err := Listen("tcp", "[fe80::1%"+ti.name+"]:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ ln.Close()
+ if err := ti.teardown(); err != nil {
+ t.Fatal(err)
+ }
+}
diff --git a/src/net/interface_windows.go b/src/net/interface_windows.go
new file mode 100644
index 0000000..22a1312
--- /dev/null
+++ b/src/net/interface_windows.go
@@ -0,0 +1,178 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "internal/syscall/windows"
+ "os"
+ "syscall"
+ "unsafe"
+)
+
+// adapterAddresses returns a list of IP adapter and address
+// structures. The structure contains an IP adapter and flattened
+// multiple IP addresses including unicast, anycast and multicast
+// addresses.
+func adapterAddresses() ([]*windows.IpAdapterAddresses, error) {
+ var b []byte
+ l := uint32(15000) // recommended initial size
+ for {
+ b = make([]byte, l)
+ err := windows.GetAdaptersAddresses(syscall.AF_UNSPEC, windows.GAA_FLAG_INCLUDE_PREFIX, 0, (*windows.IpAdapterAddresses)(unsafe.Pointer(&b[0])), &l)
+ if err == nil {
+ if l == 0 {
+ return nil, nil
+ }
+ break
+ }
+ if err.(syscall.Errno) != syscall.ERROR_BUFFER_OVERFLOW {
+ return nil, os.NewSyscallError("getadaptersaddresses", err)
+ }
+ if l <= uint32(len(b)) {
+ return nil, os.NewSyscallError("getadaptersaddresses", err)
+ }
+ }
+ var aas []*windows.IpAdapterAddresses
+ for aa := (*windows.IpAdapterAddresses)(unsafe.Pointer(&b[0])); aa != nil; aa = aa.Next {
+ aas = append(aas, aa)
+ }
+ return aas, nil
+}
+
+// If the ifindex is zero, interfaceTable returns mappings of all
+// network interfaces. Otherwise it returns a mapping of a specific
+// interface.
+func interfaceTable(ifindex int) ([]Interface, error) {
+ aas, err := adapterAddresses()
+ if err != nil {
+ return nil, err
+ }
+ var ift []Interface
+ for _, aa := range aas {
+ index := aa.IfIndex
+ if index == 0 { // ipv6IfIndex is a substitute for ifIndex
+ index = aa.Ipv6IfIndex
+ }
+ if ifindex == 0 || ifindex == int(index) {
+ ifi := Interface{
+ Index: int(index),
+ Name: windows.UTF16PtrToString(aa.FriendlyName),
+ }
+ if aa.OperStatus == windows.IfOperStatusUp {
+ ifi.Flags |= FlagUp
+ ifi.Flags |= FlagRunning
+ }
+ // For now we need to infer link-layer service
+ // capabilities from media types.
+ // TODO: use MIB_IF_ROW2.AccessType now that we no longer support
+ // Windows XP.
+ switch aa.IfType {
+ case windows.IF_TYPE_ETHERNET_CSMACD, windows.IF_TYPE_ISO88025_TOKENRING, windows.IF_TYPE_IEEE80211, windows.IF_TYPE_IEEE1394:
+ ifi.Flags |= FlagBroadcast | FlagMulticast
+ case windows.IF_TYPE_PPP, windows.IF_TYPE_TUNNEL:
+ ifi.Flags |= FlagPointToPoint | FlagMulticast
+ case windows.IF_TYPE_SOFTWARE_LOOPBACK:
+ ifi.Flags |= FlagLoopback | FlagMulticast
+ case windows.IF_TYPE_ATM:
+ ifi.Flags |= FlagBroadcast | FlagPointToPoint | FlagMulticast // assume all services available; LANE, point-to-point and point-to-multipoint
+ }
+ if aa.Mtu == 0xffffffff {
+ ifi.MTU = -1
+ } else {
+ ifi.MTU = int(aa.Mtu)
+ }
+ if aa.PhysicalAddressLength > 0 {
+ ifi.HardwareAddr = make(HardwareAddr, aa.PhysicalAddressLength)
+ copy(ifi.HardwareAddr, aa.PhysicalAddress[:])
+ }
+ ift = append(ift, ifi)
+ if ifindex == ifi.Index {
+ break
+ }
+ }
+ }
+ return ift, nil
+}
+
+// If the ifi is nil, interfaceAddrTable returns addresses for all
+// network interfaces. Otherwise it returns addresses for a specific
+// interface.
+func interfaceAddrTable(ifi *Interface) ([]Addr, error) {
+ aas, err := adapterAddresses()
+ if err != nil {
+ return nil, err
+ }
+ var ifat []Addr
+ for _, aa := range aas {
+ index := aa.IfIndex
+ if index == 0 { // ipv6IfIndex is a substitute for ifIndex
+ index = aa.Ipv6IfIndex
+ }
+ if ifi == nil || ifi.Index == int(index) {
+ for puni := aa.FirstUnicastAddress; puni != nil; puni = puni.Next {
+ sa, err := puni.Address.Sockaddr.Sockaddr()
+ if err != nil {
+ return nil, os.NewSyscallError("sockaddr", err)
+ }
+ switch sa := sa.(type) {
+ case *syscall.SockaddrInet4:
+ ifat = append(ifat, &IPNet{IP: IPv4(sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3]), Mask: CIDRMask(int(puni.OnLinkPrefixLength), 8*IPv4len)})
+ case *syscall.SockaddrInet6:
+ ifa := &IPNet{IP: make(IP, IPv6len), Mask: CIDRMask(int(puni.OnLinkPrefixLength), 8*IPv6len)}
+ copy(ifa.IP, sa.Addr[:])
+ ifat = append(ifat, ifa)
+ }
+ }
+ for pany := aa.FirstAnycastAddress; pany != nil; pany = pany.Next {
+ sa, err := pany.Address.Sockaddr.Sockaddr()
+ if err != nil {
+ return nil, os.NewSyscallError("sockaddr", err)
+ }
+ switch sa := sa.(type) {
+ case *syscall.SockaddrInet4:
+ ifat = append(ifat, &IPAddr{IP: IPv4(sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3])})
+ case *syscall.SockaddrInet6:
+ ifa := &IPAddr{IP: make(IP, IPv6len)}
+ copy(ifa.IP, sa.Addr[:])
+ ifat = append(ifat, ifa)
+ }
+ }
+ }
+ }
+ return ifat, nil
+}
+
+// interfaceMulticastAddrTable returns addresses for a specific
+// interface.
+func interfaceMulticastAddrTable(ifi *Interface) ([]Addr, error) {
+ aas, err := adapterAddresses()
+ if err != nil {
+ return nil, err
+ }
+ var ifat []Addr
+ for _, aa := range aas {
+ index := aa.IfIndex
+ if index == 0 { // ipv6IfIndex is a substitute for ifIndex
+ index = aa.Ipv6IfIndex
+ }
+ if ifi == nil || ifi.Index == int(index) {
+ for pmul := aa.FirstMulticastAddress; pmul != nil; pmul = pmul.Next {
+ sa, err := pmul.Address.Sockaddr.Sockaddr()
+ if err != nil {
+ return nil, os.NewSyscallError("sockaddr", err)
+ }
+ switch sa := sa.(type) {
+ case *syscall.SockaddrInet4:
+ ifat = append(ifat, &IPAddr{IP: IPv4(sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3])})
+ case *syscall.SockaddrInet6:
+ ifa := &IPAddr{IP: make(IP, IPv6len)}
+ copy(ifa.IP, sa.Addr[:])
+ ifat = append(ifat, ifa)
+ }
+ }
+ }
+ }
+ return ifat, nil
+}
diff --git a/src/net/internal/socktest/main_test.go b/src/net/internal/socktest/main_test.go
new file mode 100644
index 0000000..0197feb
--- /dev/null
+++ b/src/net/internal/socktest/main_test.go
@@ -0,0 +1,56 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !js && !plan9 && !wasip1
+
+package socktest_test
+
+import (
+ "net/internal/socktest"
+ "os"
+ "sync"
+ "syscall"
+ "testing"
+)
+
+var sw socktest.Switch
+
+func TestMain(m *testing.M) {
+ installTestHooks()
+
+ st := m.Run()
+
+ for s := range sw.Sockets() {
+ closeFunc(s)
+ }
+ uninstallTestHooks()
+ os.Exit(st)
+}
+
+func TestSwitch(t *testing.T) {
+ const N = 10
+ var wg sync.WaitGroup
+ wg.Add(N)
+ for i := 0; i < N; i++ {
+ go func() {
+ defer wg.Done()
+ for _, family := range []int{syscall.AF_INET, syscall.AF_INET6} {
+ socketFunc(family, syscall.SOCK_STREAM, syscall.IPPROTO_TCP)
+ }
+ }()
+ }
+ wg.Wait()
+}
+
+func TestSocket(t *testing.T) {
+ for _, f := range []socktest.Filter{
+ func(st *socktest.Status) (socktest.AfterFilter, error) { return nil, nil },
+ nil,
+ } {
+ sw.Set(socktest.FilterSocket, f)
+ for _, family := range []int{syscall.AF_INET, syscall.AF_INET6} {
+ socketFunc(family, syscall.SOCK_STREAM, syscall.IPPROTO_TCP)
+ }
+ }
+}
diff --git a/src/net/internal/socktest/main_unix_test.go b/src/net/internal/socktest/main_unix_test.go
new file mode 100644
index 0000000..19ffb28
--- /dev/null
+++ b/src/net/internal/socktest/main_unix_test.go
@@ -0,0 +1,24 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !js && !plan9 && !wasip1 && !windows
+
+package socktest_test
+
+import "syscall"
+
+var (
+ socketFunc func(int, int, int) (int, error)
+ closeFunc func(int) error
+)
+
+func installTestHooks() {
+ socketFunc = sw.Socket
+ closeFunc = sw.Close
+}
+
+func uninstallTestHooks() {
+ socketFunc = syscall.Socket
+ closeFunc = syscall.Close
+}
diff --git a/src/net/internal/socktest/main_windows_test.go b/src/net/internal/socktest/main_windows_test.go
new file mode 100644
index 0000000..df1cb97
--- /dev/null
+++ b/src/net/internal/socktest/main_windows_test.go
@@ -0,0 +1,22 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package socktest_test
+
+import "syscall"
+
+var (
+ socketFunc func(int, int, int) (syscall.Handle, error)
+ closeFunc func(syscall.Handle) error
+)
+
+func installTestHooks() {
+ socketFunc = sw.Socket
+ closeFunc = sw.Closesocket
+}
+
+func uninstallTestHooks() {
+ socketFunc = syscall.Socket
+ closeFunc = syscall.Closesocket
+}
diff --git a/src/net/internal/socktest/switch.go b/src/net/internal/socktest/switch.go
new file mode 100644
index 0000000..3c37b6f
--- /dev/null
+++ b/src/net/internal/socktest/switch.go
@@ -0,0 +1,169 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package socktest provides utilities for socket testing.
+package socktest
+
+import (
+ "fmt"
+ "sync"
+)
+
+// A Switch represents a callpath point switch for socket system
+// calls.
+type Switch struct {
+ once sync.Once
+
+ fmu sync.RWMutex
+ fltab map[FilterType]Filter
+
+ smu sync.RWMutex
+ sotab Sockets
+ stats stats
+}
+
+func (sw *Switch) init() {
+ sw.fltab = make(map[FilterType]Filter)
+ sw.sotab = make(Sockets)
+ sw.stats = make(stats)
+}
+
+// Stats returns a list of per-cookie socket statistics.
+func (sw *Switch) Stats() []Stat {
+ var st []Stat
+ sw.smu.RLock()
+ for _, s := range sw.stats {
+ ns := *s
+ st = append(st, ns)
+ }
+ sw.smu.RUnlock()
+ return st
+}
+
+// Sockets returns mappings of socket descriptor to socket status.
+func (sw *Switch) Sockets() Sockets {
+ sw.smu.RLock()
+ tab := make(Sockets, len(sw.sotab))
+ for i, s := range sw.sotab {
+ tab[i] = s
+ }
+ sw.smu.RUnlock()
+ return tab
+}
+
+// A Cookie represents a 3-tuple of a socket; address family, socket
+// type and protocol number.
+type Cookie uint64
+
+// Family returns an address family.
+func (c Cookie) Family() int { return int(c >> 48) }
+
+// Type returns a socket type.
+func (c Cookie) Type() int { return int(c << 16 >> 32) }
+
+// Protocol returns a protocol number.
+func (c Cookie) Protocol() int { return int(c & 0xff) }
+
+func cookie(family, sotype, proto int) Cookie {
+ return Cookie(family)<<48 | Cookie(sotype)&0xffffffff<<16 | Cookie(proto)&0xff
+}
+
+// A Status represents the status of a socket.
+type Status struct {
+ Cookie Cookie
+ Err error // error status of socket system call
+ SocketErr error // error status of socket by SO_ERROR
+}
+
+func (so Status) String() string {
+ return fmt.Sprintf("(%s, %s, %s): syscallerr=%v socketerr=%v", familyString(so.Cookie.Family()), typeString(so.Cookie.Type()), protocolString(so.Cookie.Protocol()), so.Err, so.SocketErr)
+}
+
+// A Stat represents a per-cookie socket statistics.
+type Stat struct {
+ Family int // address family
+ Type int // socket type
+ Protocol int // protocol number
+
+ Opened uint64 // number of sockets opened
+ Connected uint64 // number of sockets connected
+ Listened uint64 // number of sockets listened
+ Accepted uint64 // number of sockets accepted
+ Closed uint64 // number of sockets closed
+
+ OpenFailed uint64 // number of sockets open failed
+ ConnectFailed uint64 // number of sockets connect failed
+ ListenFailed uint64 // number of sockets listen failed
+ AcceptFailed uint64 // number of sockets accept failed
+ CloseFailed uint64 // number of sockets close failed
+}
+
+func (st Stat) String() string {
+ return fmt.Sprintf("(%s, %s, %s): opened=%d connected=%d listened=%d accepted=%d closed=%d openfailed=%d connectfailed=%d listenfailed=%d acceptfailed=%d closefailed=%d", familyString(st.Family), typeString(st.Type), protocolString(st.Protocol), st.Opened, st.Connected, st.Listened, st.Accepted, st.Closed, st.OpenFailed, st.ConnectFailed, st.ListenFailed, st.AcceptFailed, st.CloseFailed)
+}
+
+type stats map[Cookie]*Stat
+
+func (st stats) getLocked(c Cookie) *Stat {
+ s, ok := st[c]
+ if !ok {
+ s = &Stat{Family: c.Family(), Type: c.Type(), Protocol: c.Protocol()}
+ st[c] = s
+ }
+ return s
+}
+
+// A FilterType represents a filter type.
+type FilterType int
+
+const (
+ FilterSocket FilterType = iota // for Socket
+ FilterConnect // for Connect or ConnectEx
+ FilterListen // for Listen
+ FilterAccept // for Accept, Accept4 or AcceptEx
+ FilterGetsockoptInt // for GetsockoptInt
+ FilterClose // for Close or Closesocket
+)
+
+// A Filter represents a socket system call filter.
+//
+// It will only be executed before a system call for a socket that has
+// an entry in internal table.
+// If the filter returns a non-nil error, the execution of system call
+// will be canceled and the system call function returns the non-nil
+// error.
+// It can return a non-nil AfterFilter for filtering after the
+// execution of the system call.
+type Filter func(*Status) (AfterFilter, error)
+
+func (f Filter) apply(st *Status) (AfterFilter, error) {
+ if f == nil {
+ return nil, nil
+ }
+ return f(st)
+}
+
+// An AfterFilter represents a socket system call filter after an
+// execution of a system call.
+//
+// It will only be executed after a system call for a socket that has
+// an entry in internal table.
+// If the filter returns a non-nil error, the system call function
+// returns the non-nil error.
+type AfterFilter func(*Status) error
+
+func (f AfterFilter) apply(st *Status) error {
+ if f == nil {
+ return nil
+ }
+ return f(st)
+}
+
+// Set deploys the socket system call filter f for the filter type t.
+func (sw *Switch) Set(t FilterType, f Filter) {
+ sw.once.Do(sw.init)
+ sw.fmu.Lock()
+ sw.fltab[t] = f
+ sw.fmu.Unlock()
+}
diff --git a/src/net/internal/socktest/switch_posix.go b/src/net/internal/socktest/switch_posix.go
new file mode 100644
index 0000000..fcad4ce
--- /dev/null
+++ b/src/net/internal/socktest/switch_posix.go
@@ -0,0 +1,58 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !plan9
+
+package socktest
+
+import (
+ "fmt"
+ "syscall"
+)
+
+func familyString(family int) string {
+ switch family {
+ case syscall.AF_INET:
+ return "inet4"
+ case syscall.AF_INET6:
+ return "inet6"
+ case syscall.AF_UNIX:
+ return "local"
+ default:
+ return fmt.Sprintf("%d", family)
+ }
+}
+
+func typeString(sotype int) string {
+ var s string
+ switch sotype & 0xff {
+ case syscall.SOCK_STREAM:
+ s = "stream"
+ case syscall.SOCK_DGRAM:
+ s = "datagram"
+ case syscall.SOCK_RAW:
+ s = "raw"
+ case syscall.SOCK_SEQPACKET:
+ s = "seqpacket"
+ default:
+ s = fmt.Sprintf("%d", sotype&0xff)
+ }
+ if flags := uint(sotype) & ^uint(0xff); flags != 0 {
+ s += fmt.Sprintf("|%#x", flags)
+ }
+ return s
+}
+
+func protocolString(proto int) string {
+ switch proto {
+ case 0:
+ return "default"
+ case syscall.IPPROTO_TCP:
+ return "tcp"
+ case syscall.IPPROTO_UDP:
+ return "udp"
+ default:
+ return fmt.Sprintf("%d", proto)
+ }
+}
diff --git a/src/net/internal/socktest/switch_stub.go b/src/net/internal/socktest/switch_stub.go
new file mode 100644
index 0000000..8a2fc35
--- /dev/null
+++ b/src/net/internal/socktest/switch_stub.go
@@ -0,0 +1,16 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build plan9
+
+package socktest
+
+// Sockets maps a socket descriptor to the status of socket.
+type Sockets map[int]Status
+
+func familyString(family int) string { return "<nil>" }
+
+func typeString(sotype int) string { return "<nil>" }
+
+func protocolString(proto int) string { return "<nil>" }
diff --git a/src/net/internal/socktest/switch_unix.go b/src/net/internal/socktest/switch_unix.go
new file mode 100644
index 0000000..ff92877
--- /dev/null
+++ b/src/net/internal/socktest/switch_unix.go
@@ -0,0 +1,29 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build unix || (js && wasm) || wasip1
+
+package socktest
+
+// Sockets maps a socket descriptor to the status of socket.
+type Sockets map[int]Status
+
+func (sw *Switch) sockso(s int) *Status {
+ sw.smu.RLock()
+ defer sw.smu.RUnlock()
+ so, ok := sw.sotab[s]
+ if !ok {
+ return nil
+ }
+ return &so
+}
+
+// addLocked returns a new Status without locking.
+// sw.smu must be held before call.
+func (sw *Switch) addLocked(s, family, sotype, proto int) *Status {
+ sw.once.Do(sw.init)
+ so := Status{Cookie: cookie(family, sotype, proto)}
+ sw.sotab[s] = so
+ return &so
+}
diff --git a/src/net/internal/socktest/switch_windows.go b/src/net/internal/socktest/switch_windows.go
new file mode 100644
index 0000000..4f1d597
--- /dev/null
+++ b/src/net/internal/socktest/switch_windows.go
@@ -0,0 +1,29 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package socktest
+
+import "syscall"
+
+// Sockets maps a socket descriptor to the status of socket.
+type Sockets map[syscall.Handle]Status
+
+func (sw *Switch) sockso(s syscall.Handle) *Status {
+ sw.smu.RLock()
+ defer sw.smu.RUnlock()
+ so, ok := sw.sotab[s]
+ if !ok {
+ return nil
+ }
+ return &so
+}
+
+// addLocked returns a new Status without locking.
+// sw.smu must be held before call.
+func (sw *Switch) addLocked(s syscall.Handle, family, sotype, proto int) *Status {
+ sw.once.Do(sw.init)
+ so := Status{Cookie: cookie(family, sotype, proto)}
+ sw.sotab[s] = so
+ return &so
+}
diff --git a/src/net/internal/socktest/sys_cloexec.go b/src/net/internal/socktest/sys_cloexec.go
new file mode 100644
index 0000000..d57f44d
--- /dev/null
+++ b/src/net/internal/socktest/sys_cloexec.go
@@ -0,0 +1,42 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build dragonfly || freebsd || linux || netbsd || openbsd || solaris
+
+package socktest
+
+import "syscall"
+
+// Accept4 wraps syscall.Accept4.
+func (sw *Switch) Accept4(s, flags int) (ns int, sa syscall.Sockaddr, err error) {
+ so := sw.sockso(s)
+ if so == nil {
+ return syscall.Accept4(s, flags)
+ }
+ sw.fmu.RLock()
+ f := sw.fltab[FilterAccept]
+ sw.fmu.RUnlock()
+
+ af, err := f.apply(so)
+ if err != nil {
+ return -1, nil, err
+ }
+ ns, sa, so.Err = syscall.Accept4(s, flags)
+ if err = af.apply(so); err != nil {
+ if so.Err == nil {
+ syscall.Close(ns)
+ }
+ return -1, nil, err
+ }
+
+ sw.smu.Lock()
+ defer sw.smu.Unlock()
+ if so.Err != nil {
+ sw.stats.getLocked(so.Cookie).AcceptFailed++
+ return -1, nil, so.Err
+ }
+ nso := sw.addLocked(ns, so.Cookie.Family(), so.Cookie.Type(), so.Cookie.Protocol())
+ sw.stats.getLocked(nso.Cookie).Accepted++
+ return ns, sa, nil
+}
diff --git a/src/net/internal/socktest/sys_unix.go b/src/net/internal/socktest/sys_unix.go
new file mode 100644
index 0000000..712462a
--- /dev/null
+++ b/src/net/internal/socktest/sys_unix.go
@@ -0,0 +1,193 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build unix || (js && wasm) || wasip1
+
+package socktest
+
+import "syscall"
+
+// Socket wraps syscall.Socket.
+func (sw *Switch) Socket(family, sotype, proto int) (s int, err error) {
+ sw.once.Do(sw.init)
+
+ so := &Status{Cookie: cookie(family, sotype, proto)}
+ sw.fmu.RLock()
+ f := sw.fltab[FilterSocket]
+ sw.fmu.RUnlock()
+
+ af, err := f.apply(so)
+ if err != nil {
+ return -1, err
+ }
+ s, so.Err = syscall.Socket(family, sotype, proto)
+ if err = af.apply(so); err != nil {
+ if so.Err == nil {
+ syscall.Close(s)
+ }
+ return -1, err
+ }
+
+ sw.smu.Lock()
+ defer sw.smu.Unlock()
+ if so.Err != nil {
+ sw.stats.getLocked(so.Cookie).OpenFailed++
+ return -1, so.Err
+ }
+ nso := sw.addLocked(s, family, sotype, proto)
+ sw.stats.getLocked(nso.Cookie).Opened++
+ return s, nil
+}
+
+// Close wraps syscall.Close.
+func (sw *Switch) Close(s int) (err error) {
+ so := sw.sockso(s)
+ if so == nil {
+ return syscall.Close(s)
+ }
+ sw.fmu.RLock()
+ f := sw.fltab[FilterClose]
+ sw.fmu.RUnlock()
+
+ af, err := f.apply(so)
+ if err != nil {
+ return err
+ }
+ so.Err = syscall.Close(s)
+ if err = af.apply(so); err != nil {
+ return err
+ }
+
+ sw.smu.Lock()
+ defer sw.smu.Unlock()
+ if so.Err != nil {
+ sw.stats.getLocked(so.Cookie).CloseFailed++
+ return so.Err
+ }
+ delete(sw.sotab, s)
+ sw.stats.getLocked(so.Cookie).Closed++
+ return nil
+}
+
+// Connect wraps syscall.Connect.
+func (sw *Switch) Connect(s int, sa syscall.Sockaddr) (err error) {
+ so := sw.sockso(s)
+ if so == nil {
+ return syscall.Connect(s, sa)
+ }
+ sw.fmu.RLock()
+ f := sw.fltab[FilterConnect]
+ sw.fmu.RUnlock()
+
+ af, err := f.apply(so)
+ if err != nil {
+ return err
+ }
+ so.Err = syscall.Connect(s, sa)
+ if err = af.apply(so); err != nil {
+ return err
+ }
+
+ sw.smu.Lock()
+ defer sw.smu.Unlock()
+ if so.Err != nil {
+ sw.stats.getLocked(so.Cookie).ConnectFailed++
+ return so.Err
+ }
+ sw.stats.getLocked(so.Cookie).Connected++
+ return nil
+}
+
+// Listen wraps syscall.Listen.
+func (sw *Switch) Listen(s, backlog int) (err error) {
+ so := sw.sockso(s)
+ if so == nil {
+ return syscall.Listen(s, backlog)
+ }
+ sw.fmu.RLock()
+ f := sw.fltab[FilterListen]
+ sw.fmu.RUnlock()
+
+ af, err := f.apply(so)
+ if err != nil {
+ return err
+ }
+ so.Err = syscall.Listen(s, backlog)
+ if err = af.apply(so); err != nil {
+ return err
+ }
+
+ sw.smu.Lock()
+ defer sw.smu.Unlock()
+ if so.Err != nil {
+ sw.stats.getLocked(so.Cookie).ListenFailed++
+ return so.Err
+ }
+ sw.stats.getLocked(so.Cookie).Listened++
+ return nil
+}
+
+// Accept wraps syscall.Accept.
+func (sw *Switch) Accept(s int) (ns int, sa syscall.Sockaddr, err error) {
+ so := sw.sockso(s)
+ if so == nil {
+ return syscall.Accept(s)
+ }
+ sw.fmu.RLock()
+ f := sw.fltab[FilterAccept]
+ sw.fmu.RUnlock()
+
+ af, err := f.apply(so)
+ if err != nil {
+ return -1, nil, err
+ }
+ ns, sa, so.Err = syscall.Accept(s)
+ if err = af.apply(so); err != nil {
+ if so.Err == nil {
+ syscall.Close(ns)
+ }
+ return -1, nil, err
+ }
+
+ sw.smu.Lock()
+ defer sw.smu.Unlock()
+ if so.Err != nil {
+ sw.stats.getLocked(so.Cookie).AcceptFailed++
+ return -1, nil, so.Err
+ }
+ nso := sw.addLocked(ns, so.Cookie.Family(), so.Cookie.Type(), so.Cookie.Protocol())
+ sw.stats.getLocked(nso.Cookie).Accepted++
+ return ns, sa, nil
+}
+
+// GetsockoptInt wraps syscall.GetsockoptInt.
+func (sw *Switch) GetsockoptInt(s, level, opt int) (soerr int, err error) {
+ so := sw.sockso(s)
+ if so == nil {
+ return syscall.GetsockoptInt(s, level, opt)
+ }
+ sw.fmu.RLock()
+ f := sw.fltab[FilterGetsockoptInt]
+ sw.fmu.RUnlock()
+
+ af, err := f.apply(so)
+ if err != nil {
+ return -1, err
+ }
+ soerr, so.Err = syscall.GetsockoptInt(s, level, opt)
+ so.SocketErr = syscall.Errno(soerr)
+ if err = af.apply(so); err != nil {
+ return -1, err
+ }
+
+ if so.Err != nil {
+ return -1, so.Err
+ }
+ if opt == syscall.SO_ERROR && (so.SocketErr == syscall.Errno(0) || so.SocketErr == syscall.EISCONN) {
+ sw.smu.Lock()
+ sw.stats.getLocked(so.Cookie).Connected++
+ sw.smu.Unlock()
+ }
+ return soerr, nil
+}
diff --git a/src/net/internal/socktest/sys_windows.go b/src/net/internal/socktest/sys_windows.go
new file mode 100644
index 0000000..8c1c862
--- /dev/null
+++ b/src/net/internal/socktest/sys_windows.go
@@ -0,0 +1,221 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package socktest
+
+import (
+ "internal/syscall/windows"
+ "syscall"
+)
+
+// Socket wraps syscall.Socket.
+func (sw *Switch) Socket(family, sotype, proto int) (s syscall.Handle, err error) {
+ sw.once.Do(sw.init)
+
+ so := &Status{Cookie: cookie(family, sotype, proto)}
+ sw.fmu.RLock()
+ f, _ := sw.fltab[FilterSocket]
+ sw.fmu.RUnlock()
+
+ af, err := f.apply(so)
+ if err != nil {
+ return syscall.InvalidHandle, err
+ }
+ s, so.Err = syscall.Socket(family, sotype, proto)
+ if err = af.apply(so); err != nil {
+ if so.Err == nil {
+ syscall.Closesocket(s)
+ }
+ return syscall.InvalidHandle, err
+ }
+
+ sw.smu.Lock()
+ defer sw.smu.Unlock()
+ if so.Err != nil {
+ sw.stats.getLocked(so.Cookie).OpenFailed++
+ return syscall.InvalidHandle, so.Err
+ }
+ nso := sw.addLocked(s, family, sotype, proto)
+ sw.stats.getLocked(nso.Cookie).Opened++
+ return s, nil
+}
+
+// WSASocket wraps syscall.WSASocket.
+func (sw *Switch) WSASocket(family, sotype, proto int32, protinfo *syscall.WSAProtocolInfo, group uint32, flags uint32) (s syscall.Handle, err error) {
+ sw.once.Do(sw.init)
+
+ so := &Status{Cookie: cookie(int(family), int(sotype), int(proto))}
+ sw.fmu.RLock()
+ f, _ := sw.fltab[FilterSocket]
+ sw.fmu.RUnlock()
+
+ af, err := f.apply(so)
+ if err != nil {
+ return syscall.InvalidHandle, err
+ }
+ s, so.Err = windows.WSASocket(family, sotype, proto, protinfo, group, flags)
+ if err = af.apply(so); err != nil {
+ if so.Err == nil {
+ syscall.Closesocket(s)
+ }
+ return syscall.InvalidHandle, err
+ }
+
+ sw.smu.Lock()
+ defer sw.smu.Unlock()
+ if so.Err != nil {
+ sw.stats.getLocked(so.Cookie).OpenFailed++
+ return syscall.InvalidHandle, so.Err
+ }
+ nso := sw.addLocked(s, int(family), int(sotype), int(proto))
+ sw.stats.getLocked(nso.Cookie).Opened++
+ return s, nil
+}
+
+// Closesocket wraps syscall.Closesocket.
+func (sw *Switch) Closesocket(s syscall.Handle) (err error) {
+ so := sw.sockso(s)
+ if so == nil {
+ return syscall.Closesocket(s)
+ }
+ sw.fmu.RLock()
+ f, _ := sw.fltab[FilterClose]
+ sw.fmu.RUnlock()
+
+ af, err := f.apply(so)
+ if err != nil {
+ return err
+ }
+ so.Err = syscall.Closesocket(s)
+ if err = af.apply(so); err != nil {
+ return err
+ }
+
+ sw.smu.Lock()
+ defer sw.smu.Unlock()
+ if so.Err != nil {
+ sw.stats.getLocked(so.Cookie).CloseFailed++
+ return so.Err
+ }
+ delete(sw.sotab, s)
+ sw.stats.getLocked(so.Cookie).Closed++
+ return nil
+}
+
+// Connect wraps syscall.Connect.
+func (sw *Switch) Connect(s syscall.Handle, sa syscall.Sockaddr) (err error) {
+ so := sw.sockso(s)
+ if so == nil {
+ return syscall.Connect(s, sa)
+ }
+ sw.fmu.RLock()
+ f, _ := sw.fltab[FilterConnect]
+ sw.fmu.RUnlock()
+
+ af, err := f.apply(so)
+ if err != nil {
+ return err
+ }
+ so.Err = syscall.Connect(s, sa)
+ if err = af.apply(so); err != nil {
+ return err
+ }
+
+ sw.smu.Lock()
+ defer sw.smu.Unlock()
+ if so.Err != nil {
+ sw.stats.getLocked(so.Cookie).ConnectFailed++
+ return so.Err
+ }
+ sw.stats.getLocked(so.Cookie).Connected++
+ return nil
+}
+
+// ConnectEx wraps syscall.ConnectEx.
+func (sw *Switch) ConnectEx(s syscall.Handle, sa syscall.Sockaddr, b *byte, n uint32, nwr *uint32, o *syscall.Overlapped) (err error) {
+ so := sw.sockso(s)
+ if so == nil {
+ return syscall.ConnectEx(s, sa, b, n, nwr, o)
+ }
+ sw.fmu.RLock()
+ f, _ := sw.fltab[FilterConnect]
+ sw.fmu.RUnlock()
+
+ af, err := f.apply(so)
+ if err != nil {
+ return err
+ }
+ so.Err = syscall.ConnectEx(s, sa, b, n, nwr, o)
+ if err = af.apply(so); err != nil {
+ return err
+ }
+
+ sw.smu.Lock()
+ defer sw.smu.Unlock()
+ if so.Err != nil {
+ sw.stats.getLocked(so.Cookie).ConnectFailed++
+ return so.Err
+ }
+ sw.stats.getLocked(so.Cookie).Connected++
+ return nil
+}
+
+// Listen wraps syscall.Listen.
+func (sw *Switch) Listen(s syscall.Handle, backlog int) (err error) {
+ so := sw.sockso(s)
+ if so == nil {
+ return syscall.Listen(s, backlog)
+ }
+ sw.fmu.RLock()
+ f, _ := sw.fltab[FilterListen]
+ sw.fmu.RUnlock()
+
+ af, err := f.apply(so)
+ if err != nil {
+ return err
+ }
+ so.Err = syscall.Listen(s, backlog)
+ if err = af.apply(so); err != nil {
+ return err
+ }
+
+ sw.smu.Lock()
+ defer sw.smu.Unlock()
+ if so.Err != nil {
+ sw.stats.getLocked(so.Cookie).ListenFailed++
+ return so.Err
+ }
+ sw.stats.getLocked(so.Cookie).Listened++
+ return nil
+}
+
+// AcceptEx wraps syscall.AcceptEx.
+func (sw *Switch) AcceptEx(ls syscall.Handle, as syscall.Handle, b *byte, rxdatalen uint32, laddrlen uint32, raddrlen uint32, rcvd *uint32, overlapped *syscall.Overlapped) error {
+ so := sw.sockso(ls)
+ if so == nil {
+ return syscall.AcceptEx(ls, as, b, rxdatalen, laddrlen, raddrlen, rcvd, overlapped)
+ }
+ sw.fmu.RLock()
+ f, _ := sw.fltab[FilterAccept]
+ sw.fmu.RUnlock()
+
+ af, err := f.apply(so)
+ if err != nil {
+ return err
+ }
+ so.Err = syscall.AcceptEx(ls, as, b, rxdatalen, laddrlen, raddrlen, rcvd, overlapped)
+ if err = af.apply(so); err != nil {
+ return err
+ }
+
+ sw.smu.Lock()
+ defer sw.smu.Unlock()
+ if so.Err != nil {
+ sw.stats.getLocked(so.Cookie).AcceptFailed++
+ return so.Err
+ }
+ nso := sw.addLocked(as, so.Cookie.Family(), so.Cookie.Type(), so.Cookie.Protocol())
+ sw.stats.getLocked(nso.Cookie).Accepted++
+ return nil
+}
diff --git a/src/net/ip.go b/src/net/ip.go
new file mode 100644
index 0000000..d51ba10
--- /dev/null
+++ b/src/net/ip.go
@@ -0,0 +1,542 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// IP address manipulations
+//
+// IPv4 addresses are 4 bytes; IPv6 addresses are 16 bytes.
+// An IPv4 address can be converted to an IPv6 address by
+// adding a canonical prefix (10 zeros, 2 0xFFs).
+// This library accepts either size of byte slice but always
+// returns 16-byte addresses.
+
+package net
+
+import (
+ "internal/bytealg"
+ "internal/itoa"
+ "net/netip"
+)
+
+// IP address lengths (bytes).
+const (
+ IPv4len = 4
+ IPv6len = 16
+)
+
+// An IP is a single IP address, a slice of bytes.
+// Functions in this package accept either 4-byte (IPv4)
+// or 16-byte (IPv6) slices as input.
+//
+// Note that in this documentation, referring to an
+// IP address as an IPv4 address or an IPv6 address
+// is a semantic property of the address, not just the
+// length of the byte slice: a 16-byte slice can still
+// be an IPv4 address.
+type IP []byte
+
+// An IPMask is a bitmask that can be used to manipulate
+// IP addresses for IP addressing and routing.
+//
+// See type IPNet and func ParseCIDR for details.
+type IPMask []byte
+
+// An IPNet represents an IP network.
+type IPNet struct {
+ IP IP // network number
+ Mask IPMask // network mask
+}
+
+// IPv4 returns the IP address (in 16-byte form) of the
+// IPv4 address a.b.c.d.
+func IPv4(a, b, c, d byte) IP {
+ p := make(IP, IPv6len)
+ copy(p, v4InV6Prefix)
+ p[12] = a
+ p[13] = b
+ p[14] = c
+ p[15] = d
+ return p
+}
+
+var v4InV6Prefix = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff}
+
+// IPv4Mask returns the IP mask (in 4-byte form) of the
+// IPv4 mask a.b.c.d.
+func IPv4Mask(a, b, c, d byte) IPMask {
+ p := make(IPMask, IPv4len)
+ p[0] = a
+ p[1] = b
+ p[2] = c
+ p[3] = d
+ return p
+}
+
+// CIDRMask returns an IPMask consisting of 'ones' 1 bits
+// followed by 0s up to a total length of 'bits' bits.
+// For a mask of this form, CIDRMask is the inverse of IPMask.Size.
+func CIDRMask(ones, bits int) IPMask {
+ if bits != 8*IPv4len && bits != 8*IPv6len {
+ return nil
+ }
+ if ones < 0 || ones > bits {
+ return nil
+ }
+ l := bits / 8
+ m := make(IPMask, l)
+ n := uint(ones)
+ for i := 0; i < l; i++ {
+ if n >= 8 {
+ m[i] = 0xff
+ n -= 8
+ continue
+ }
+ m[i] = ^byte(0xff >> n)
+ n = 0
+ }
+ return m
+}
+
+// Well-known IPv4 addresses
+var (
+ IPv4bcast = IPv4(255, 255, 255, 255) // limited broadcast
+ IPv4allsys = IPv4(224, 0, 0, 1) // all systems
+ IPv4allrouter = IPv4(224, 0, 0, 2) // all routers
+ IPv4zero = IPv4(0, 0, 0, 0) // all zeros
+)
+
+// Well-known IPv6 addresses
+var (
+ IPv6zero = IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
+ IPv6unspecified = IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
+ IPv6loopback = IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
+ IPv6interfacelocalallnodes = IP{0xff, 0x01, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01}
+ IPv6linklocalallnodes = IP{0xff, 0x02, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01}
+ IPv6linklocalallrouters = IP{0xff, 0x02, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x02}
+)
+
+// IsUnspecified reports whether ip is an unspecified address, either
+// the IPv4 address "0.0.0.0" or the IPv6 address "::".
+func (ip IP) IsUnspecified() bool {
+ return ip.Equal(IPv4zero) || ip.Equal(IPv6unspecified)
+}
+
+// IsLoopback reports whether ip is a loopback address.
+func (ip IP) IsLoopback() bool {
+ if ip4 := ip.To4(); ip4 != nil {
+ return ip4[0] == 127
+ }
+ return ip.Equal(IPv6loopback)
+}
+
+// IsPrivate reports whether ip is a private address, according to
+// RFC 1918 (IPv4 addresses) and RFC 4193 (IPv6 addresses).
+func (ip IP) IsPrivate() bool {
+ if ip4 := ip.To4(); ip4 != nil {
+ // Following RFC 1918, Section 3. Private Address Space which says:
+ // The Internet Assigned Numbers Authority (IANA) has reserved the
+ // following three blocks of the IP address space for private internets:
+ // 10.0.0.0 - 10.255.255.255 (10/8 prefix)
+ // 172.16.0.0 - 172.31.255.255 (172.16/12 prefix)
+ // 192.168.0.0 - 192.168.255.255 (192.168/16 prefix)
+ return ip4[0] == 10 ||
+ (ip4[0] == 172 && ip4[1]&0xf0 == 16) ||
+ (ip4[0] == 192 && ip4[1] == 168)
+ }
+ // Following RFC 4193, Section 8. IANA Considerations which says:
+ // The IANA has assigned the FC00::/7 prefix to "Unique Local Unicast".
+ return len(ip) == IPv6len && ip[0]&0xfe == 0xfc
+}
+
+// IsMulticast reports whether ip is a multicast address.
+func (ip IP) IsMulticast() bool {
+ if ip4 := ip.To4(); ip4 != nil {
+ return ip4[0]&0xf0 == 0xe0
+ }
+ return len(ip) == IPv6len && ip[0] == 0xff
+}
+
+// IsInterfaceLocalMulticast reports whether ip is
+// an interface-local multicast address.
+func (ip IP) IsInterfaceLocalMulticast() bool {
+ return len(ip) == IPv6len && ip[0] == 0xff && ip[1]&0x0f == 0x01
+}
+
+// IsLinkLocalMulticast reports whether ip is a link-local
+// multicast address.
+func (ip IP) IsLinkLocalMulticast() bool {
+ if ip4 := ip.To4(); ip4 != nil {
+ return ip4[0] == 224 && ip4[1] == 0 && ip4[2] == 0
+ }
+ return len(ip) == IPv6len && ip[0] == 0xff && ip[1]&0x0f == 0x02
+}
+
+// IsLinkLocalUnicast reports whether ip is a link-local
+// unicast address.
+func (ip IP) IsLinkLocalUnicast() bool {
+ if ip4 := ip.To4(); ip4 != nil {
+ return ip4[0] == 169 && ip4[1] == 254
+ }
+ return len(ip) == IPv6len && ip[0] == 0xfe && ip[1]&0xc0 == 0x80
+}
+
+// IsGlobalUnicast reports whether ip is a global unicast
+// address.
+//
+// The identification of global unicast addresses uses address type
+// identification as defined in RFC 1122, RFC 4632 and RFC 4291 with
+// the exception of IPv4 directed broadcast addresses.
+// It returns true even if ip is in IPv4 private address space or
+// local IPv6 unicast address space.
+func (ip IP) IsGlobalUnicast() bool {
+ return (len(ip) == IPv4len || len(ip) == IPv6len) &&
+ !ip.Equal(IPv4bcast) &&
+ !ip.IsUnspecified() &&
+ !ip.IsLoopback() &&
+ !ip.IsMulticast() &&
+ !ip.IsLinkLocalUnicast()
+}
+
+// Is p all zeros?
+func isZeros(p IP) bool {
+ for i := 0; i < len(p); i++ {
+ if p[i] != 0 {
+ return false
+ }
+ }
+ return true
+}
+
+// To4 converts the IPv4 address ip to a 4-byte representation.
+// If ip is not an IPv4 address, To4 returns nil.
+func (ip IP) To4() IP {
+ if len(ip) == IPv4len {
+ return ip
+ }
+ if len(ip) == IPv6len &&
+ isZeros(ip[0:10]) &&
+ ip[10] == 0xff &&
+ ip[11] == 0xff {
+ return ip[12:16]
+ }
+ return nil
+}
+
+// To16 converts the IP address ip to a 16-byte representation.
+// If ip is not an IP address (it is the wrong length), To16 returns nil.
+func (ip IP) To16() IP {
+ if len(ip) == IPv4len {
+ return IPv4(ip[0], ip[1], ip[2], ip[3])
+ }
+ if len(ip) == IPv6len {
+ return ip
+ }
+ return nil
+}
+
+// Default route masks for IPv4.
+var (
+ classAMask = IPv4Mask(0xff, 0, 0, 0)
+ classBMask = IPv4Mask(0xff, 0xff, 0, 0)
+ classCMask = IPv4Mask(0xff, 0xff, 0xff, 0)
+)
+
+// DefaultMask returns the default IP mask for the IP address ip.
+// Only IPv4 addresses have default masks; DefaultMask returns
+// nil if ip is not a valid IPv4 address.
+func (ip IP) DefaultMask() IPMask {
+ if ip = ip.To4(); ip == nil {
+ return nil
+ }
+ switch {
+ case ip[0] < 0x80:
+ return classAMask
+ case ip[0] < 0xC0:
+ return classBMask
+ default:
+ return classCMask
+ }
+}
+
+func allFF(b []byte) bool {
+ for _, c := range b {
+ if c != 0xff {
+ return false
+ }
+ }
+ return true
+}
+
+// Mask returns the result of masking the IP address ip with mask.
+func (ip IP) Mask(mask IPMask) IP {
+ if len(mask) == IPv6len && len(ip) == IPv4len && allFF(mask[:12]) {
+ mask = mask[12:]
+ }
+ if len(mask) == IPv4len && len(ip) == IPv6len && bytealg.Equal(ip[:12], v4InV6Prefix) {
+ ip = ip[12:]
+ }
+ n := len(ip)
+ if n != len(mask) {
+ return nil
+ }
+ out := make(IP, n)
+ for i := 0; i < n; i++ {
+ out[i] = ip[i] & mask[i]
+ }
+ return out
+}
+
+// String returns the string form of the IP address ip.
+// It returns one of 4 forms:
+// - "<nil>", if ip has length 0
+// - dotted decimal ("192.0.2.1"), if ip is an IPv4 or IP4-mapped IPv6 address
+// - IPv6 conforming to RFC 5952 ("2001:db8::1"), if ip is a valid IPv6 address
+// - the hexadecimal form of ip, without punctuation, if no other cases apply
+func (ip IP) String() string {
+ if len(ip) == 0 {
+ return "<nil>"
+ }
+
+ if len(ip) != IPv4len && len(ip) != IPv6len {
+ return "?" + hexString(ip)
+ }
+ // If IPv4, use dotted notation.
+ if p4 := ip.To4(); len(p4) == IPv4len {
+ return netip.AddrFrom4([4]byte(p4)).String()
+ }
+ return netip.AddrFrom16([16]byte(ip)).String()
+}
+
+func hexString(b []byte) string {
+ s := make([]byte, len(b)*2)
+ for i, tn := range b {
+ s[i*2], s[i*2+1] = hexDigit[tn>>4], hexDigit[tn&0xf]
+ }
+ return string(s)
+}
+
+// ipEmptyString is like ip.String except that it returns
+// an empty string when ip is unset.
+func ipEmptyString(ip IP) string {
+ if len(ip) == 0 {
+ return ""
+ }
+ return ip.String()
+}
+
+// MarshalText implements the encoding.TextMarshaler interface.
+// The encoding is the same as returned by String, with one exception:
+// When len(ip) is zero, it returns an empty slice.
+func (ip IP) MarshalText() ([]byte, error) {
+ if len(ip) == 0 {
+ return []byte(""), nil
+ }
+ if len(ip) != IPv4len && len(ip) != IPv6len {
+ return nil, &AddrError{Err: "invalid IP address", Addr: hexString(ip)}
+ }
+ return []byte(ip.String()), nil
+}
+
+// UnmarshalText implements the encoding.TextUnmarshaler interface.
+// The IP address is expected in a form accepted by ParseIP.
+func (ip *IP) UnmarshalText(text []byte) error {
+ if len(text) == 0 {
+ *ip = nil
+ return nil
+ }
+ s := string(text)
+ x := ParseIP(s)
+ if x == nil {
+ return &ParseError{Type: "IP address", Text: s}
+ }
+ *ip = x
+ return nil
+}
+
+// Equal reports whether ip and x are the same IP address.
+// An IPv4 address and that same address in IPv6 form are
+// considered to be equal.
+func (ip IP) Equal(x IP) bool {
+ if len(ip) == len(x) {
+ return bytealg.Equal(ip, x)
+ }
+ if len(ip) == IPv4len && len(x) == IPv6len {
+ return bytealg.Equal(x[0:12], v4InV6Prefix) && bytealg.Equal(ip, x[12:])
+ }
+ if len(ip) == IPv6len && len(x) == IPv4len {
+ return bytealg.Equal(ip[0:12], v4InV6Prefix) && bytealg.Equal(ip[12:], x)
+ }
+ return false
+}
+
+func (ip IP) matchAddrFamily(x IP) bool {
+ return ip.To4() != nil && x.To4() != nil || ip.To16() != nil && ip.To4() == nil && x.To16() != nil && x.To4() == nil
+}
+
+// If mask is a sequence of 1 bits followed by 0 bits,
+// return the number of 1 bits.
+func simpleMaskLength(mask IPMask) int {
+ var n int
+ for i, v := range mask {
+ if v == 0xff {
+ n += 8
+ continue
+ }
+ // found non-ff byte
+ // count 1 bits
+ for v&0x80 != 0 {
+ n++
+ v <<= 1
+ }
+ // rest must be 0 bits
+ if v != 0 {
+ return -1
+ }
+ for i++; i < len(mask); i++ {
+ if mask[i] != 0 {
+ return -1
+ }
+ }
+ break
+ }
+ return n
+}
+
+// Size returns the number of leading ones and total bits in the mask.
+// If the mask is not in the canonical form--ones followed by zeros--then
+// Size returns 0, 0.
+func (m IPMask) Size() (ones, bits int) {
+ ones, bits = simpleMaskLength(m), len(m)*8
+ if ones == -1 {
+ return 0, 0
+ }
+ return
+}
+
+// String returns the hexadecimal form of m, with no punctuation.
+func (m IPMask) String() string {
+ if len(m) == 0 {
+ return "<nil>"
+ }
+ return hexString(m)
+}
+
+func networkNumberAndMask(n *IPNet) (ip IP, m IPMask) {
+ if ip = n.IP.To4(); ip == nil {
+ ip = n.IP
+ if len(ip) != IPv6len {
+ return nil, nil
+ }
+ }
+ m = n.Mask
+ switch len(m) {
+ case IPv4len:
+ if len(ip) != IPv4len {
+ return nil, nil
+ }
+ case IPv6len:
+ if len(ip) == IPv4len {
+ m = m[12:]
+ }
+ default:
+ return nil, nil
+ }
+ return
+}
+
+// Contains reports whether the network includes ip.
+func (n *IPNet) Contains(ip IP) bool {
+ nn, m := networkNumberAndMask(n)
+ if x := ip.To4(); x != nil {
+ ip = x
+ }
+ l := len(ip)
+ if l != len(nn) {
+ return false
+ }
+ for i := 0; i < l; i++ {
+ if nn[i]&m[i] != ip[i]&m[i] {
+ return false
+ }
+ }
+ return true
+}
+
+// Network returns the address's network name, "ip+net".
+func (n *IPNet) Network() string { return "ip+net" }
+
+// String returns the CIDR notation of n like "192.0.2.0/24"
+// or "2001:db8::/48" as defined in RFC 4632 and RFC 4291.
+// If the mask is not in the canonical form, it returns the
+// string which consists of an IP address, followed by a slash
+// character and a mask expressed as hexadecimal form with no
+// punctuation like "198.51.100.0/c000ff00".
+func (n *IPNet) String() string {
+ if n == nil {
+ return "<nil>"
+ }
+ nn, m := networkNumberAndMask(n)
+ if nn == nil || m == nil {
+ return "<nil>"
+ }
+ l := simpleMaskLength(m)
+ if l == -1 {
+ return nn.String() + "/" + m.String()
+ }
+ return nn.String() + "/" + itoa.Uitoa(uint(l))
+}
+
+// ParseIP parses s as an IP address, returning the result.
+// The string s can be in IPv4 dotted decimal ("192.0.2.1"), IPv6
+// ("2001:db8::68"), or IPv4-mapped IPv6 ("::ffff:192.0.2.1") form.
+// If s is not a valid textual representation of an IP address,
+// ParseIP returns nil.
+func ParseIP(s string) IP {
+ if addr, valid := parseIP(s); valid {
+ return IP(addr[:])
+ }
+ return nil
+}
+
+func parseIP(s string) ([16]byte, bool) {
+ ip, err := netip.ParseAddr(s)
+ if err != nil || ip.Zone() != "" {
+ return [16]byte{}, false
+ }
+ return ip.As16(), true
+}
+
+// ParseCIDR parses s as a CIDR notation IP address and prefix length,
+// like "192.0.2.0/24" or "2001:db8::/32", as defined in
+// RFC 4632 and RFC 4291.
+//
+// It returns the IP address and the network implied by the IP and
+// prefix length.
+// For example, ParseCIDR("192.0.2.1/24") returns the IP address
+// 192.0.2.1 and the network 192.0.2.0/24.
+func ParseCIDR(s string) (IP, *IPNet, error) {
+ i := bytealg.IndexByteString(s, '/')
+ if i < 0 {
+ return nil, nil, &ParseError{Type: "CIDR address", Text: s}
+ }
+ addr, mask := s[:i], s[i+1:]
+
+ ipAddr, err := netip.ParseAddr(addr)
+ if err != nil || ipAddr.Zone() != "" {
+ return nil, nil, &ParseError{Type: "CIDR address", Text: s}
+ }
+
+ n, i, ok := dtoi(mask)
+ if !ok || i != len(mask) || n < 0 || n > ipAddr.BitLen() {
+ return nil, nil, &ParseError{Type: "CIDR address", Text: s}
+ }
+ m := CIDRMask(n, ipAddr.BitLen())
+ addr16 := ipAddr.As16()
+ return IP(addr16[:]), &IPNet{IP: IP(addr16[:]).Mask(m), Mask: m}, nil
+}
+
+func copyIP(x IP) IP {
+ y := make(IP, len(x))
+ copy(y, x)
+ return y
+}
diff --git a/src/net/ip_test.go b/src/net/ip_test.go
new file mode 100644
index 0000000..1373059
--- /dev/null
+++ b/src/net/ip_test.go
@@ -0,0 +1,784 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !js && !wasip1
+
+package net
+
+import (
+ "bytes"
+ "math/rand"
+ "reflect"
+ "runtime"
+ "testing"
+)
+
+var parseIPTests = []struct {
+ in string
+ out IP
+}{
+ {"127.0.1.2", IPv4(127, 0, 1, 2)},
+ {"127.0.0.1", IPv4(127, 0, 0, 1)},
+ {"::ffff:127.1.2.3", IPv4(127, 1, 2, 3)},
+ {"::ffff:7f01:0203", IPv4(127, 1, 2, 3)},
+ {"0:0:0:0:0000:ffff:127.1.2.3", IPv4(127, 1, 2, 3)},
+ {"0:0:0:0:000000:ffff:127.1.2.3", IPv4(127, 1, 2, 3)},
+ {"0:0:0:0::ffff:127.1.2.3", IPv4(127, 1, 2, 3)},
+
+ {"2001:4860:0:2001::68", IP{0x20, 0x01, 0x48, 0x60, 0, 0, 0x20, 0x01, 0, 0, 0, 0, 0, 0, 0x00, 0x68}},
+ {"2001:4860:0000:2001:0000:0000:0000:0068", IP{0x20, 0x01, 0x48, 0x60, 0, 0, 0x20, 0x01, 0, 0, 0, 0, 0, 0, 0x00, 0x68}},
+
+ {"-0.0.0.0", nil},
+ {"0.-1.0.0", nil},
+ {"0.0.-2.0", nil},
+ {"0.0.0.-3", nil},
+ {"127.0.0.256", nil},
+ {"abc", nil},
+ {"123:", nil},
+ {"fe80::1%lo0", nil},
+ {"fe80::1%911", nil},
+ {"", nil},
+ {"a1:a2:a3:a4::b1:b2:b3:b4", nil}, // Issue 6628
+ {"127.001.002.003", nil},
+ {"::ffff:127.001.002.003", nil},
+ {"123.000.000.000", nil},
+ {"1.2..4", nil},
+ {"0123.0.0.1", nil},
+}
+
+func TestParseIP(t *testing.T) {
+ for _, tt := range parseIPTests {
+ if out := ParseIP(tt.in); !reflect.DeepEqual(out, tt.out) {
+ t.Errorf("ParseIP(%q) = %v, want %v", tt.in, out, tt.out)
+ }
+ if tt.in == "" {
+ // Tested in TestMarshalEmptyIP below.
+ continue
+ }
+ var out IP
+ if err := out.UnmarshalText([]byte(tt.in)); !reflect.DeepEqual(out, tt.out) || (tt.out == nil) != (err != nil) {
+ t.Errorf("IP.UnmarshalText(%q) = %v, %v, want %v", tt.in, out, err, tt.out)
+ }
+ }
+}
+
+func TestLookupWithIP(t *testing.T) {
+ _, err := LookupIP("")
+ if err == nil {
+ t.Errorf(`LookupIP("") succeeded, should fail`)
+ }
+ _, err = LookupHost("")
+ if err == nil {
+ t.Errorf(`LookupIP("") succeeded, should fail`)
+ }
+
+ // Test that LookupHost and LookupIP, which normally
+ // expect host names, work with IP addresses.
+ for _, tt := range parseIPTests {
+ if tt.out != nil {
+ addrs, err := LookupHost(tt.in)
+ if len(addrs) != 1 || addrs[0] != tt.in || err != nil {
+ t.Errorf("LookupHost(%q) = %v, %v, want %v, nil", tt.in, addrs, err, []string{tt.in})
+ }
+ } else if !testing.Short() {
+ // We can't control what the host resolver does; if it can resolve, say,
+ // 127.0.0.256 or fe80::1%911 or a host named 'abc', who are we to judge?
+ // Warn about these discrepancies but don't fail the test.
+ addrs, err := LookupHost(tt.in)
+ if err == nil {
+ t.Logf("warning: LookupHost(%q) = %v, want error", tt.in, addrs)
+ }
+ }
+
+ if tt.out != nil {
+ ips, err := LookupIP(tt.in)
+ if len(ips) != 1 || !reflect.DeepEqual(ips[0], tt.out) || err != nil {
+ t.Errorf("LookupIP(%q) = %v, %v, want %v, nil", tt.in, ips, err, []IP{tt.out})
+ }
+ } else if !testing.Short() {
+ ips, err := LookupIP(tt.in)
+ // We can't control what the host resolver does. See above.
+ if err == nil {
+ t.Logf("warning: LookupIP(%q) = %v, want error", tt.in, ips)
+ }
+ }
+ }
+}
+
+func BenchmarkParseIP(b *testing.B) {
+ testHookUninstaller.Do(uninstallTestHooks)
+
+ for i := 0; i < b.N; i++ {
+ for _, tt := range parseIPTests {
+ ParseIP(tt.in)
+ }
+ }
+}
+
+func BenchmarkParseIPValidIPv4(b *testing.B) {
+ testHookUninstaller.Do(uninstallTestHooks)
+
+ for i := 0; i < b.N; i++ {
+ ParseIP("192.0.2.1")
+ }
+}
+
+func BenchmarkParseIPValidIPv6(b *testing.B) {
+ testHookUninstaller.Do(uninstallTestHooks)
+
+ for i := 0; i < b.N; i++ {
+ ParseIP("2001:DB8::1")
+ }
+}
+
+// Issue 6339
+func TestMarshalEmptyIP(t *testing.T) {
+ for _, in := range [][]byte{nil, []byte("")} {
+ var out = IP{1, 2, 3, 4}
+ if err := out.UnmarshalText(in); err != nil || out != nil {
+ t.Errorf("UnmarshalText(%v) = %v, %v; want nil, nil", in, out, err)
+ }
+ }
+ var ip IP
+ got, err := ip.MarshalText()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !reflect.DeepEqual(got, []byte("")) {
+ t.Errorf(`got %#v, want []byte("")`, got)
+ }
+}
+
+var ipStringTests = []*struct {
+ in IP // see RFC 791 and RFC 4291
+ str string // see RFC 791, RFC 4291 and RFC 5952
+ byt []byte
+ error
+}{
+ // IPv4 address
+ {
+ IP{192, 0, 2, 1},
+ "192.0.2.1",
+ []byte("192.0.2.1"),
+ nil,
+ },
+ {
+ IP{0, 0, 0, 0},
+ "0.0.0.0",
+ []byte("0.0.0.0"),
+ nil,
+ },
+
+ // IPv4-mapped IPv6 address
+ {
+ IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 192, 0, 2, 1},
+ "192.0.2.1",
+ []byte("192.0.2.1"),
+ nil,
+ },
+ {
+ IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 0, 0, 0, 0},
+ "0.0.0.0",
+ []byte("0.0.0.0"),
+ nil,
+ },
+
+ // IPv6 address
+ {
+ IP{0x20, 0x1, 0xd, 0xb8, 0, 0, 0, 0, 0, 0, 0x1, 0x23, 0, 0x12, 0, 0x1},
+ "2001:db8::123:12:1",
+ []byte("2001:db8::123:12:1"),
+ nil,
+ },
+ {
+ IP{0x20, 0x1, 0xd, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x1},
+ "2001:db8::1",
+ []byte("2001:db8::1"),
+ nil,
+ },
+ {
+ IP{0x20, 0x1, 0xd, 0xb8, 0, 0, 0, 0x1, 0, 0, 0, 0x1, 0, 0, 0, 0x1},
+ "2001:db8:0:1:0:1:0:1",
+ []byte("2001:db8:0:1:0:1:0:1"),
+ nil,
+ },
+ {
+ IP{0x20, 0x1, 0xd, 0xb8, 0, 0x1, 0, 0, 0, 0x1, 0, 0, 0, 0x1, 0, 0},
+ "2001:db8:1:0:1:0:1:0",
+ []byte("2001:db8:1:0:1:0:1:0"),
+ nil,
+ },
+ {
+ IP{0x20, 0x1, 0, 0, 0, 0, 0, 0, 0, 0x1, 0, 0, 0, 0, 0, 0x1},
+ "2001::1:0:0:1",
+ []byte("2001::1:0:0:1"),
+ nil,
+ },
+ {
+ IP{0x20, 0x1, 0xd, 0xb8, 0, 0, 0, 0, 0, 0x1, 0, 0, 0, 0, 0, 0},
+ "2001:db8:0:0:1::",
+ []byte("2001:db8:0:0:1::"),
+ nil,
+ },
+ {
+ IP{0x20, 0x1, 0xd, 0xb8, 0, 0, 0, 0, 0, 0x1, 0, 0, 0, 0, 0, 0x1},
+ "2001:db8::1:0:0:1",
+ []byte("2001:db8::1:0:0:1"),
+ nil,
+ },
+ {
+ IP{0x20, 0x1, 0xd, 0xb8, 0, 0, 0, 0, 0, 0xa, 0, 0xb, 0, 0xc, 0, 0xd},
+ "2001:db8::a:b:c:d",
+ []byte("2001:db8::a:b:c:d"),
+ nil,
+ },
+ {
+ IPv6unspecified,
+ "::",
+ []byte("::"),
+ nil,
+ },
+
+ // IP wildcard equivalent address in Dial/Listen API
+ {
+ nil,
+ "<nil>",
+ nil,
+ nil,
+ },
+
+ // Opaque byte sequence
+ {
+ IP{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef},
+ "?0123456789abcdef",
+ nil,
+ &AddrError{Err: "invalid IP address", Addr: "0123456789abcdef"},
+ },
+}
+
+func TestIPString(t *testing.T) {
+ for _, tt := range ipStringTests {
+ if out := tt.in.String(); out != tt.str {
+ t.Errorf("IP.String(%v) = %q, want %q", tt.in, out, tt.str)
+ }
+ if out, err := tt.in.MarshalText(); !bytes.Equal(out, tt.byt) || !reflect.DeepEqual(err, tt.error) {
+ t.Errorf("IP.MarshalText(%v) = %v, %v, want %v, %v", tt.in, out, err, tt.byt, tt.error)
+ }
+ }
+}
+
+var sink string
+
+func BenchmarkIPString(b *testing.B) {
+ testHookUninstaller.Do(uninstallTestHooks)
+
+ b.Run("IPv4", func(b *testing.B) {
+ benchmarkIPString(b, IPv4len)
+ })
+
+ b.Run("IPv6", func(b *testing.B) {
+ benchmarkIPString(b, IPv6len)
+ })
+}
+
+func benchmarkIPString(b *testing.B, size int) {
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ for _, tt := range ipStringTests {
+ if tt.in != nil && len(tt.in) == size {
+ sink = tt.in.String()
+ }
+ }
+ }
+}
+
+var ipMaskTests = []struct {
+ in IP
+ mask IPMask
+ out IP
+}{
+ {IPv4(192, 168, 1, 127), IPv4Mask(255, 255, 255, 128), IPv4(192, 168, 1, 0)},
+ {IPv4(192, 168, 1, 127), IPMask(ParseIP("255.255.255.192")), IPv4(192, 168, 1, 64)},
+ {IPv4(192, 168, 1, 127), IPMask(ParseIP("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffe0")), IPv4(192, 168, 1, 96)},
+ {IPv4(192, 168, 1, 127), IPv4Mask(255, 0, 255, 0), IPv4(192, 0, 1, 0)},
+ {ParseIP("2001:db8::1"), IPMask(ParseIP("ffff:ff80::")), ParseIP("2001:d80::")},
+ {ParseIP("2001:db8::1"), IPMask(ParseIP("f0f0:0f0f::")), ParseIP("2000:d08::")},
+}
+
+func TestIPMask(t *testing.T) {
+ for _, tt := range ipMaskTests {
+ if out := tt.in.Mask(tt.mask); out == nil || !tt.out.Equal(out) {
+ t.Errorf("IP(%v).Mask(%v) = %v, want %v", tt.in, tt.mask, out, tt.out)
+ }
+ }
+}
+
+var ipMaskStringTests = []struct {
+ in IPMask
+ out string
+}{
+ {IPv4Mask(255, 255, 255, 240), "fffffff0"},
+ {IPv4Mask(255, 0, 128, 0), "ff008000"},
+ {IPMask(ParseIP("ffff:ff80::")), "ffffff80000000000000000000000000"},
+ {IPMask(ParseIP("ef00:ff80::cafe:0")), "ef00ff800000000000000000cafe0000"},
+ {nil, "<nil>"},
+}
+
+func TestIPMaskString(t *testing.T) {
+ for _, tt := range ipMaskStringTests {
+ if out := tt.in.String(); out != tt.out {
+ t.Errorf("IPMask.String(%v) = %q, want %q", tt.in, out, tt.out)
+ }
+ }
+}
+
+func BenchmarkIPMaskString(b *testing.B) {
+ testHookUninstaller.Do(uninstallTestHooks)
+
+ for i := 0; i < b.N; i++ {
+ for _, tt := range ipMaskStringTests {
+ sink = tt.in.String()
+ }
+ }
+}
+
+var parseCIDRTests = []struct {
+ in string
+ ip IP
+ net *IPNet
+ err error
+}{
+ {"135.104.0.0/32", IPv4(135, 104, 0, 0), &IPNet{IP: IPv4(135, 104, 0, 0), Mask: IPv4Mask(255, 255, 255, 255)}, nil},
+ {"0.0.0.0/24", IPv4(0, 0, 0, 0), &IPNet{IP: IPv4(0, 0, 0, 0), Mask: IPv4Mask(255, 255, 255, 0)}, nil},
+ {"135.104.0.0/24", IPv4(135, 104, 0, 0), &IPNet{IP: IPv4(135, 104, 0, 0), Mask: IPv4Mask(255, 255, 255, 0)}, nil},
+ {"135.104.0.1/32", IPv4(135, 104, 0, 1), &IPNet{IP: IPv4(135, 104, 0, 1), Mask: IPv4Mask(255, 255, 255, 255)}, nil},
+ {"135.104.0.1/24", IPv4(135, 104, 0, 1), &IPNet{IP: IPv4(135, 104, 0, 0), Mask: IPv4Mask(255, 255, 255, 0)}, nil},
+ {"::1/128", ParseIP("::1"), &IPNet{IP: ParseIP("::1"), Mask: IPMask(ParseIP("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"))}, nil},
+ {"abcd:2345::/127", ParseIP("abcd:2345::"), &IPNet{IP: ParseIP("abcd:2345::"), Mask: IPMask(ParseIP("ffff:ffff:ffff:ffff:ffff:ffff:ffff:fffe"))}, nil},
+ {"abcd:2345::/65", ParseIP("abcd:2345::"), &IPNet{IP: ParseIP("abcd:2345::"), Mask: IPMask(ParseIP("ffff:ffff:ffff:ffff:8000::"))}, nil},
+ {"abcd:2345::/64", ParseIP("abcd:2345::"), &IPNet{IP: ParseIP("abcd:2345::"), Mask: IPMask(ParseIP("ffff:ffff:ffff:ffff::"))}, nil},
+ {"abcd:2345::/63", ParseIP("abcd:2345::"), &IPNet{IP: ParseIP("abcd:2345::"), Mask: IPMask(ParseIP("ffff:ffff:ffff:fffe::"))}, nil},
+ {"abcd:2345::/33", ParseIP("abcd:2345::"), &IPNet{IP: ParseIP("abcd:2345::"), Mask: IPMask(ParseIP("ffff:ffff:8000::"))}, nil},
+ {"abcd:2345::/32", ParseIP("abcd:2345::"), &IPNet{IP: ParseIP("abcd:2345::"), Mask: IPMask(ParseIP("ffff:ffff::"))}, nil},
+ {"abcd:2344::/31", ParseIP("abcd:2344::"), &IPNet{IP: ParseIP("abcd:2344::"), Mask: IPMask(ParseIP("ffff:fffe::"))}, nil},
+ {"abcd:2300::/24", ParseIP("abcd:2300::"), &IPNet{IP: ParseIP("abcd:2300::"), Mask: IPMask(ParseIP("ffff:ff00::"))}, nil},
+ {"abcd:2345::/24", ParseIP("abcd:2345::"), &IPNet{IP: ParseIP("abcd:2300::"), Mask: IPMask(ParseIP("ffff:ff00::"))}, nil},
+ {"2001:DB8::/48", ParseIP("2001:DB8::"), &IPNet{IP: ParseIP("2001:DB8::"), Mask: IPMask(ParseIP("ffff:ffff:ffff::"))}, nil},
+ {"2001:DB8::1/48", ParseIP("2001:DB8::1"), &IPNet{IP: ParseIP("2001:DB8::"), Mask: IPMask(ParseIP("ffff:ffff:ffff::"))}, nil},
+ {"192.168.1.1/255.255.255.0", nil, nil, &ParseError{Type: "CIDR address", Text: "192.168.1.1/255.255.255.0"}},
+ {"192.168.1.1/35", nil, nil, &ParseError{Type: "CIDR address", Text: "192.168.1.1/35"}},
+ {"2001:db8::1/-1", nil, nil, &ParseError{Type: "CIDR address", Text: "2001:db8::1/-1"}},
+ {"2001:db8::1/-0", nil, nil, &ParseError{Type: "CIDR address", Text: "2001:db8::1/-0"}},
+ {"-0.0.0.0/32", nil, nil, &ParseError{Type: "CIDR address", Text: "-0.0.0.0/32"}},
+ {"0.-1.0.0/32", nil, nil, &ParseError{Type: "CIDR address", Text: "0.-1.0.0/32"}},
+ {"0.0.-2.0/32", nil, nil, &ParseError{Type: "CIDR address", Text: "0.0.-2.0/32"}},
+ {"0.0.0.-3/32", nil, nil, &ParseError{Type: "CIDR address", Text: "0.0.0.-3/32"}},
+ {"0.0.0.0/-0", nil, nil, &ParseError{Type: "CIDR address", Text: "0.0.0.0/-0"}},
+ {"127.000.000.001/32", nil, nil, &ParseError{Type: "CIDR address", Text: "127.000.000.001/32"}},
+ {"", nil, nil, &ParseError{Type: "CIDR address", Text: ""}},
+}
+
+func TestParseCIDR(t *testing.T) {
+ for _, tt := range parseCIDRTests {
+ ip, net, err := ParseCIDR(tt.in)
+ if !reflect.DeepEqual(err, tt.err) {
+ t.Errorf("ParseCIDR(%q) = %v, %v; want %v, %v", tt.in, ip, net, tt.ip, tt.net)
+ }
+ if err == nil && (!tt.ip.Equal(ip) || !tt.net.IP.Equal(net.IP) || !reflect.DeepEqual(net.Mask, tt.net.Mask)) {
+ t.Errorf("ParseCIDR(%q) = %v, {%v, %v}; want %v, {%v, %v}", tt.in, ip, net.IP, net.Mask, tt.ip, tt.net.IP, tt.net.Mask)
+ }
+ }
+}
+
+var ipNetContainsTests = []struct {
+ ip IP
+ net *IPNet
+ ok bool
+}{
+ {IPv4(172, 16, 1, 1), &IPNet{IP: IPv4(172, 16, 0, 0), Mask: CIDRMask(12, 32)}, true},
+ {IPv4(172, 24, 0, 1), &IPNet{IP: IPv4(172, 16, 0, 0), Mask: CIDRMask(13, 32)}, false},
+ {IPv4(192, 168, 0, 3), &IPNet{IP: IPv4(192, 168, 0, 0), Mask: IPv4Mask(0, 0, 255, 252)}, true},
+ {IPv4(192, 168, 0, 4), &IPNet{IP: IPv4(192, 168, 0, 0), Mask: IPv4Mask(0, 255, 0, 252)}, false},
+ {ParseIP("2001:db8:1:2::1"), &IPNet{IP: ParseIP("2001:db8:1::"), Mask: CIDRMask(47, 128)}, true},
+ {ParseIP("2001:db8:1:2::1"), &IPNet{IP: ParseIP("2001:db8:2::"), Mask: CIDRMask(47, 128)}, false},
+ {ParseIP("2001:db8:1:2::1"), &IPNet{IP: ParseIP("2001:db8:1::"), Mask: IPMask(ParseIP("ffff:0:ffff::"))}, true},
+ {ParseIP("2001:db8:1:2::1"), &IPNet{IP: ParseIP("2001:db8:1::"), Mask: IPMask(ParseIP("0:0:0:ffff::"))}, false},
+}
+
+func TestIPNetContains(t *testing.T) {
+ for _, tt := range ipNetContainsTests {
+ if ok := tt.net.Contains(tt.ip); ok != tt.ok {
+ t.Errorf("IPNet(%v).Contains(%v) = %v, want %v", tt.net, tt.ip, ok, tt.ok)
+ }
+ }
+}
+
+var ipNetStringTests = []struct {
+ in *IPNet
+ out string
+}{
+ {&IPNet{IP: IPv4(192, 168, 1, 0), Mask: CIDRMask(26, 32)}, "192.168.1.0/26"},
+ {&IPNet{IP: IPv4(192, 168, 1, 0), Mask: IPv4Mask(255, 0, 255, 0)}, "192.168.1.0/ff00ff00"},
+ {&IPNet{IP: ParseIP("2001:db8::"), Mask: CIDRMask(55, 128)}, "2001:db8::/55"},
+ {&IPNet{IP: ParseIP("2001:db8::"), Mask: IPMask(ParseIP("8000:f123:0:cafe::"))}, "2001:db8::/8000f1230000cafe0000000000000000"},
+ {nil, "<nil>"},
+}
+
+func TestIPNetString(t *testing.T) {
+ for _, tt := range ipNetStringTests {
+ if out := tt.in.String(); out != tt.out {
+ t.Errorf("IPNet.String(%v) = %q, want %q", tt.in, out, tt.out)
+ }
+ }
+}
+
+var cidrMaskTests = []struct {
+ ones int
+ bits int
+ out IPMask
+}{
+ {0, 32, IPv4Mask(0, 0, 0, 0)},
+ {12, 32, IPv4Mask(255, 240, 0, 0)},
+ {24, 32, IPv4Mask(255, 255, 255, 0)},
+ {32, 32, IPv4Mask(255, 255, 255, 255)},
+ {0, 128, IPMask{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}},
+ {4, 128, IPMask{0xf0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}},
+ {48, 128, IPMask{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}},
+ {128, 128, IPMask{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}},
+ {33, 32, nil},
+ {32, 33, nil},
+ {-1, 128, nil},
+ {128, -1, nil},
+}
+
+func TestCIDRMask(t *testing.T) {
+ for _, tt := range cidrMaskTests {
+ if out := CIDRMask(tt.ones, tt.bits); !reflect.DeepEqual(out, tt.out) {
+ t.Errorf("CIDRMask(%v, %v) = %v, want %v", tt.ones, tt.bits, out, tt.out)
+ }
+ }
+}
+
+var (
+ v4addr = IP{192, 168, 0, 1}
+ v4mappedv6addr = IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 192, 168, 0, 1}
+ v6addr = IP{0x20, 0x1, 0xd, 0xb8, 0, 0, 0, 0, 0, 0, 0x1, 0x23, 0, 0x12, 0, 0x1}
+ v4mask = IPMask{255, 255, 255, 0}
+ v4mappedv6mask = IPMask{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 255, 255, 255, 0}
+ v6mask = IPMask{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0}
+ badaddr = IP{192, 168, 0}
+ badmask = IPMask{255, 255, 0}
+ v4maskzero = IPMask{0, 0, 0, 0}
+)
+
+var networkNumberAndMaskTests = []struct {
+ in IPNet
+ out IPNet
+}{
+ {IPNet{IP: v4addr, Mask: v4mask}, IPNet{IP: v4addr, Mask: v4mask}},
+ {IPNet{IP: v4addr, Mask: v4mappedv6mask}, IPNet{IP: v4addr, Mask: v4mask}},
+ {IPNet{IP: v4mappedv6addr, Mask: v4mappedv6mask}, IPNet{IP: v4addr, Mask: v4mask}},
+ {IPNet{IP: v4mappedv6addr, Mask: v6mask}, IPNet{IP: v4addr, Mask: v4maskzero}},
+ {IPNet{IP: v4addr, Mask: v6mask}, IPNet{IP: v4addr, Mask: v4maskzero}},
+ {IPNet{IP: v6addr, Mask: v6mask}, IPNet{IP: v6addr, Mask: v6mask}},
+ {IPNet{IP: v6addr, Mask: v4mappedv6mask}, IPNet{IP: v6addr, Mask: v4mappedv6mask}},
+ {in: IPNet{IP: v6addr, Mask: v4mask}},
+ {in: IPNet{IP: v4addr, Mask: badmask}},
+ {in: IPNet{IP: v4mappedv6addr, Mask: badmask}},
+ {in: IPNet{IP: v6addr, Mask: badmask}},
+ {in: IPNet{IP: badaddr, Mask: v4mask}},
+ {in: IPNet{IP: badaddr, Mask: v4mappedv6mask}},
+ {in: IPNet{IP: badaddr, Mask: v6mask}},
+ {in: IPNet{IP: badaddr, Mask: badmask}},
+}
+
+func TestNetworkNumberAndMask(t *testing.T) {
+ for _, tt := range networkNumberAndMaskTests {
+ ip, m := networkNumberAndMask(&tt.in)
+ out := &IPNet{IP: ip, Mask: m}
+ if !reflect.DeepEqual(&tt.out, out) {
+ t.Errorf("networkNumberAndMask(%v) = %v, want %v", tt.in, out, &tt.out)
+ }
+ }
+}
+
+func TestSplitHostPort(t *testing.T) {
+ for _, tt := range []struct {
+ hostPort string
+ host string
+ port string
+ }{
+ // Host name
+ {"localhost:http", "localhost", "http"},
+ {"localhost:80", "localhost", "80"},
+
+ // Go-specific host name with zone identifier
+ {"localhost%lo0:http", "localhost%lo0", "http"},
+ {"localhost%lo0:80", "localhost%lo0", "80"},
+ {"[localhost%lo0]:http", "localhost%lo0", "http"}, // Go 1 behavior
+ {"[localhost%lo0]:80", "localhost%lo0", "80"}, // Go 1 behavior
+
+ // IP literal
+ {"127.0.0.1:http", "127.0.0.1", "http"},
+ {"127.0.0.1:80", "127.0.0.1", "80"},
+ {"[::1]:http", "::1", "http"},
+ {"[::1]:80", "::1", "80"},
+
+ // IP literal with zone identifier
+ {"[::1%lo0]:http", "::1%lo0", "http"},
+ {"[::1%lo0]:80", "::1%lo0", "80"},
+
+ // Go-specific wildcard for host name
+ {":http", "", "http"}, // Go 1 behavior
+ {":80", "", "80"}, // Go 1 behavior
+
+ // Go-specific wildcard for service name or transport port number
+ {"golang.org:", "golang.org", ""}, // Go 1 behavior
+ {"127.0.0.1:", "127.0.0.1", ""}, // Go 1 behavior
+ {"[::1]:", "::1", ""}, // Go 1 behavior
+
+ // Opaque service name
+ {"golang.org:https%foo", "golang.org", "https%foo"}, // Go 1 behavior
+ } {
+ if host, port, err := SplitHostPort(tt.hostPort); host != tt.host || port != tt.port || err != nil {
+ t.Errorf("SplitHostPort(%q) = %q, %q, %v; want %q, %q, nil", tt.hostPort, host, port, err, tt.host, tt.port)
+ }
+ }
+
+ for _, tt := range []struct {
+ hostPort string
+ err string
+ }{
+ {"golang.org", "missing port in address"},
+ {"127.0.0.1", "missing port in address"},
+ {"[::1]", "missing port in address"},
+ {"[fe80::1%lo0]", "missing port in address"},
+ {"[localhost%lo0]", "missing port in address"},
+ {"localhost%lo0", "missing port in address"},
+
+ {"::1", "too many colons in address"},
+ {"fe80::1%lo0", "too many colons in address"},
+ {"fe80::1%lo0:80", "too many colons in address"},
+
+ // Test cases that didn't fail in Go 1
+
+ {"[foo:bar]", "missing port in address"},
+ {"[foo:bar]baz", "missing port in address"},
+ {"[foo]bar:baz", "missing port in address"},
+
+ {"[foo]:[bar]:baz", "too many colons in address"},
+
+ {"[foo]:[bar]baz", "unexpected '[' in address"},
+ {"foo[bar]:baz", "unexpected '[' in address"},
+
+ {"foo]bar:baz", "unexpected ']' in address"},
+ } {
+ if host, port, err := SplitHostPort(tt.hostPort); err == nil {
+ t.Errorf("SplitHostPort(%q) should have failed", tt.hostPort)
+ } else {
+ e := err.(*AddrError)
+ if e.Err != tt.err {
+ t.Errorf("SplitHostPort(%q) = _, _, %q; want %q", tt.hostPort, e.Err, tt.err)
+ }
+ if host != "" || port != "" {
+ t.Errorf("SplitHostPort(%q) = %q, %q, err; want %q, %q, err on failure", tt.hostPort, host, port, "", "")
+ }
+ }
+ }
+}
+
+func TestJoinHostPort(t *testing.T) {
+ for _, tt := range []struct {
+ host string
+ port string
+ hostPort string
+ }{
+ // Host name
+ {"localhost", "http", "localhost:http"},
+ {"localhost", "80", "localhost:80"},
+
+ // Go-specific host name with zone identifier
+ {"localhost%lo0", "http", "localhost%lo0:http"},
+ {"localhost%lo0", "80", "localhost%lo0:80"},
+
+ // IP literal
+ {"127.0.0.1", "http", "127.0.0.1:http"},
+ {"127.0.0.1", "80", "127.0.0.1:80"},
+ {"::1", "http", "[::1]:http"},
+ {"::1", "80", "[::1]:80"},
+
+ // IP literal with zone identifier
+ {"::1%lo0", "http", "[::1%lo0]:http"},
+ {"::1%lo0", "80", "[::1%lo0]:80"},
+
+ // Go-specific wildcard for host name
+ {"", "http", ":http"}, // Go 1 behavior
+ {"", "80", ":80"}, // Go 1 behavior
+
+ // Go-specific wildcard for service name or transport port number
+ {"golang.org", "", "golang.org:"}, // Go 1 behavior
+ {"127.0.0.1", "", "127.0.0.1:"}, // Go 1 behavior
+ {"::1", "", "[::1]:"}, // Go 1 behavior
+
+ // Opaque service name
+ {"golang.org", "https%foo", "golang.org:https%foo"}, // Go 1 behavior
+ } {
+ if hostPort := JoinHostPort(tt.host, tt.port); hostPort != tt.hostPort {
+ t.Errorf("JoinHostPort(%q, %q) = %q; want %q", tt.host, tt.port, hostPort, tt.hostPort)
+ }
+ }
+}
+
+var ipAddrFamilyTests = []struct {
+ in IP
+ af4 bool
+ af6 bool
+}{
+ {IPv4bcast, true, false},
+ {IPv4allsys, true, false},
+ {IPv4allrouter, true, false},
+ {IPv4zero, true, false},
+ {IPv4(224, 0, 0, 1), true, false},
+ {IPv4(127, 0, 0, 1), true, false},
+ {IPv4(240, 0, 0, 1), true, false},
+ {IPv6unspecified, false, true},
+ {IPv6loopback, false, true},
+ {IPv6interfacelocalallnodes, false, true},
+ {IPv6linklocalallnodes, false, true},
+ {IPv6linklocalallrouters, false, true},
+ {ParseIP("ff05::a:b:c:d"), false, true},
+ {ParseIP("fe80::1:2:3:4"), false, true},
+ {ParseIP("2001:db8::123:12:1"), false, true},
+}
+
+func TestIPAddrFamily(t *testing.T) {
+ for _, tt := range ipAddrFamilyTests {
+ if af := tt.in.To4() != nil; af != tt.af4 {
+ t.Errorf("verifying IPv4 address family for %q = %v, want %v", tt.in, af, tt.af4)
+ }
+ if af := len(tt.in) == IPv6len && tt.in.To4() == nil; af != tt.af6 {
+ t.Errorf("verifying IPv6 address family for %q = %v, want %v", tt.in, af, tt.af6)
+ }
+ }
+}
+
+var ipAddrScopeTests = []struct {
+ scope func(IP) bool
+ in IP
+ ok bool
+}{
+ {IP.IsUnspecified, IPv4zero, true},
+ {IP.IsUnspecified, IPv4(127, 0, 0, 1), false},
+ {IP.IsUnspecified, IPv6unspecified, true},
+ {IP.IsUnspecified, IPv6interfacelocalallnodes, false},
+ {IP.IsUnspecified, nil, false},
+ {IP.IsLoopback, IPv4(127, 0, 0, 1), true},
+ {IP.IsLoopback, IPv4(127, 255, 255, 254), true},
+ {IP.IsLoopback, IPv4(128, 1, 2, 3), false},
+ {IP.IsLoopback, IPv6loopback, true},
+ {IP.IsLoopback, IPv6linklocalallrouters, false},
+ {IP.IsLoopback, nil, false},
+ {IP.IsMulticast, IPv4(224, 0, 0, 0), true},
+ {IP.IsMulticast, IPv4(239, 0, 0, 0), true},
+ {IP.IsMulticast, IPv4(240, 0, 0, 0), false},
+ {IP.IsMulticast, IPv6linklocalallnodes, true},
+ {IP.IsMulticast, IP{0xff, 0x05, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true},
+ {IP.IsMulticast, IP{0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, false},
+ {IP.IsMulticast, nil, false},
+ {IP.IsInterfaceLocalMulticast, IPv4(224, 0, 0, 0), false},
+ {IP.IsInterfaceLocalMulticast, IPv4(0xff, 0x01, 0, 0), false},
+ {IP.IsInterfaceLocalMulticast, IPv6interfacelocalallnodes, true},
+ {IP.IsInterfaceLocalMulticast, nil, false},
+ {IP.IsLinkLocalMulticast, IPv4(224, 0, 0, 0), true},
+ {IP.IsLinkLocalMulticast, IPv4(239, 0, 0, 0), false},
+ {IP.IsLinkLocalMulticast, IPv4(0xff, 0x02, 0, 0), false},
+ {IP.IsLinkLocalMulticast, IPv6linklocalallrouters, true},
+ {IP.IsLinkLocalMulticast, IP{0xff, 0x05, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, false},
+ {IP.IsLinkLocalMulticast, nil, false},
+ {IP.IsLinkLocalUnicast, IPv4(169, 254, 0, 0), true},
+ {IP.IsLinkLocalUnicast, IPv4(169, 255, 0, 0), false},
+ {IP.IsLinkLocalUnicast, IPv4(0xfe, 0x80, 0, 0), false},
+ {IP.IsLinkLocalUnicast, IP{0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true},
+ {IP.IsLinkLocalUnicast, IP{0xfe, 0xc0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, false},
+ {IP.IsLinkLocalUnicast, nil, false},
+ {IP.IsGlobalUnicast, IPv4(240, 0, 0, 0), true},
+ {IP.IsGlobalUnicast, IPv4(232, 0, 0, 0), false},
+ {IP.IsGlobalUnicast, IPv4(169, 254, 0, 0), false},
+ {IP.IsGlobalUnicast, IPv4bcast, false},
+ {IP.IsGlobalUnicast, IP{0x20, 0x1, 0xd, 0xb8, 0, 0, 0, 0, 0, 0, 0x1, 0x23, 0, 0x12, 0, 0x1}, true},
+ {IP.IsGlobalUnicast, IP{0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, false},
+ {IP.IsGlobalUnicast, IP{0xff, 0x05, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, false},
+ {IP.IsGlobalUnicast, nil, false},
+ {IP.IsPrivate, nil, false},
+ {IP.IsPrivate, IPv4(1, 1, 1, 1), false},
+ {IP.IsPrivate, IPv4(9, 255, 255, 255), false},
+ {IP.IsPrivate, IPv4(10, 0, 0, 0), true},
+ {IP.IsPrivate, IPv4(10, 255, 255, 255), true},
+ {IP.IsPrivate, IPv4(11, 0, 0, 0), false},
+ {IP.IsPrivate, IPv4(172, 15, 255, 255), false},
+ {IP.IsPrivate, IPv4(172, 16, 0, 0), true},
+ {IP.IsPrivate, IPv4(172, 16, 255, 255), true},
+ {IP.IsPrivate, IPv4(172, 23, 18, 255), true},
+ {IP.IsPrivate, IPv4(172, 31, 255, 255), true},
+ {IP.IsPrivate, IPv4(172, 31, 0, 0), true},
+ {IP.IsPrivate, IPv4(172, 32, 0, 0), false},
+ {IP.IsPrivate, IPv4(192, 167, 255, 255), false},
+ {IP.IsPrivate, IPv4(192, 168, 0, 0), true},
+ {IP.IsPrivate, IPv4(192, 168, 255, 255), true},
+ {IP.IsPrivate, IPv4(192, 169, 0, 0), false},
+ {IP.IsPrivate, IP{0xfb, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, false},
+ {IP.IsPrivate, IP{0xfc, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true},
+ {IP.IsPrivate, IP{0xfc, 0xff, 0x12, 0, 0, 0, 0, 0x44, 0, 0, 0, 0, 0, 0, 0, 0}, true},
+ {IP.IsPrivate, IP{0xfd, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, true},
+ {IP.IsPrivate, IP{0xfe, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, false},
+}
+
+func name(f any) string {
+ return runtime.FuncForPC(reflect.ValueOf(f).Pointer()).Name()
+}
+
+func TestIPAddrScope(t *testing.T) {
+ for _, tt := range ipAddrScopeTests {
+ if ok := tt.scope(tt.in); ok != tt.ok {
+ t.Errorf("%s(%q) = %v, want %v", name(tt.scope), tt.in, ok, tt.ok)
+ }
+ ip := tt.in.To4()
+ if ip == nil {
+ continue
+ }
+ if ok := tt.scope(ip); ok != tt.ok {
+ t.Errorf("%s(%q) = %v, want %v", name(tt.scope), ip, ok, tt.ok)
+ }
+ }
+}
+
+func BenchmarkIPEqual(b *testing.B) {
+ b.Run("IPv4", func(b *testing.B) {
+ benchmarkIPEqual(b, IPv4len)
+ })
+ b.Run("IPv6", func(b *testing.B) {
+ benchmarkIPEqual(b, IPv6len)
+ })
+}
+
+func benchmarkIPEqual(b *testing.B, size int) {
+ ips := make([]IP, 1000)
+ for i := range ips {
+ ips[i] = make(IP, size)
+ rand.Read(ips[i])
+ }
+ // Half of the N are equal.
+ for i := 0; i < b.N/2; i++ {
+ x := ips[i%len(ips)]
+ y := ips[i%len(ips)]
+ x.Equal(y)
+ }
+ // The other half are not equal.
+ for i := 0; i < b.N/2; i++ {
+ x := ips[i%len(ips)]
+ y := ips[(i+1)%len(ips)]
+ x.Equal(y)
+ }
+}
diff --git a/src/net/iprawsock.go b/src/net/iprawsock.go
new file mode 100644
index 0000000..f18331a
--- /dev/null
+++ b/src/net/iprawsock.go
@@ -0,0 +1,240 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "context"
+ "syscall"
+)
+
+// BUG(mikio): On every POSIX platform, reads from the "ip4" network
+// using the ReadFrom or ReadFromIP method might not return a complete
+// IPv4 packet, including its header, even if there is space
+// available. This can occur even in cases where Read or ReadMsgIP
+// could return a complete packet. For this reason, it is recommended
+// that you do not use these methods if it is important to receive a
+// full packet.
+//
+// The Go 1 compatibility guidelines make it impossible for us to
+// change the behavior of these methods; use Read or ReadMsgIP
+// instead.
+
+// BUG(mikio): On JS and Plan 9, methods and functions related
+// to IPConn are not implemented.
+
+// BUG(mikio): On Windows, the File method of IPConn is not
+// implemented.
+
+// IPAddr represents the address of an IP end point.
+type IPAddr struct {
+ IP IP
+ Zone string // IPv6 scoped addressing zone
+}
+
+// Network returns the address's network name, "ip".
+func (a *IPAddr) Network() string { return "ip" }
+
+func (a *IPAddr) String() string {
+ if a == nil {
+ return "<nil>"
+ }
+ ip := ipEmptyString(a.IP)
+ if a.Zone != "" {
+ return ip + "%" + a.Zone
+ }
+ return ip
+}
+
+func (a *IPAddr) isWildcard() bool {
+ if a == nil || a.IP == nil {
+ return true
+ }
+ return a.IP.IsUnspecified()
+}
+
+func (a *IPAddr) opAddr() Addr {
+ if a == nil {
+ return nil
+ }
+ return a
+}
+
+// ResolveIPAddr returns an address of IP end point.
+//
+// The network must be an IP network name.
+//
+// If the host in the address parameter is not a literal IP address,
+// ResolveIPAddr resolves the address to an address of IP end point.
+// Otherwise, it parses the address as a literal IP address.
+// The address parameter can use a host name, but this is not
+// recommended, because it will return at most one of the host name's
+// IP addresses.
+//
+// See func Dial for a description of the network and address
+// parameters.
+func ResolveIPAddr(network, address string) (*IPAddr, error) {
+ if network == "" { // a hint wildcard for Go 1.0 undocumented behavior
+ network = "ip"
+ }
+ afnet, _, err := parseNetwork(context.Background(), network, false)
+ if err != nil {
+ return nil, err
+ }
+ switch afnet {
+ case "ip", "ip4", "ip6":
+ default:
+ return nil, UnknownNetworkError(network)
+ }
+ addrs, err := DefaultResolver.internetAddrList(context.Background(), afnet, address)
+ if err != nil {
+ return nil, err
+ }
+ return addrs.forResolve(network, address).(*IPAddr), nil
+}
+
+// IPConn is the implementation of the Conn and PacketConn interfaces
+// for IP network connections.
+type IPConn struct {
+ conn
+}
+
+// SyscallConn returns a raw network connection.
+// This implements the syscall.Conn interface.
+func (c *IPConn) SyscallConn() (syscall.RawConn, error) {
+ if !c.ok() {
+ return nil, syscall.EINVAL
+ }
+ return newRawConn(c.fd)
+}
+
+// ReadFromIP acts like ReadFrom but returns an IPAddr.
+func (c *IPConn) ReadFromIP(b []byte) (int, *IPAddr, error) {
+ if !c.ok() {
+ return 0, nil, syscall.EINVAL
+ }
+ n, addr, err := c.readFrom(b)
+ if err != nil {
+ err = &OpError{Op: "read", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
+ }
+ return n, addr, err
+}
+
+// ReadFrom implements the PacketConn ReadFrom method.
+func (c *IPConn) ReadFrom(b []byte) (int, Addr, error) {
+ if !c.ok() {
+ return 0, nil, syscall.EINVAL
+ }
+ n, addr, err := c.readFrom(b)
+ if err != nil {
+ err = &OpError{Op: "read", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
+ }
+ if addr == nil {
+ return n, nil, err
+ }
+ return n, addr, err
+}
+
+// ReadMsgIP reads a message from c, copying the payload into b and
+// the associated out-of-band data into oob. It returns the number of
+// bytes copied into b, the number of bytes copied into oob, the flags
+// that were set on the message and the source address of the message.
+//
+// The packages golang.org/x/net/ipv4 and golang.org/x/net/ipv6 can be
+// used to manipulate IP-level socket options in oob.
+func (c *IPConn) ReadMsgIP(b, oob []byte) (n, oobn, flags int, addr *IPAddr, err error) {
+ if !c.ok() {
+ return 0, 0, 0, nil, syscall.EINVAL
+ }
+ n, oobn, flags, addr, err = c.readMsg(b, oob)
+ if err != nil {
+ err = &OpError{Op: "read", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
+ }
+ return
+}
+
+// WriteToIP acts like WriteTo but takes an IPAddr.
+func (c *IPConn) WriteToIP(b []byte, addr *IPAddr) (int, error) {
+ if !c.ok() {
+ return 0, syscall.EINVAL
+ }
+ n, err := c.writeTo(b, addr)
+ if err != nil {
+ err = &OpError{Op: "write", Net: c.fd.net, Source: c.fd.laddr, Addr: addr.opAddr(), Err: err}
+ }
+ return n, err
+}
+
+// WriteTo implements the PacketConn WriteTo method.
+func (c *IPConn) WriteTo(b []byte, addr Addr) (int, error) {
+ if !c.ok() {
+ return 0, syscall.EINVAL
+ }
+ a, ok := addr.(*IPAddr)
+ if !ok {
+ return 0, &OpError{Op: "write", Net: c.fd.net, Source: c.fd.laddr, Addr: addr, Err: syscall.EINVAL}
+ }
+ n, err := c.writeTo(b, a)
+ if err != nil {
+ err = &OpError{Op: "write", Net: c.fd.net, Source: c.fd.laddr, Addr: a.opAddr(), Err: err}
+ }
+ return n, err
+}
+
+// WriteMsgIP writes a message to addr via c, copying the payload from
+// b and the associated out-of-band data from oob. It returns the
+// number of payload and out-of-band bytes written.
+//
+// The packages golang.org/x/net/ipv4 and golang.org/x/net/ipv6 can be
+// used to manipulate IP-level socket options in oob.
+func (c *IPConn) WriteMsgIP(b, oob []byte, addr *IPAddr) (n, oobn int, err error) {
+ if !c.ok() {
+ return 0, 0, syscall.EINVAL
+ }
+ n, oobn, err = c.writeMsg(b, oob, addr)
+ if err != nil {
+ err = &OpError{Op: "write", Net: c.fd.net, Source: c.fd.laddr, Addr: addr.opAddr(), Err: err}
+ }
+ return
+}
+
+func newIPConn(fd *netFD) *IPConn { return &IPConn{conn{fd}} }
+
+// DialIP acts like Dial for IP networks.
+//
+// The network must be an IP network name; see func Dial for details.
+//
+// If laddr is nil, a local address is automatically chosen.
+// If the IP field of raddr is nil or an unspecified IP address, the
+// local system is assumed.
+func DialIP(network string, laddr, raddr *IPAddr) (*IPConn, error) {
+ if raddr == nil {
+ return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: nil, Err: errMissingAddress}
+ }
+ sd := &sysDialer{network: network, address: raddr.String()}
+ c, err := sd.dialIP(context.Background(), laddr, raddr)
+ if err != nil {
+ return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
+ }
+ return c, nil
+}
+
+// ListenIP acts like ListenPacket for IP networks.
+//
+// The network must be an IP network name; see func Dial for details.
+//
+// If the IP field of laddr is nil or an unspecified IP address,
+// ListenIP listens on all available IP addresses of the local system
+// except multicast IP addresses.
+func ListenIP(network string, laddr *IPAddr) (*IPConn, error) {
+ if laddr == nil {
+ laddr = &IPAddr{}
+ }
+ sl := &sysListener{network: network, address: laddr.String()}
+ c, err := sl.listenIP(context.Background(), laddr)
+ if err != nil {
+ return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: laddr.opAddr(), Err: err}
+ }
+ return c, nil
+}
diff --git a/src/net/iprawsock_plan9.go b/src/net/iprawsock_plan9.go
new file mode 100644
index 0000000..ebe5808
--- /dev/null
+++ b/src/net/iprawsock_plan9.go
@@ -0,0 +1,34 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "context"
+ "syscall"
+)
+
+func (c *IPConn) readFrom(b []byte) (int, *IPAddr, error) {
+ return 0, nil, syscall.EPLAN9
+}
+
+func (c *IPConn) readMsg(b, oob []byte) (n, oobn, flags int, addr *IPAddr, err error) {
+ return 0, 0, 0, nil, syscall.EPLAN9
+}
+
+func (c *IPConn) writeTo(b []byte, addr *IPAddr) (int, error) {
+ return 0, syscall.EPLAN9
+}
+
+func (c *IPConn) writeMsg(b, oob []byte, addr *IPAddr) (n, oobn int, err error) {
+ return 0, 0, syscall.EPLAN9
+}
+
+func (sd *sysDialer) dialIP(ctx context.Context, laddr, raddr *IPAddr) (*IPConn, error) {
+ return nil, syscall.EPLAN9
+}
+
+func (sl *sysListener) listenIP(ctx context.Context, laddr *IPAddr) (*IPConn, error) {
+ return nil, syscall.EPLAN9
+}
diff --git a/src/net/iprawsock_posix.go b/src/net/iprawsock_posix.go
new file mode 100644
index 0000000..59967eb
--- /dev/null
+++ b/src/net/iprawsock_posix.go
@@ -0,0 +1,159 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build unix || (js && wasm) || wasip1 || windows
+
+package net
+
+import (
+ "context"
+ "syscall"
+)
+
+func sockaddrToIP(sa syscall.Sockaddr) Addr {
+ switch sa := sa.(type) {
+ case *syscall.SockaddrInet4:
+ return &IPAddr{IP: sa.Addr[0:]}
+ case *syscall.SockaddrInet6:
+ return &IPAddr{IP: sa.Addr[0:], Zone: zoneCache.name(int(sa.ZoneId))}
+ }
+ return nil
+}
+
+func (a *IPAddr) family() int {
+ if a == nil || len(a.IP) <= IPv4len {
+ return syscall.AF_INET
+ }
+ if a.IP.To4() != nil {
+ return syscall.AF_INET
+ }
+ return syscall.AF_INET6
+}
+
+func (a *IPAddr) sockaddr(family int) (syscall.Sockaddr, error) {
+ if a == nil {
+ return nil, nil
+ }
+ return ipToSockaddr(family, a.IP, 0, a.Zone)
+}
+
+func (a *IPAddr) toLocal(net string) sockaddr {
+ return &IPAddr{loopbackIP(net), a.Zone}
+}
+
+func (c *IPConn) readFrom(b []byte) (int, *IPAddr, error) {
+ // TODO(cw,rsc): consider using readv if we know the family
+ // type to avoid the header trim/copy
+ var addr *IPAddr
+ n, sa, err := c.fd.readFrom(b)
+ switch sa := sa.(type) {
+ case *syscall.SockaddrInet4:
+ addr = &IPAddr{IP: sa.Addr[0:]}
+ n = stripIPv4Header(n, b)
+ case *syscall.SockaddrInet6:
+ addr = &IPAddr{IP: sa.Addr[0:], Zone: zoneCache.name(int(sa.ZoneId))}
+ }
+ return n, addr, err
+}
+
+func stripIPv4Header(n int, b []byte) int {
+ if len(b) < 20 {
+ return n
+ }
+ l := int(b[0]&0x0f) << 2
+ if 20 > l || l > len(b) {
+ return n
+ }
+ if b[0]>>4 != 4 {
+ return n
+ }
+ copy(b, b[l:])
+ return n - l
+}
+
+func (c *IPConn) readMsg(b, oob []byte) (n, oobn, flags int, addr *IPAddr, err error) {
+ var sa syscall.Sockaddr
+ n, oobn, flags, sa, err = c.fd.readMsg(b, oob, 0)
+ switch sa := sa.(type) {
+ case *syscall.SockaddrInet4:
+ addr = &IPAddr{IP: sa.Addr[0:]}
+ case *syscall.SockaddrInet6:
+ addr = &IPAddr{IP: sa.Addr[0:], Zone: zoneCache.name(int(sa.ZoneId))}
+ }
+ return
+}
+
+func (c *IPConn) writeTo(b []byte, addr *IPAddr) (int, error) {
+ if c.fd.isConnected {
+ return 0, ErrWriteToConnected
+ }
+ if addr == nil {
+ return 0, errMissingAddress
+ }
+ sa, err := addr.sockaddr(c.fd.family)
+ if err != nil {
+ return 0, err
+ }
+ return c.fd.writeTo(b, sa)
+}
+
+func (c *IPConn) writeMsg(b, oob []byte, addr *IPAddr) (n, oobn int, err error) {
+ if c.fd.isConnected {
+ return 0, 0, ErrWriteToConnected
+ }
+ if addr == nil {
+ return 0, 0, errMissingAddress
+ }
+ sa, err := addr.sockaddr(c.fd.family)
+ if err != nil {
+ return 0, 0, err
+ }
+ return c.fd.writeMsg(b, oob, sa)
+}
+
+func (sd *sysDialer) dialIP(ctx context.Context, laddr, raddr *IPAddr) (*IPConn, error) {
+ network, proto, err := parseNetwork(ctx, sd.network, true)
+ if err != nil {
+ return nil, err
+ }
+ switch network {
+ case "ip", "ip4", "ip6":
+ default:
+ return nil, UnknownNetworkError(sd.network)
+ }
+ ctrlCtxFn := sd.Dialer.ControlContext
+ if ctrlCtxFn == nil && sd.Dialer.Control != nil {
+ ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
+ return sd.Dialer.Control(network, address, c)
+ }
+ }
+ fd, err := internetSocket(ctx, network, laddr, raddr, syscall.SOCK_RAW, proto, "dial", ctrlCtxFn)
+ if err != nil {
+ return nil, err
+ }
+ return newIPConn(fd), nil
+}
+
+func (sl *sysListener) listenIP(ctx context.Context, laddr *IPAddr) (*IPConn, error) {
+ network, proto, err := parseNetwork(ctx, sl.network, true)
+ if err != nil {
+ return nil, err
+ }
+ switch network {
+ case "ip", "ip4", "ip6":
+ default:
+ return nil, UnknownNetworkError(sl.network)
+ }
+ var ctrlCtxFn func(cxt context.Context, network, address string, c syscall.RawConn) error
+ if sl.ListenConfig.Control != nil {
+ ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
+ return sl.ListenConfig.Control(network, address, c)
+ }
+ }
+ fd, err := internetSocket(ctx, network, laddr, nil, syscall.SOCK_RAW, proto, "listen", ctrlCtxFn)
+ if err != nil {
+ return nil, err
+ }
+ return newIPConn(fd), nil
+}
diff --git a/src/net/iprawsock_test.go b/src/net/iprawsock_test.go
new file mode 100644
index 0000000..14c03a1
--- /dev/null
+++ b/src/net/iprawsock_test.go
@@ -0,0 +1,202 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !js && !wasip1
+
+package net
+
+import (
+ "internal/testenv"
+ "reflect"
+ "testing"
+)
+
+// The full stack test cases for IPConn have been moved to the
+// following:
+// golang.org/x/net/ipv4
+// golang.org/x/net/ipv6
+// golang.org/x/net/icmp
+
+type resolveIPAddrTest struct {
+ network string
+ litAddrOrName string
+ addr *IPAddr
+ err error
+}
+
+var resolveIPAddrTests = []resolveIPAddrTest{
+ {"ip", "127.0.0.1", &IPAddr{IP: IPv4(127, 0, 0, 1)}, nil},
+ {"ip4", "127.0.0.1", &IPAddr{IP: IPv4(127, 0, 0, 1)}, nil},
+ {"ip4:icmp", "127.0.0.1", &IPAddr{IP: IPv4(127, 0, 0, 1)}, nil},
+
+ {"ip", "::1", &IPAddr{IP: ParseIP("::1")}, nil},
+ {"ip6", "::1", &IPAddr{IP: ParseIP("::1")}, nil},
+ {"ip6:ipv6-icmp", "::1", &IPAddr{IP: ParseIP("::1")}, nil},
+ {"ip6:IPv6-ICMP", "::1", &IPAddr{IP: ParseIP("::1")}, nil},
+
+ {"ip", "::1%en0", &IPAddr{IP: ParseIP("::1"), Zone: "en0"}, nil},
+ {"ip6", "::1%911", &IPAddr{IP: ParseIP("::1"), Zone: "911"}, nil},
+
+ {"", "127.0.0.1", &IPAddr{IP: IPv4(127, 0, 0, 1)}, nil}, // Go 1.0 behavior
+ {"", "::1", &IPAddr{IP: ParseIP("::1")}, nil}, // Go 1.0 behavior
+
+ {"ip4:icmp", "", &IPAddr{}, nil},
+
+ {"l2tp", "127.0.0.1", nil, UnknownNetworkError("l2tp")},
+ {"l2tp:gre", "127.0.0.1", nil, UnknownNetworkError("l2tp:gre")},
+ {"tcp", "1.2.3.4:123", nil, UnknownNetworkError("tcp")},
+
+ {"ip4", "2001:db8::1", nil, &AddrError{Err: errNoSuitableAddress.Error(), Addr: "2001:db8::1"}},
+ {"ip4:icmp", "2001:db8::1", nil, &AddrError{Err: errNoSuitableAddress.Error(), Addr: "2001:db8::1"}},
+ {"ip6", "127.0.0.1", nil, &AddrError{Err: errNoSuitableAddress.Error(), Addr: "127.0.0.1"}},
+ {"ip6", "::ffff:127.0.0.1", nil, &AddrError{Err: errNoSuitableAddress.Error(), Addr: "::ffff:127.0.0.1"}},
+ {"ip6:ipv6-icmp", "127.0.0.1", nil, &AddrError{Err: errNoSuitableAddress.Error(), Addr: "127.0.0.1"}},
+ {"ip6:ipv6-icmp", "::ffff:127.0.0.1", nil, &AddrError{Err: errNoSuitableAddress.Error(), Addr: "::ffff:127.0.0.1"}},
+}
+
+func TestResolveIPAddr(t *testing.T) {
+ if !testableNetwork("ip+nopriv") {
+ t.Skip("ip+nopriv test")
+ }
+
+ origTestHookLookupIP := testHookLookupIP
+ defer func() { testHookLookupIP = origTestHookLookupIP }()
+ testHookLookupIP = lookupLocalhost
+
+ for _, tt := range resolveIPAddrTests {
+ addr, err := ResolveIPAddr(tt.network, tt.litAddrOrName)
+ if !reflect.DeepEqual(addr, tt.addr) || !reflect.DeepEqual(err, tt.err) {
+ t.Errorf("ResolveIPAddr(%q, %q) = %#v, %v, want %#v, %v", tt.network, tt.litAddrOrName, addr, err, tt.addr, tt.err)
+ continue
+ }
+ if err == nil {
+ addr2, err := ResolveIPAddr(addr.Network(), addr.String())
+ if !reflect.DeepEqual(addr2, tt.addr) || err != tt.err {
+ t.Errorf("(%q, %q): ResolveIPAddr(%q, %q) = %#v, %v, want %#v, %v", tt.network, tt.litAddrOrName, addr.Network(), addr.String(), addr2, err, tt.addr, tt.err)
+ }
+ }
+ }
+}
+
+var ipConnLocalNameTests = []struct {
+ net string
+ laddr *IPAddr
+}{
+ {"ip4:icmp", &IPAddr{IP: IPv4(127, 0, 0, 1)}},
+ {"ip4:icmp", &IPAddr{}},
+ {"ip4:icmp", nil},
+}
+
+func TestIPConnLocalName(t *testing.T) {
+ for _, tt := range ipConnLocalNameTests {
+ if !testableNetwork(tt.net) {
+ t.Logf("skipping %s test", tt.net)
+ continue
+ }
+ c, err := ListenIP(tt.net, tt.laddr)
+ if testenv.SyscallIsNotSupported(err) {
+ // May be inside a container that disallows creating a socket.
+ t.Logf("skipping %s test: %v", tt.net, err)
+ continue
+ } else if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+ if la := c.LocalAddr(); la == nil {
+ t.Fatal("should not fail")
+ }
+ }
+}
+
+func TestIPConnRemoteName(t *testing.T) {
+ network := "ip:tcp"
+ if !testableNetwork(network) {
+ t.Skipf("skipping %s test", network)
+ }
+
+ raddr := &IPAddr{IP: IPv4(127, 0, 0, 1).To4()}
+ c, err := DialIP(network, &IPAddr{IP: IPv4(127, 0, 0, 1)}, raddr)
+ if testenv.SyscallIsNotSupported(err) {
+ // May be inside a container that disallows creating a socket.
+ t.Skipf("skipping %s test: %v", network, err)
+ } else if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+ if !reflect.DeepEqual(raddr, c.RemoteAddr()) {
+ t.Fatalf("got %#v; want %#v", c.RemoteAddr(), raddr)
+ }
+}
+
+func TestDialListenIPArgs(t *testing.T) {
+ type test struct {
+ argLists [][2]string
+ shouldFail bool
+ }
+ tests := []test{
+ {
+ argLists: [][2]string{
+ {"ip", "127.0.0.1"},
+ {"ip:", "127.0.0.1"},
+ {"ip::", "127.0.0.1"},
+ {"ip", "::1"},
+ {"ip:", "::1"},
+ {"ip::", "::1"},
+ {"ip4", "127.0.0.1"},
+ {"ip4:", "127.0.0.1"},
+ {"ip4::", "127.0.0.1"},
+ {"ip6", "::1"},
+ {"ip6:", "::1"},
+ {"ip6::", "::1"},
+ },
+ shouldFail: true,
+ },
+ }
+ if testableNetwork("ip") {
+ priv := test{shouldFail: false}
+ for _, tt := range []struct {
+ network, address string
+ args [2]string
+ }{
+ {"ip4:47", "127.0.0.1", [2]string{"ip4:47", "127.0.0.1"}},
+ {"ip6:47", "::1", [2]string{"ip6:47", "::1"}},
+ } {
+ c, err := ListenPacket(tt.network, tt.address)
+ if err != nil {
+ continue
+ }
+ c.Close()
+ priv.argLists = append(priv.argLists, tt.args)
+ }
+ if len(priv.argLists) > 0 {
+ tests = append(tests, priv)
+ }
+ }
+
+ for _, tt := range tests {
+ for _, args := range tt.argLists {
+ _, err := Dial(args[0], args[1])
+ if tt.shouldFail != (err != nil) {
+ t.Errorf("Dial(%q, %q) = %v; want (err != nil) is %t", args[0], args[1], err, tt.shouldFail)
+ }
+ _, err = ListenPacket(args[0], args[1])
+ if tt.shouldFail != (err != nil) {
+ t.Errorf("ListenPacket(%q, %q) = %v; want (err != nil) is %t", args[0], args[1], err, tt.shouldFail)
+ }
+ a, err := ResolveIPAddr("ip", args[1])
+ if err != nil {
+ t.Errorf("ResolveIPAddr(\"ip\", %q) = %v", args[1], err)
+ continue
+ }
+ _, err = DialIP(args[0], nil, a)
+ if tt.shouldFail != (err != nil) {
+ t.Errorf("DialIP(%q, %v) = %v; want (err != nil) is %t", args[0], a, err, tt.shouldFail)
+ }
+ _, err = ListenIP(args[0], a)
+ if tt.shouldFail != (err != nil) {
+ t.Errorf("ListenIP(%q, %v) = %v; want (err != nil) is %t", args[0], a, err, tt.shouldFail)
+ }
+ }
+ }
+}
diff --git a/src/net/ipsock.go b/src/net/ipsock.go
new file mode 100644
index 0000000..0f5da25
--- /dev/null
+++ b/src/net/ipsock.go
@@ -0,0 +1,315 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "context"
+ "internal/bytealg"
+ "runtime"
+ "sync"
+)
+
+// BUG(rsc,mikio): On DragonFly BSD and OpenBSD, listening on the
+// "tcp" and "udp" networks does not listen for both IPv4 and IPv6
+// connections. This is due to the fact that IPv4 traffic will not be
+// routed to an IPv6 socket - two separate sockets are required if
+// both address families are to be supported.
+// See inet6(4) for details.
+
+type ipStackCapabilities struct {
+ sync.Once // guards following
+ ipv4Enabled bool
+ ipv6Enabled bool
+ ipv4MappedIPv6Enabled bool
+}
+
+var ipStackCaps ipStackCapabilities
+
+// supportsIPv4 reports whether the platform supports IPv4 networking
+// functionality.
+func supportsIPv4() bool {
+ ipStackCaps.Once.Do(ipStackCaps.probe)
+ return ipStackCaps.ipv4Enabled
+}
+
+// supportsIPv6 reports whether the platform supports IPv6 networking
+// functionality.
+func supportsIPv6() bool {
+ ipStackCaps.Once.Do(ipStackCaps.probe)
+ return ipStackCaps.ipv6Enabled
+}
+
+// supportsIPv4map reports whether the platform supports mapping an
+// IPv4 address inside an IPv6 address at transport layer
+// protocols. See RFC 4291, RFC 4038 and RFC 3493.
+func supportsIPv4map() bool {
+ // Some operating systems provide no support for mapping IPv4
+ // addresses to IPv6, and a runtime check is unnecessary.
+ switch runtime.GOOS {
+ case "dragonfly", "openbsd":
+ return false
+ }
+
+ ipStackCaps.Once.Do(ipStackCaps.probe)
+ return ipStackCaps.ipv4MappedIPv6Enabled
+}
+
+// An addrList represents a list of network endpoint addresses.
+type addrList []Addr
+
+// isIPv4 reports whether addr contains an IPv4 address.
+func isIPv4(addr Addr) bool {
+ switch addr := addr.(type) {
+ case *TCPAddr:
+ return addr.IP.To4() != nil
+ case *UDPAddr:
+ return addr.IP.To4() != nil
+ case *IPAddr:
+ return addr.IP.To4() != nil
+ }
+ return false
+}
+
+// isNotIPv4 reports whether addr does not contain an IPv4 address.
+func isNotIPv4(addr Addr) bool { return !isIPv4(addr) }
+
+// forResolve returns the most appropriate address in address for
+// a call to ResolveTCPAddr, ResolveUDPAddr, or ResolveIPAddr.
+// IPv4 is preferred, unless addr contains an IPv6 literal.
+func (addrs addrList) forResolve(network, addr string) Addr {
+ var want6 bool
+ switch network {
+ case "ip":
+ // IPv6 literal (addr does NOT contain a port)
+ want6 = count(addr, ':') > 0
+ case "tcp", "udp":
+ // IPv6 literal. (addr contains a port, so look for '[')
+ want6 = count(addr, '[') > 0
+ }
+ if want6 {
+ return addrs.first(isNotIPv4)
+ }
+ return addrs.first(isIPv4)
+}
+
+// first returns the first address which satisfies strategy, or if
+// none do, then the first address of any kind.
+func (addrs addrList) first(strategy func(Addr) bool) Addr {
+ for _, addr := range addrs {
+ if strategy(addr) {
+ return addr
+ }
+ }
+ return addrs[0]
+}
+
+// partition divides an address list into two categories, using a
+// strategy function to assign a boolean label to each address.
+// The first address, and any with a matching label, are returned as
+// primaries, while addresses with the opposite label are returned
+// as fallbacks. For non-empty inputs, primaries is guaranteed to be
+// non-empty.
+func (addrs addrList) partition(strategy func(Addr) bool) (primaries, fallbacks addrList) {
+ var primaryLabel bool
+ for i, addr := range addrs {
+ label := strategy(addr)
+ if i == 0 || label == primaryLabel {
+ primaryLabel = label
+ primaries = append(primaries, addr)
+ } else {
+ fallbacks = append(fallbacks, addr)
+ }
+ }
+ return
+}
+
+// filterAddrList applies a filter to a list of IP addresses,
+// yielding a list of Addr objects. Known filters are nil, ipv4only,
+// and ipv6only. It returns every address when the filter is nil.
+// The result contains at least one address when error is nil.
+func filterAddrList(filter func(IPAddr) bool, ips []IPAddr, inetaddr func(IPAddr) Addr, originalAddr string) (addrList, error) {
+ var addrs addrList
+ for _, ip := range ips {
+ if filter == nil || filter(ip) {
+ addrs = append(addrs, inetaddr(ip))
+ }
+ }
+ if len(addrs) == 0 {
+ return nil, &AddrError{Err: errNoSuitableAddress.Error(), Addr: originalAddr}
+ }
+ return addrs, nil
+}
+
+// ipv4only reports whether addr is an IPv4 address.
+func ipv4only(addr IPAddr) bool {
+ return addr.IP.To4() != nil
+}
+
+// ipv6only reports whether addr is an IPv6 address except IPv4-mapped IPv6 address.
+func ipv6only(addr IPAddr) bool {
+ return len(addr.IP) == IPv6len && addr.IP.To4() == nil
+}
+
+// SplitHostPort splits a network address of the form "host:port",
+// "host%zone:port", "[host]:port" or "[host%zone]:port" into host or
+// host%zone and port.
+//
+// A literal IPv6 address in hostport must be enclosed in square
+// brackets, as in "[::1]:80", "[::1%lo0]:80".
+//
+// See func Dial for a description of the hostport parameter, and host
+// and port results.
+func SplitHostPort(hostport string) (host, port string, err error) {
+ const (
+ missingPort = "missing port in address"
+ tooManyColons = "too many colons in address"
+ )
+ addrErr := func(addr, why string) (host, port string, err error) {
+ return "", "", &AddrError{Err: why, Addr: addr}
+ }
+ j, k := 0, 0
+
+ // The port starts after the last colon.
+ i := last(hostport, ':')
+ if i < 0 {
+ return addrErr(hostport, missingPort)
+ }
+
+ if hostport[0] == '[' {
+ // Expect the first ']' just before the last ':'.
+ end := bytealg.IndexByteString(hostport, ']')
+ if end < 0 {
+ return addrErr(hostport, "missing ']' in address")
+ }
+ switch end + 1 {
+ case len(hostport):
+ // There can't be a ':' behind the ']' now.
+ return addrErr(hostport, missingPort)
+ case i:
+ // The expected result.
+ default:
+ // Either ']' isn't followed by a colon, or it is
+ // followed by a colon that is not the last one.
+ if hostport[end+1] == ':' {
+ return addrErr(hostport, tooManyColons)
+ }
+ return addrErr(hostport, missingPort)
+ }
+ host = hostport[1:end]
+ j, k = 1, end+1 // there can't be a '[' resp. ']' before these positions
+ } else {
+ host = hostport[:i]
+ if bytealg.IndexByteString(host, ':') >= 0 {
+ return addrErr(hostport, tooManyColons)
+ }
+ }
+ if bytealg.IndexByteString(hostport[j:], '[') >= 0 {
+ return addrErr(hostport, "unexpected '[' in address")
+ }
+ if bytealg.IndexByteString(hostport[k:], ']') >= 0 {
+ return addrErr(hostport, "unexpected ']' in address")
+ }
+
+ port = hostport[i+1:]
+ return host, port, nil
+}
+
+func splitHostZone(s string) (host, zone string) {
+ // The IPv6 scoped addressing zone identifier starts after the
+ // last percent sign.
+ if i := last(s, '%'); i > 0 {
+ host, zone = s[:i], s[i+1:]
+ } else {
+ host = s
+ }
+ return
+}
+
+// JoinHostPort combines host and port into a network address of the
+// form "host:port". If host contains a colon, as found in literal
+// IPv6 addresses, then JoinHostPort returns "[host]:port".
+//
+// See func Dial for a description of the host and port parameters.
+func JoinHostPort(host, port string) string {
+ // We assume that host is a literal IPv6 address if host has
+ // colons.
+ if bytealg.IndexByteString(host, ':') >= 0 {
+ return "[" + host + "]:" + port
+ }
+ return host + ":" + port
+}
+
+// internetAddrList resolves addr, which may be a literal IP
+// address or a DNS name, and returns a list of internet protocol
+// family addresses. The result contains at least one address when
+// error is nil.
+func (r *Resolver) internetAddrList(ctx context.Context, net, addr string) (addrList, error) {
+ var (
+ err error
+ host, port string
+ portnum int
+ )
+ switch net {
+ case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6":
+ if addr != "" {
+ if host, port, err = SplitHostPort(addr); err != nil {
+ return nil, err
+ }
+ if portnum, err = r.LookupPort(ctx, net, port); err != nil {
+ return nil, err
+ }
+ }
+ case "ip", "ip4", "ip6":
+ if addr != "" {
+ host = addr
+ }
+ default:
+ return nil, UnknownNetworkError(net)
+ }
+ inetaddr := func(ip IPAddr) Addr {
+ switch net {
+ case "tcp", "tcp4", "tcp6":
+ return &TCPAddr{IP: ip.IP, Port: portnum, Zone: ip.Zone}
+ case "udp", "udp4", "udp6":
+ return &UDPAddr{IP: ip.IP, Port: portnum, Zone: ip.Zone}
+ case "ip", "ip4", "ip6":
+ return &IPAddr{IP: ip.IP, Zone: ip.Zone}
+ default:
+ panic("unexpected network: " + net)
+ }
+ }
+ if host == "" {
+ return addrList{inetaddr(IPAddr{})}, nil
+ }
+
+ // Try as a literal IP address, then as a DNS name.
+ ips, err := r.lookupIPAddr(ctx, net, host)
+ if err != nil {
+ return nil, err
+ }
+ // Issue 18806: if the machine has halfway configured
+ // IPv6 such that it can bind on "::" (IPv6unspecified)
+ // but not connect back to that same address, fall
+ // back to dialing 0.0.0.0.
+ if len(ips) == 1 && ips[0].IP.Equal(IPv6unspecified) {
+ ips = append(ips, IPAddr{IP: IPv4zero})
+ }
+
+ var filter func(IPAddr) bool
+ if net != "" && net[len(net)-1] == '4' {
+ filter = ipv4only
+ }
+ if net != "" && net[len(net)-1] == '6' {
+ filter = ipv6only
+ }
+ return filterAddrList(filter, ips, inetaddr, host)
+}
+
+func loopbackIP(net string) IP {
+ if net != "" && net[len(net)-1] == '6' {
+ return IPv6loopback
+ }
+ return IP{127, 0, 0, 1}
+}
diff --git a/src/net/ipsock_plan9.go b/src/net/ipsock_plan9.go
new file mode 100644
index 0000000..4328743
--- /dev/null
+++ b/src/net/ipsock_plan9.go
@@ -0,0 +1,367 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "context"
+ "internal/bytealg"
+ "internal/itoa"
+ "io/fs"
+ "os"
+ "syscall"
+)
+
+// probe probes IPv4, IPv6 and IPv4-mapped IPv6 communication
+// capabilities.
+//
+// Plan 9 uses IPv6 natively, see ip(3).
+func (p *ipStackCapabilities) probe() {
+ p.ipv4Enabled = probe(netdir+"/iproute", "4i")
+ p.ipv6Enabled = probe(netdir+"/iproute", "6i")
+ if p.ipv4Enabled && p.ipv6Enabled {
+ p.ipv4MappedIPv6Enabled = true
+ }
+}
+
+func probe(filename, query string) bool {
+ var file *file
+ var err error
+ if file, err = open(filename); err != nil {
+ return false
+ }
+ defer file.close()
+
+ r := false
+ for line, ok := file.readLine(); ok && !r; line, ok = file.readLine() {
+ f := getFields(line)
+ if len(f) < 3 {
+ continue
+ }
+ for i := 0; i < len(f); i++ {
+ if query == f[i] {
+ r = true
+ break
+ }
+ }
+ }
+ return r
+}
+
+// parsePlan9Addr parses address of the form [ip!]port (e.g. 127.0.0.1!80).
+func parsePlan9Addr(s string) (ip IP, iport int, err error) {
+ addr := IPv4zero // address contains port only
+ i := bytealg.IndexByteString(s, '!')
+ if i >= 0 {
+ addr = ParseIP(s[:i])
+ if addr == nil {
+ return nil, 0, &ParseError{Type: "IP address", Text: s}
+ }
+ }
+ p, plen, ok := dtoi(s[i+1:])
+ if !ok {
+ return nil, 0, &ParseError{Type: "port", Text: s}
+ }
+ if p < 0 || p > 0xFFFF {
+ return nil, 0, &AddrError{Err: "invalid port", Addr: s[i+1 : i+1+plen]}
+ }
+ return addr, p, nil
+}
+
+func readPlan9Addr(net, filename string) (addr Addr, err error) {
+ var buf [128]byte
+
+ f, err := os.Open(filename)
+ if err != nil {
+ return
+ }
+ defer f.Close()
+ n, err := f.Read(buf[:])
+ if err != nil {
+ return
+ }
+ ip, port, err := parsePlan9Addr(string(buf[:n]))
+ if err != nil {
+ return
+ }
+ switch net {
+ case "tcp4", "udp4":
+ if ip.Equal(IPv6zero) {
+ ip = ip[:IPv4len]
+ }
+ }
+ switch net {
+ case "tcp", "tcp4", "tcp6":
+ addr = &TCPAddr{IP: ip, Port: port}
+ case "udp", "udp4", "udp6":
+ addr = &UDPAddr{IP: ip, Port: port}
+ default:
+ return nil, UnknownNetworkError(net)
+ }
+ return addr, nil
+}
+
+func startPlan9(ctx context.Context, net string, addr Addr) (ctl *os.File, dest, proto, name string, err error) {
+ var (
+ ip IP
+ port int
+ )
+ switch a := addr.(type) {
+ case *TCPAddr:
+ proto = "tcp"
+ ip = a.IP
+ port = a.Port
+ case *UDPAddr:
+ proto = "udp"
+ ip = a.IP
+ port = a.Port
+ default:
+ err = UnknownNetworkError(net)
+ return
+ }
+
+ if port > 65535 {
+ err = InvalidAddrError("port should be < 65536")
+ return
+ }
+
+ clone, dest, err := queryCS1(ctx, proto, ip, port)
+ if err != nil {
+ return
+ }
+ f, err := os.OpenFile(clone, os.O_RDWR, 0)
+ if err != nil {
+ return
+ }
+ var buf [16]byte
+ n, err := f.Read(buf[:])
+ if err != nil {
+ f.Close()
+ return
+ }
+ return f, dest, proto, string(buf[:n]), nil
+}
+
+func fixErr(err error) {
+ oe, ok := err.(*OpError)
+ if !ok {
+ return
+ }
+ nonNilInterface := func(a Addr) bool {
+ switch a := a.(type) {
+ case *TCPAddr:
+ return a == nil
+ case *UDPAddr:
+ return a == nil
+ case *IPAddr:
+ return a == nil
+ default:
+ return false
+ }
+ }
+ if nonNilInterface(oe.Source) {
+ oe.Source = nil
+ }
+ if nonNilInterface(oe.Addr) {
+ oe.Addr = nil
+ }
+ if pe, ok := oe.Err.(*fs.PathError); ok {
+ if _, ok = pe.Err.(syscall.ErrorString); ok {
+ oe.Err = pe.Err
+ }
+ }
+}
+
+func dialPlan9(ctx context.Context, net string, laddr, raddr Addr) (fd *netFD, err error) {
+ defer func() { fixErr(err) }()
+ type res struct {
+ fd *netFD
+ err error
+ }
+ resc := make(chan res)
+ go func() {
+ testHookDialChannel()
+ fd, err := dialPlan9Blocking(ctx, net, laddr, raddr)
+ select {
+ case resc <- res{fd, err}:
+ case <-ctx.Done():
+ if fd != nil {
+ fd.Close()
+ }
+ }
+ }()
+ select {
+ case res := <-resc:
+ return res.fd, res.err
+ case <-ctx.Done():
+ return nil, mapErr(ctx.Err())
+ }
+}
+
+func dialPlan9Blocking(ctx context.Context, net string, laddr, raddr Addr) (fd *netFD, err error) {
+ if isWildcard(raddr) {
+ raddr = toLocal(raddr, net)
+ }
+ f, dest, proto, name, err := startPlan9(ctx, net, raddr)
+ if err != nil {
+ return nil, err
+ }
+ if la := plan9LocalAddr(laddr); la == "" {
+ err = hangupCtlWrite(ctx, proto, f, "connect "+dest)
+ } else {
+ err = hangupCtlWrite(ctx, proto, f, "connect "+dest+" "+la)
+ }
+ if err != nil {
+ f.Close()
+ return nil, err
+ }
+ data, err := os.OpenFile(netdir+"/"+proto+"/"+name+"/data", os.O_RDWR, 0)
+ if err != nil {
+ f.Close()
+ return nil, err
+ }
+ laddr, err = readPlan9Addr(net, netdir+"/"+proto+"/"+name+"/local")
+ if err != nil {
+ data.Close()
+ f.Close()
+ return nil, err
+ }
+ return newFD(proto, name, nil, f, data, laddr, raddr)
+}
+
+func listenPlan9(ctx context.Context, net string, laddr Addr) (fd *netFD, err error) {
+ defer func() { fixErr(err) }()
+ f, dest, proto, name, err := startPlan9(ctx, net, laddr)
+ if err != nil {
+ return nil, err
+ }
+ _, err = f.WriteString("announce " + dest)
+ if err != nil {
+ f.Close()
+ return nil, &OpError{Op: "announce", Net: net, Source: laddr, Addr: nil, Err: err}
+ }
+ laddr, err = readPlan9Addr(net, netdir+"/"+proto+"/"+name+"/local")
+ if err != nil {
+ f.Close()
+ return nil, err
+ }
+ return newFD(proto, name, nil, f, nil, laddr, nil)
+}
+
+func (fd *netFD) netFD() (*netFD, error) {
+ return newFD(fd.net, fd.n, fd.listen, fd.ctl, fd.data, fd.laddr, fd.raddr)
+}
+
+func (fd *netFD) acceptPlan9() (nfd *netFD, err error) {
+ defer func() { fixErr(err) }()
+ if err := fd.pfd.ReadLock(); err != nil {
+ return nil, err
+ }
+ defer fd.pfd.ReadUnlock()
+ listen, err := os.Open(fd.dir + "/listen")
+ if err != nil {
+ return nil, err
+ }
+ var buf [16]byte
+ n, err := listen.Read(buf[:])
+ if err != nil {
+ listen.Close()
+ return nil, err
+ }
+ name := string(buf[:n])
+ ctl, err := os.OpenFile(netdir+"/"+fd.net+"/"+name+"/ctl", os.O_RDWR, 0)
+ if err != nil {
+ listen.Close()
+ return nil, err
+ }
+ data, err := os.OpenFile(netdir+"/"+fd.net+"/"+name+"/data", os.O_RDWR, 0)
+ if err != nil {
+ listen.Close()
+ ctl.Close()
+ return nil, err
+ }
+ raddr, err := readPlan9Addr(fd.net, netdir+"/"+fd.net+"/"+name+"/remote")
+ if err != nil {
+ listen.Close()
+ ctl.Close()
+ data.Close()
+ return nil, err
+ }
+ return newFD(fd.net, name, listen, ctl, data, fd.laddr, raddr)
+}
+
+func isWildcard(a Addr) bool {
+ var wildcard bool
+ switch a := a.(type) {
+ case *TCPAddr:
+ wildcard = a.isWildcard()
+ case *UDPAddr:
+ wildcard = a.isWildcard()
+ case *IPAddr:
+ wildcard = a.isWildcard()
+ }
+ return wildcard
+}
+
+func toLocal(a Addr, net string) Addr {
+ switch a := a.(type) {
+ case *TCPAddr:
+ a.IP = loopbackIP(net)
+ case *UDPAddr:
+ a.IP = loopbackIP(net)
+ case *IPAddr:
+ a.IP = loopbackIP(net)
+ }
+ return a
+}
+
+// plan9LocalAddr returns a Plan 9 local address string.
+// See setladdrport at https://9p.io/sources/plan9/sys/src/9/ip/devip.c.
+func plan9LocalAddr(addr Addr) string {
+ var ip IP
+ port := 0
+ switch a := addr.(type) {
+ case *TCPAddr:
+ if a != nil {
+ ip = a.IP
+ port = a.Port
+ }
+ case *UDPAddr:
+ if a != nil {
+ ip = a.IP
+ port = a.Port
+ }
+ }
+ if len(ip) == 0 || ip.IsUnspecified() {
+ if port == 0 {
+ return ""
+ }
+ return itoa.Itoa(port)
+ }
+ return ip.String() + "!" + itoa.Itoa(port)
+}
+
+func hangupCtlWrite(ctx context.Context, proto string, ctl *os.File, msg string) error {
+ if proto != "tcp" {
+ _, err := ctl.WriteString(msg)
+ return err
+ }
+ written := make(chan struct{})
+ errc := make(chan error)
+ go func() {
+ select {
+ case <-ctx.Done():
+ ctl.WriteString("hangup")
+ errc <- mapErr(ctx.Err())
+ case <-written:
+ errc <- nil
+ }
+ }()
+ _, err := ctl.WriteString(msg)
+ close(written)
+ if e := <-errc; err == nil && e != nil { // we hung up
+ return e
+ }
+ return err
+}
diff --git a/src/net/ipsock_plan9_test.go b/src/net/ipsock_plan9_test.go
new file mode 100644
index 0000000..e5fb9ff
--- /dev/null
+++ b/src/net/ipsock_plan9_test.go
@@ -0,0 +1,29 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import "testing"
+
+func TestTCP4ListenZero(t *testing.T) {
+ l, err := Listen("tcp4", "0.0.0.0:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer l.Close()
+ if a := l.Addr(); isNotIPv4(a) {
+ t.Errorf("address does not contain IPv4: %v", a)
+ }
+}
+
+func TestUDP4ListenZero(t *testing.T) {
+ c, err := ListenPacket("udp4", "0.0.0.0:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+ if a := c.LocalAddr(); isNotIPv4(a) {
+ t.Errorf("address does not contain IPv4: %v", a)
+ }
+}
diff --git a/src/net/ipsock_posix.go b/src/net/ipsock_posix.go
new file mode 100644
index 0000000..b0a00a6
--- /dev/null
+++ b/src/net/ipsock_posix.go
@@ -0,0 +1,232 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build unix || (js && wasm) || wasip1 || windows
+
+package net
+
+import (
+ "context"
+ "internal/poll"
+ "net/netip"
+ "runtime"
+ "syscall"
+)
+
+// probe probes IPv4, IPv6 and IPv4-mapped IPv6 communication
+// capabilities which are controlled by the IPV6_V6ONLY socket option
+// and kernel configuration.
+//
+// Should we try to use the IPv4 socket interface if we're only
+// dealing with IPv4 sockets? As long as the host system understands
+// IPv4-mapped IPv6, it's okay to pass IPv4-mapped IPv6 addresses to
+// the IPv6 interface. That simplifies our code and is most
+// general. Unfortunately, we need to run on kernels built without
+// IPv6 support too. So probe the kernel to figure it out.
+func (p *ipStackCapabilities) probe() {
+ s, err := sysSocket(syscall.AF_INET, syscall.SOCK_STREAM, syscall.IPPROTO_TCP)
+ switch err {
+ case syscall.EAFNOSUPPORT, syscall.EPROTONOSUPPORT:
+ case nil:
+ poll.CloseFunc(s)
+ p.ipv4Enabled = true
+ }
+ var probes = []struct {
+ laddr TCPAddr
+ value int
+ }{
+ // IPv6 communication capability
+ {laddr: TCPAddr{IP: ParseIP("::1")}, value: 1},
+ // IPv4-mapped IPv6 address communication capability
+ {laddr: TCPAddr{IP: IPv4(127, 0, 0, 1)}, value: 0},
+ }
+ switch runtime.GOOS {
+ case "dragonfly", "openbsd":
+ // The latest DragonFly BSD and OpenBSD kernels don't
+ // support IPV6_V6ONLY=0. They always return an error
+ // and we don't need to probe the capability.
+ probes = probes[:1]
+ }
+ for i := range probes {
+ s, err := sysSocket(syscall.AF_INET6, syscall.SOCK_STREAM, syscall.IPPROTO_TCP)
+ if err != nil {
+ continue
+ }
+ defer poll.CloseFunc(s)
+ syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, probes[i].value)
+ sa, err := probes[i].laddr.sockaddr(syscall.AF_INET6)
+ if err != nil {
+ continue
+ }
+ if err := syscall.Bind(s, sa); err != nil {
+ continue
+ }
+ if i == 0 {
+ p.ipv6Enabled = true
+ } else {
+ p.ipv4MappedIPv6Enabled = true
+ }
+ }
+}
+
+// favoriteAddrFamily returns the appropriate address family for the
+// given network, laddr, raddr and mode.
+//
+// If mode indicates "listen" and laddr is a wildcard, we assume that
+// the user wants to make a passive-open connection with a wildcard
+// address family, both AF_INET and AF_INET6, and a wildcard address
+// like the following:
+//
+// - A listen for a wildcard communication domain, "tcp" or
+// "udp", with a wildcard address: If the platform supports
+// both IPv6 and IPv4-mapped IPv6 communication capabilities,
+// or does not support IPv4, we use a dual stack, AF_INET6 and
+// IPV6_V6ONLY=0, wildcard address listen. The dual stack
+// wildcard address listen may fall back to an IPv6-only,
+// AF_INET6 and IPV6_V6ONLY=1, wildcard address listen.
+// Otherwise we prefer an IPv4-only, AF_INET, wildcard address
+// listen.
+//
+// - A listen for a wildcard communication domain, "tcp" or
+// "udp", with an IPv4 wildcard address: same as above.
+//
+// - A listen for a wildcard communication domain, "tcp" or
+// "udp", with an IPv6 wildcard address: same as above.
+//
+// - A listen for an IPv4 communication domain, "tcp4" or "udp4",
+// with an IPv4 wildcard address: We use an IPv4-only, AF_INET,
+// wildcard address listen.
+//
+// - A listen for an IPv6 communication domain, "tcp6" or "udp6",
+// with an IPv6 wildcard address: We use an IPv6-only, AF_INET6
+// and IPV6_V6ONLY=1, wildcard address listen.
+//
+// Otherwise guess: If the addresses are IPv4 then returns AF_INET,
+// or else returns AF_INET6. It also returns a boolean value what
+// designates IPV6_V6ONLY option.
+//
+// Note that the latest DragonFly BSD and OpenBSD kernels allow
+// neither "net.inet6.ip6.v6only=1" change nor IPPROTO_IPV6 level
+// IPV6_V6ONLY socket option setting.
+func favoriteAddrFamily(network string, laddr, raddr sockaddr, mode string) (family int, ipv6only bool) {
+ switch network[len(network)-1] {
+ case '4':
+ return syscall.AF_INET, false
+ case '6':
+ return syscall.AF_INET6, true
+ }
+
+ if mode == "listen" && (laddr == nil || laddr.isWildcard()) {
+ if supportsIPv4map() || !supportsIPv4() {
+ return syscall.AF_INET6, false
+ }
+ if laddr == nil {
+ return syscall.AF_INET, false
+ }
+ return laddr.family(), false
+ }
+
+ if (laddr == nil || laddr.family() == syscall.AF_INET) &&
+ (raddr == nil || raddr.family() == syscall.AF_INET) {
+ return syscall.AF_INET, false
+ }
+ return syscall.AF_INET6, false
+}
+
+func internetSocket(ctx context.Context, net string, laddr, raddr sockaddr, sotype, proto int, mode string, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) (fd *netFD, err error) {
+ if (runtime.GOOS == "aix" || runtime.GOOS == "windows" || runtime.GOOS == "openbsd") && mode == "dial" && raddr.isWildcard() {
+ raddr = raddr.toLocal(net)
+ }
+ family, ipv6only := favoriteAddrFamily(net, laddr, raddr, mode)
+ return socket(ctx, net, family, sotype, proto, ipv6only, laddr, raddr, ctrlCtxFn)
+}
+
+func ipToSockaddrInet4(ip IP, port int) (syscall.SockaddrInet4, error) {
+ if len(ip) == 0 {
+ ip = IPv4zero
+ }
+ ip4 := ip.To4()
+ if ip4 == nil {
+ return syscall.SockaddrInet4{}, &AddrError{Err: "non-IPv4 address", Addr: ip.String()}
+ }
+ sa := syscall.SockaddrInet4{Port: port}
+ copy(sa.Addr[:], ip4)
+ return sa, nil
+}
+
+func ipToSockaddrInet6(ip IP, port int, zone string) (syscall.SockaddrInet6, error) {
+ // In general, an IP wildcard address, which is either
+ // "0.0.0.0" or "::", means the entire IP addressing
+ // space. For some historical reason, it is used to
+ // specify "any available address" on some operations
+ // of IP node.
+ //
+ // When the IP node supports IPv4-mapped IPv6 address,
+ // we allow a listener to listen to the wildcard
+ // address of both IP addressing spaces by specifying
+ // IPv6 wildcard address.
+ if len(ip) == 0 || ip.Equal(IPv4zero) {
+ ip = IPv6zero
+ }
+ // We accept any IPv6 address including IPv4-mapped
+ // IPv6 address.
+ ip6 := ip.To16()
+ if ip6 == nil {
+ return syscall.SockaddrInet6{}, &AddrError{Err: "non-IPv6 address", Addr: ip.String()}
+ }
+ sa := syscall.SockaddrInet6{Port: port, ZoneId: uint32(zoneCache.index(zone))}
+ copy(sa.Addr[:], ip6)
+ return sa, nil
+}
+
+func ipToSockaddr(family int, ip IP, port int, zone string) (syscall.Sockaddr, error) {
+ switch family {
+ case syscall.AF_INET:
+ sa, err := ipToSockaddrInet4(ip, port)
+ if err != nil {
+ return nil, err
+ }
+ return &sa, nil
+ case syscall.AF_INET6:
+ sa, err := ipToSockaddrInet6(ip, port, zone)
+ if err != nil {
+ return nil, err
+ }
+ return &sa, nil
+ }
+ return nil, &AddrError{Err: "invalid address family", Addr: ip.String()}
+}
+
+func addrPortToSockaddrInet4(ap netip.AddrPort) (syscall.SockaddrInet4, error) {
+ // ipToSockaddrInet4 has special handling here for zero length slices.
+ // We do not, because netip has no concept of a generic zero IP address.
+ addr := ap.Addr()
+ if !addr.Is4() {
+ return syscall.SockaddrInet4{}, &AddrError{Err: "non-IPv4 address", Addr: addr.String()}
+ }
+ sa := syscall.SockaddrInet4{
+ Addr: addr.As4(),
+ Port: int(ap.Port()),
+ }
+ return sa, nil
+}
+
+func addrPortToSockaddrInet6(ap netip.AddrPort) (syscall.SockaddrInet6, error) {
+ // ipToSockaddrInet6 has special handling here for zero length slices.
+ // We do not, because netip has no concept of a generic zero IP address.
+ //
+ // addr is allowed to be an IPv4 address, because As16 will convert it
+ // to an IPv4-mapped IPv6 address.
+ // The error message is kept consistent with ipToSockaddrInet6.
+ addr := ap.Addr()
+ if !addr.IsValid() {
+ return syscall.SockaddrInet6{}, &AddrError{Err: "non-IPv6 address", Addr: addr.String()}
+ }
+ sa := syscall.SockaddrInet6{
+ Addr: addr.As16(),
+ Port: int(ap.Port()),
+ ZoneId: uint32(zoneCache.index(addr.Zone())),
+ }
+ return sa, nil
+}
diff --git a/src/net/ipsock_test.go b/src/net/ipsock_test.go
new file mode 100644
index 0000000..aede354
--- /dev/null
+++ b/src/net/ipsock_test.go
@@ -0,0 +1,282 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "reflect"
+ "testing"
+)
+
+var testInetaddr = func(ip IPAddr) Addr { return &TCPAddr{IP: ip.IP, Port: 5682, Zone: ip.Zone} }
+
+var addrListTests = []struct {
+ filter func(IPAddr) bool
+ ips []IPAddr
+ inetaddr func(IPAddr) Addr
+ first Addr
+ primaries addrList
+ fallbacks addrList
+ err error
+}{
+ {
+ nil,
+ []IPAddr{
+ {IP: IPv4(127, 0, 0, 1)},
+ {IP: IPv6loopback},
+ },
+ testInetaddr,
+ &TCPAddr{IP: IPv4(127, 0, 0, 1), Port: 5682},
+ addrList{&TCPAddr{IP: IPv4(127, 0, 0, 1), Port: 5682}},
+ addrList{&TCPAddr{IP: IPv6loopback, Port: 5682}},
+ nil,
+ },
+ {
+ nil,
+ []IPAddr{
+ {IP: IPv6loopback},
+ {IP: IPv4(127, 0, 0, 1)},
+ },
+ testInetaddr,
+ &TCPAddr{IP: IPv4(127, 0, 0, 1), Port: 5682},
+ addrList{&TCPAddr{IP: IPv6loopback, Port: 5682}},
+ addrList{&TCPAddr{IP: IPv4(127, 0, 0, 1), Port: 5682}},
+ nil,
+ },
+ {
+ nil,
+ []IPAddr{
+ {IP: IPv4(127, 0, 0, 1)},
+ {IP: IPv4(192, 168, 0, 1)},
+ },
+ testInetaddr,
+ &TCPAddr{IP: IPv4(127, 0, 0, 1), Port: 5682},
+ addrList{
+ &TCPAddr{IP: IPv4(127, 0, 0, 1), Port: 5682},
+ &TCPAddr{IP: IPv4(192, 168, 0, 1), Port: 5682},
+ },
+ nil,
+ nil,
+ },
+ {
+ nil,
+ []IPAddr{
+ {IP: IPv6loopback},
+ {IP: ParseIP("fe80::1"), Zone: "eth0"},
+ },
+ testInetaddr,
+ &TCPAddr{IP: IPv6loopback, Port: 5682},
+ addrList{
+ &TCPAddr{IP: IPv6loopback, Port: 5682},
+ &TCPAddr{IP: ParseIP("fe80::1"), Port: 5682, Zone: "eth0"},
+ },
+ nil,
+ nil,
+ },
+ {
+ nil,
+ []IPAddr{
+ {IP: IPv4(127, 0, 0, 1)},
+ {IP: IPv4(192, 168, 0, 1)},
+ {IP: IPv6loopback},
+ {IP: ParseIP("fe80::1"), Zone: "eth0"},
+ },
+ testInetaddr,
+ &TCPAddr{IP: IPv4(127, 0, 0, 1), Port: 5682},
+ addrList{
+ &TCPAddr{IP: IPv4(127, 0, 0, 1), Port: 5682},
+ &TCPAddr{IP: IPv4(192, 168, 0, 1), Port: 5682},
+ },
+ addrList{
+ &TCPAddr{IP: IPv6loopback, Port: 5682},
+ &TCPAddr{IP: ParseIP("fe80::1"), Port: 5682, Zone: "eth0"},
+ },
+ nil,
+ },
+ {
+ nil,
+ []IPAddr{
+ {IP: IPv6loopback},
+ {IP: ParseIP("fe80::1"), Zone: "eth0"},
+ {IP: IPv4(127, 0, 0, 1)},
+ {IP: IPv4(192, 168, 0, 1)},
+ },
+ testInetaddr,
+ &TCPAddr{IP: IPv4(127, 0, 0, 1), Port: 5682},
+ addrList{
+ &TCPAddr{IP: IPv6loopback, Port: 5682},
+ &TCPAddr{IP: ParseIP("fe80::1"), Port: 5682, Zone: "eth0"},
+ },
+ addrList{
+ &TCPAddr{IP: IPv4(127, 0, 0, 1), Port: 5682},
+ &TCPAddr{IP: IPv4(192, 168, 0, 1), Port: 5682},
+ },
+ nil,
+ },
+ {
+ nil,
+ []IPAddr{
+ {IP: IPv4(127, 0, 0, 1)},
+ {IP: IPv6loopback},
+ {IP: IPv4(192, 168, 0, 1)},
+ {IP: ParseIP("fe80::1"), Zone: "eth0"},
+ },
+ testInetaddr,
+ &TCPAddr{IP: IPv4(127, 0, 0, 1), Port: 5682},
+ addrList{
+ &TCPAddr{IP: IPv4(127, 0, 0, 1), Port: 5682},
+ &TCPAddr{IP: IPv4(192, 168, 0, 1), Port: 5682},
+ },
+ addrList{
+ &TCPAddr{IP: IPv6loopback, Port: 5682},
+ &TCPAddr{IP: ParseIP("fe80::1"), Port: 5682, Zone: "eth0"},
+ },
+ nil,
+ },
+ {
+ nil,
+ []IPAddr{
+ {IP: IPv6loopback},
+ {IP: IPv4(127, 0, 0, 1)},
+ {IP: ParseIP("fe80::1"), Zone: "eth0"},
+ {IP: IPv4(192, 168, 0, 1)},
+ },
+ testInetaddr,
+ &TCPAddr{IP: IPv4(127, 0, 0, 1), Port: 5682},
+ addrList{
+ &TCPAddr{IP: IPv6loopback, Port: 5682},
+ &TCPAddr{IP: ParseIP("fe80::1"), Port: 5682, Zone: "eth0"},
+ },
+ addrList{
+ &TCPAddr{IP: IPv4(127, 0, 0, 1), Port: 5682},
+ &TCPAddr{IP: IPv4(192, 168, 0, 1), Port: 5682},
+ },
+ nil,
+ },
+
+ {
+ ipv4only,
+ []IPAddr{
+ {IP: IPv4(127, 0, 0, 1)},
+ {IP: IPv6loopback},
+ },
+ testInetaddr,
+ &TCPAddr{IP: IPv4(127, 0, 0, 1), Port: 5682},
+ addrList{&TCPAddr{IP: IPv4(127, 0, 0, 1), Port: 5682}},
+ nil,
+ nil,
+ },
+ {
+ ipv4only,
+ []IPAddr{
+ {IP: IPv6loopback},
+ {IP: IPv4(127, 0, 0, 1)},
+ },
+ testInetaddr,
+ &TCPAddr{IP: IPv4(127, 0, 0, 1), Port: 5682},
+ addrList{&TCPAddr{IP: IPv4(127, 0, 0, 1), Port: 5682}},
+ nil,
+ nil,
+ },
+
+ {
+ ipv6only,
+ []IPAddr{
+ {IP: IPv4(127, 0, 0, 1)},
+ {IP: IPv6loopback},
+ },
+ testInetaddr,
+ &TCPAddr{IP: IPv6loopback, Port: 5682},
+ addrList{&TCPAddr{IP: IPv6loopback, Port: 5682}},
+ nil,
+ nil,
+ },
+ {
+ ipv6only,
+ []IPAddr{
+ {IP: IPv6loopback},
+ {IP: IPv4(127, 0, 0, 1)},
+ },
+ testInetaddr,
+ &TCPAddr{IP: IPv6loopback, Port: 5682},
+ addrList{&TCPAddr{IP: IPv6loopback, Port: 5682}},
+ nil,
+ nil,
+ },
+
+ {nil, nil, testInetaddr, nil, nil, nil, &AddrError{errNoSuitableAddress.Error(), "ADDR"}},
+
+ {ipv4only, nil, testInetaddr, nil, nil, nil, &AddrError{errNoSuitableAddress.Error(), "ADDR"}},
+ {ipv4only, []IPAddr{{IP: IPv6loopback}}, testInetaddr, nil, nil, nil, &AddrError{errNoSuitableAddress.Error(), "ADDR"}},
+
+ {ipv6only, nil, testInetaddr, nil, nil, nil, &AddrError{errNoSuitableAddress.Error(), "ADDR"}},
+ {ipv6only, []IPAddr{{IP: IPv4(127, 0, 0, 1)}}, testInetaddr, nil, nil, nil, &AddrError{errNoSuitableAddress.Error(), "ADDR"}},
+}
+
+func TestAddrList(t *testing.T) {
+ if !supportsIPv4() || !supportsIPv6() {
+ t.Skip("both IPv4 and IPv6 are required")
+ }
+
+ for i, tt := range addrListTests {
+ addrs, err := filterAddrList(tt.filter, tt.ips, tt.inetaddr, "ADDR")
+ if !reflect.DeepEqual(err, tt.err) {
+ t.Errorf("#%v: got %v; want %v", i, err, tt.err)
+ }
+ if tt.err != nil {
+ if len(addrs) != 0 {
+ t.Errorf("#%v: got %v; want 0", i, len(addrs))
+ }
+ continue
+ }
+ first := addrs.first(isIPv4)
+ if !reflect.DeepEqual(first, tt.first) {
+ t.Errorf("#%v: got %v; want %v", i, first, tt.first)
+ }
+ primaries, fallbacks := addrs.partition(isIPv4)
+ if !reflect.DeepEqual(primaries, tt.primaries) {
+ t.Errorf("#%v: got %v; want %v", i, primaries, tt.primaries)
+ }
+ if !reflect.DeepEqual(fallbacks, tt.fallbacks) {
+ t.Errorf("#%v: got %v; want %v", i, fallbacks, tt.fallbacks)
+ }
+ expectedLen := len(primaries) + len(fallbacks)
+ if len(addrs) != expectedLen {
+ t.Errorf("#%v: got %v; want %v", i, len(addrs), expectedLen)
+ }
+ }
+}
+
+func TestAddrListPartition(t *testing.T) {
+ addrs := addrList{
+ &IPAddr{IP: ParseIP("fe80::"), Zone: "eth0"},
+ &IPAddr{IP: ParseIP("fe80::1"), Zone: "eth0"},
+ &IPAddr{IP: ParseIP("fe80::2"), Zone: "eth0"},
+ }
+ cases := []struct {
+ lastByte byte
+ primaries addrList
+ fallbacks addrList
+ }{
+ {0, addrList{addrs[0]}, addrList{addrs[1], addrs[2]}},
+ {1, addrList{addrs[0], addrs[2]}, addrList{addrs[1]}},
+ {2, addrList{addrs[0], addrs[1]}, addrList{addrs[2]}},
+ {3, addrList{addrs[0], addrs[1], addrs[2]}, nil},
+ }
+ for i, tt := range cases {
+ // Inverting the function's output should not affect the outcome.
+ for _, invert := range []bool{false, true} {
+ primaries, fallbacks := addrs.partition(func(a Addr) bool {
+ ip := a.(*IPAddr).IP
+ return (ip[len(ip)-1] == tt.lastByte) != invert
+ })
+ if !reflect.DeepEqual(primaries, tt.primaries) {
+ t.Errorf("#%v: got %v; want %v", i, primaries, tt.primaries)
+ }
+ if !reflect.DeepEqual(fallbacks, tt.fallbacks) {
+ t.Errorf("#%v: got %v; want %v", i, fallbacks, tt.fallbacks)
+ }
+ }
+ }
+}
diff --git a/src/net/listen_test.go b/src/net/listen_test.go
new file mode 100644
index 0000000..f0a8861
--- /dev/null
+++ b/src/net/listen_test.go
@@ -0,0 +1,750 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !js && !plan9 && !wasip1
+
+package net
+
+import (
+ "fmt"
+ "internal/testenv"
+ "os"
+ "runtime"
+ "syscall"
+ "testing"
+ "time"
+)
+
+func (ln *TCPListener) port() string {
+ _, port, err := SplitHostPort(ln.Addr().String())
+ if err != nil {
+ return ""
+ }
+ return port
+}
+
+func (c *UDPConn) port() string {
+ _, port, err := SplitHostPort(c.LocalAddr().String())
+ if err != nil {
+ return ""
+ }
+ return port
+}
+
+var tcpListenerTests = []struct {
+ network string
+ address string
+}{
+ {"tcp", ""},
+ {"tcp", "0.0.0.0"},
+ {"tcp", "::ffff:0.0.0.0"},
+ {"tcp", "::"},
+
+ {"tcp", "127.0.0.1"},
+ {"tcp", "::ffff:127.0.0.1"},
+ {"tcp", "::1"},
+
+ {"tcp4", ""},
+ {"tcp4", "0.0.0.0"},
+ {"tcp4", "::ffff:0.0.0.0"},
+
+ {"tcp4", "127.0.0.1"},
+ {"tcp4", "::ffff:127.0.0.1"},
+
+ {"tcp6", ""},
+ {"tcp6", "::"},
+
+ {"tcp6", "::1"},
+}
+
+// TestTCPListener tests both single and double listen to a test
+// listener with same address family, same listening address and
+// same port.
+func TestTCPListener(t *testing.T) {
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("not supported on %s", runtime.GOOS)
+ }
+
+ for _, tt := range tcpListenerTests {
+ if !testableListenArgs(tt.network, JoinHostPort(tt.address, "0"), "") {
+ t.Logf("skipping %s test", tt.network+" "+tt.address)
+ continue
+ }
+
+ ln1, err := Listen(tt.network, JoinHostPort(tt.address, "0"))
+ if err != nil {
+ t.Fatal(err)
+ }
+ if err := checkFirstListener(tt.network, ln1); err != nil {
+ ln1.Close()
+ t.Fatal(err)
+ }
+ ln2, err := Listen(tt.network, JoinHostPort(tt.address, ln1.(*TCPListener).port()))
+ if err == nil {
+ ln2.Close()
+ }
+ if err := checkSecondListener(tt.network, tt.address, err); err != nil {
+ ln1.Close()
+ t.Fatal(err)
+ }
+ ln1.Close()
+ }
+}
+
+var udpListenerTests = []struct {
+ network string
+ address string
+}{
+ {"udp", ""},
+ {"udp", "0.0.0.0"},
+ {"udp", "::ffff:0.0.0.0"},
+ {"udp", "::"},
+
+ {"udp", "127.0.0.1"},
+ {"udp", "::ffff:127.0.0.1"},
+ {"udp", "::1"},
+
+ {"udp4", ""},
+ {"udp4", "0.0.0.0"},
+ {"udp4", "::ffff:0.0.0.0"},
+
+ {"udp4", "127.0.0.1"},
+ {"udp4", "::ffff:127.0.0.1"},
+
+ {"udp6", ""},
+ {"udp6", "::"},
+
+ {"udp6", "::1"},
+}
+
+// TestUDPListener tests both single and double listen to a test
+// listener with same address family, same listening address and
+// same port.
+func TestUDPListener(t *testing.T) {
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("not supported on %s", runtime.GOOS)
+ }
+
+ for _, tt := range udpListenerTests {
+ if !testableListenArgs(tt.network, JoinHostPort(tt.address, "0"), "") {
+ t.Logf("skipping %s test", tt.network+" "+tt.address)
+ continue
+ }
+
+ c1, err := ListenPacket(tt.network, JoinHostPort(tt.address, "0"))
+ if err != nil {
+ t.Fatal(err)
+ }
+ if err := checkFirstListener(tt.network, c1); err != nil {
+ c1.Close()
+ t.Fatal(err)
+ }
+ c2, err := ListenPacket(tt.network, JoinHostPort(tt.address, c1.(*UDPConn).port()))
+ if err == nil {
+ c2.Close()
+ }
+ if err := checkSecondListener(tt.network, tt.address, err); err != nil {
+ c1.Close()
+ t.Fatal(err)
+ }
+ c1.Close()
+ }
+}
+
+var dualStackTCPListenerTests = []struct {
+ network1, address1 string // first listener
+ network2, address2 string // second listener
+ xerr error // expected error value, nil or other
+}{
+ // Test cases and expected results for the attempting 2nd listen on the same port
+ // 1st listen 2nd listen darwin freebsd linux openbsd
+ // ------------------------------------------------------------------------------------
+ // "tcp" "" "tcp" "" - - - -
+ // "tcp" "" "tcp" "0.0.0.0" - - - -
+ // "tcp" "0.0.0.0" "tcp" "" - - - -
+ // ------------------------------------------------------------------------------------
+ // "tcp" "" "tcp" "[::]" - - - ok
+ // "tcp" "[::]" "tcp" "" - - - ok
+ // "tcp" "0.0.0.0" "tcp" "[::]" - - - ok
+ // "tcp" "[::]" "tcp" "0.0.0.0" - - - ok
+ // "tcp" "[::ffff:0.0.0.0]" "tcp" "[::]" - - - ok
+ // "tcp" "[::]" "tcp" "[::ffff:0.0.0.0]" - - - ok
+ // ------------------------------------------------------------------------------------
+ // "tcp4" "" "tcp6" "" ok ok ok ok
+ // "tcp6" "" "tcp4" "" ok ok ok ok
+ // "tcp4" "0.0.0.0" "tcp6" "[::]" ok ok ok ok
+ // "tcp6" "[::]" "tcp4" "0.0.0.0" ok ok ok ok
+ // ------------------------------------------------------------------------------------
+ // "tcp" "127.0.0.1" "tcp" "[::1]" ok ok ok ok
+ // "tcp" "[::1]" "tcp" "127.0.0.1" ok ok ok ok
+ // "tcp4" "127.0.0.1" "tcp6" "[::1]" ok ok ok ok
+ // "tcp6" "[::1]" "tcp4" "127.0.0.1" ok ok ok ok
+ //
+ // Platform default configurations:
+ // darwin, kernel version 11.3.0
+ // net.inet6.ip6.v6only=0 (overridable by sysctl or IPV6_V6ONLY option)
+ // freebsd, kernel version 8.2
+ // net.inet6.ip6.v6only=1 (overridable by sysctl or IPV6_V6ONLY option)
+ // linux, kernel version 3.0.0
+ // net.ipv6.bindv6only=0 (overridable by sysctl or IPV6_V6ONLY option)
+ // openbsd, kernel version 5.0
+ // net.inet6.ip6.v6only=1 (overriding is prohibited)
+
+ {"tcp", "", "tcp", "", syscall.EADDRINUSE},
+ {"tcp", "", "tcp", "0.0.0.0", syscall.EADDRINUSE},
+ {"tcp", "0.0.0.0", "tcp", "", syscall.EADDRINUSE},
+
+ {"tcp", "", "tcp", "::", syscall.EADDRINUSE},
+ {"tcp", "::", "tcp", "", syscall.EADDRINUSE},
+ {"tcp", "0.0.0.0", "tcp", "::", syscall.EADDRINUSE},
+ {"tcp", "::", "tcp", "0.0.0.0", syscall.EADDRINUSE},
+ {"tcp", "::ffff:0.0.0.0", "tcp", "::", syscall.EADDRINUSE},
+ {"tcp", "::", "tcp", "::ffff:0.0.0.0", syscall.EADDRINUSE},
+
+ {"tcp4", "", "tcp6", "", nil},
+ {"tcp6", "", "tcp4", "", nil},
+ {"tcp4", "0.0.0.0", "tcp6", "::", nil},
+ {"tcp6", "::", "tcp4", "0.0.0.0", nil},
+
+ {"tcp", "127.0.0.1", "tcp", "::1", nil},
+ {"tcp", "::1", "tcp", "127.0.0.1", nil},
+ {"tcp4", "127.0.0.1", "tcp6", "::1", nil},
+ {"tcp6", "::1", "tcp4", "127.0.0.1", nil},
+}
+
+// TestDualStackTCPListener tests both single and double listen
+// to a test listener with various address families, different
+// listening address and same port.
+//
+// On DragonFly BSD, we expect the kernel version of node under test
+// to be greater than or equal to 4.4.
+func TestDualStackTCPListener(t *testing.T) {
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("not supported on %s", runtime.GOOS)
+ }
+ if !supportsIPv4() || !supportsIPv6() {
+ t.Skip("both IPv4 and IPv6 are required")
+ }
+
+ for _, tt := range dualStackTCPListenerTests {
+ if !testableListenArgs(tt.network1, JoinHostPort(tt.address1, "0"), "") {
+ t.Logf("skipping %s test", tt.network1+" "+tt.address1)
+ continue
+ }
+
+ if !supportsIPv4map() && differentWildcardAddr(tt.address1, tt.address2) {
+ tt.xerr = nil
+ }
+ var firstErr, secondErr error
+ for i := 0; i < 5; i++ {
+ lns, err := newDualStackListener()
+ if err != nil {
+ t.Fatal(err)
+ }
+ port := lns[0].port()
+ for _, ln := range lns {
+ ln.Close()
+ }
+ var ln1 Listener
+ ln1, firstErr = Listen(tt.network1, JoinHostPort(tt.address1, port))
+ if firstErr != nil {
+ continue
+ }
+ if err := checkFirstListener(tt.network1, ln1); err != nil {
+ ln1.Close()
+ t.Fatal(err)
+ }
+ ln2, err := Listen(tt.network2, JoinHostPort(tt.address2, ln1.(*TCPListener).port()))
+ if err == nil {
+ ln2.Close()
+ }
+ if secondErr = checkDualStackSecondListener(tt.network2, tt.address2, err, tt.xerr); secondErr != nil {
+ ln1.Close()
+ continue
+ }
+ ln1.Close()
+ break
+ }
+ if firstErr != nil {
+ t.Error(firstErr)
+ }
+ if secondErr != nil {
+ t.Error(secondErr)
+ }
+ }
+}
+
+var dualStackUDPListenerTests = []struct {
+ network1, address1 string // first listener
+ network2, address2 string // second listener
+ xerr error // expected error value, nil or other
+}{
+ {"udp", "", "udp", "", syscall.EADDRINUSE},
+ {"udp", "", "udp", "0.0.0.0", syscall.EADDRINUSE},
+ {"udp", "0.0.0.0", "udp", "", syscall.EADDRINUSE},
+
+ {"udp", "", "udp", "::", syscall.EADDRINUSE},
+ {"udp", "::", "udp", "", syscall.EADDRINUSE},
+ {"udp", "0.0.0.0", "udp", "::", syscall.EADDRINUSE},
+ {"udp", "::", "udp", "0.0.0.0", syscall.EADDRINUSE},
+ {"udp", "::ffff:0.0.0.0", "udp", "::", syscall.EADDRINUSE},
+ {"udp", "::", "udp", "::ffff:0.0.0.0", syscall.EADDRINUSE},
+
+ {"udp4", "", "udp6", "", nil},
+ {"udp6", "", "udp4", "", nil},
+ {"udp4", "0.0.0.0", "udp6", "::", nil},
+ {"udp6", "::", "udp4", "0.0.0.0", nil},
+
+ {"udp", "127.0.0.1", "udp", "::1", nil},
+ {"udp", "::1", "udp", "127.0.0.1", nil},
+ {"udp4", "127.0.0.1", "udp6", "::1", nil},
+ {"udp6", "::1", "udp4", "127.0.0.1", nil},
+}
+
+// TestDualStackUDPListener tests both single and double listen
+// to a test listener with various address families, different
+// listening address and same port.
+//
+// On DragonFly BSD, we expect the kernel version of node under test
+// to be greater than or equal to 4.4.
+func TestDualStackUDPListener(t *testing.T) {
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("not supported on %s", runtime.GOOS)
+ }
+ if !supportsIPv4() || !supportsIPv6() {
+ t.Skip("both IPv4 and IPv6 are required")
+ }
+
+ for _, tt := range dualStackUDPListenerTests {
+ if !testableListenArgs(tt.network1, JoinHostPort(tt.address1, "0"), "") {
+ t.Logf("skipping %s test", tt.network1+" "+tt.address1)
+ continue
+ }
+
+ if !supportsIPv4map() && differentWildcardAddr(tt.address1, tt.address2) {
+ tt.xerr = nil
+ }
+ var firstErr, secondErr error
+ for i := 0; i < 5; i++ {
+ cs, err := newDualStackPacketListener()
+ if err != nil {
+ t.Fatal(err)
+ }
+ port := cs[0].port()
+ for _, c := range cs {
+ c.Close()
+ }
+ var c1 PacketConn
+ c1, firstErr = ListenPacket(tt.network1, JoinHostPort(tt.address1, port))
+ if firstErr != nil {
+ continue
+ }
+ if err := checkFirstListener(tt.network1, c1); err != nil {
+ c1.Close()
+ t.Fatal(err)
+ }
+ c2, err := ListenPacket(tt.network2, JoinHostPort(tt.address2, c1.(*UDPConn).port()))
+ if err == nil {
+ c2.Close()
+ }
+ if secondErr = checkDualStackSecondListener(tt.network2, tt.address2, err, tt.xerr); secondErr != nil {
+ c1.Close()
+ continue
+ }
+ c1.Close()
+ break
+ }
+ if firstErr != nil {
+ t.Error(firstErr)
+ }
+ if secondErr != nil {
+ t.Error(secondErr)
+ }
+ }
+}
+
+func differentWildcardAddr(i, j string) bool {
+ if (i == "" || i == "0.0.0.0" || i == "::ffff:0.0.0.0") && (j == "" || j == "0.0.0.0" || j == "::ffff:0.0.0.0") {
+ return false
+ }
+ if i == "[::]" && j == "[::]" {
+ return false
+ }
+ return true
+}
+
+func checkFirstListener(network string, ln any) error {
+ switch network {
+ case "tcp":
+ fd := ln.(*TCPListener).fd
+ if err := checkDualStackAddrFamily(fd); err != nil {
+ return err
+ }
+ case "tcp4":
+ fd := ln.(*TCPListener).fd
+ if fd.family != syscall.AF_INET {
+ return fmt.Errorf("%v got %v; want %v", fd.laddr, fd.family, syscall.AF_INET)
+ }
+ case "tcp6":
+ fd := ln.(*TCPListener).fd
+ if fd.family != syscall.AF_INET6 {
+ return fmt.Errorf("%v got %v; want %v", fd.laddr, fd.family, syscall.AF_INET6)
+ }
+ case "udp":
+ fd := ln.(*UDPConn).fd
+ if err := checkDualStackAddrFamily(fd); err != nil {
+ return err
+ }
+ case "udp4":
+ fd := ln.(*UDPConn).fd
+ if fd.family != syscall.AF_INET {
+ return fmt.Errorf("%v got %v; want %v", fd.laddr, fd.family, syscall.AF_INET)
+ }
+ case "udp6":
+ fd := ln.(*UDPConn).fd
+ if fd.family != syscall.AF_INET6 {
+ return fmt.Errorf("%v got %v; want %v", fd.laddr, fd.family, syscall.AF_INET6)
+ }
+ default:
+ return UnknownNetworkError(network)
+ }
+ return nil
+}
+
+func checkSecondListener(network, address string, err error) error {
+ switch network {
+ case "tcp", "tcp4", "tcp6":
+ if err == nil {
+ return fmt.Errorf("%s should fail", network+" "+address)
+ }
+ case "udp", "udp4", "udp6":
+ if err == nil {
+ return fmt.Errorf("%s should fail", network+" "+address)
+ }
+ default:
+ return UnknownNetworkError(network)
+ }
+ return nil
+}
+
+func checkDualStackSecondListener(network, address string, err, xerr error) error {
+ switch network {
+ case "tcp", "tcp4", "tcp6":
+ if xerr == nil && err != nil || xerr != nil && err == nil {
+ return fmt.Errorf("%s got %v; want %v", network+" "+address, err, xerr)
+ }
+ case "udp", "udp4", "udp6":
+ if xerr == nil && err != nil || xerr != nil && err == nil {
+ return fmt.Errorf("%s got %v; want %v", network+" "+address, err, xerr)
+ }
+ default:
+ return UnknownNetworkError(network)
+ }
+ return nil
+}
+
+func checkDualStackAddrFamily(fd *netFD) error {
+ switch a := fd.laddr.(type) {
+ case *TCPAddr:
+ // If a node under test supports both IPv6 capability
+ // and IPv6 IPv4-mapping capability, we can assume
+ // that the node listens on a wildcard address with an
+ // AF_INET6 socket.
+ if supportsIPv4map() && fd.laddr.(*TCPAddr).isWildcard() {
+ if fd.family != syscall.AF_INET6 {
+ return fmt.Errorf("Listen(%s, %v) returns %v; want %v", fd.net, fd.laddr, fd.family, syscall.AF_INET6)
+ }
+ } else {
+ if fd.family != a.family() {
+ return fmt.Errorf("Listen(%s, %v) returns %v; want %v", fd.net, fd.laddr, fd.family, a.family())
+ }
+ }
+ case *UDPAddr:
+ // If a node under test supports both IPv6 capability
+ // and IPv6 IPv4-mapping capability, we can assume
+ // that the node listens on a wildcard address with an
+ // AF_INET6 socket.
+ if supportsIPv4map() && fd.laddr.(*UDPAddr).isWildcard() {
+ if fd.family != syscall.AF_INET6 {
+ return fmt.Errorf("ListenPacket(%s, %v) returns %v; want %v", fd.net, fd.laddr, fd.family, syscall.AF_INET6)
+ }
+ } else {
+ if fd.family != a.family() {
+ return fmt.Errorf("ListenPacket(%s, %v) returns %v; want %v", fd.net, fd.laddr, fd.family, a.family())
+ }
+ }
+ default:
+ return fmt.Errorf("unexpected protocol address type: %T", a)
+ }
+ return nil
+}
+
+func TestWildWildcardListener(t *testing.T) {
+ testenv.MustHaveExternalNetwork(t)
+
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("not supported on %s", runtime.GOOS)
+ }
+
+ defer func() {
+ if p := recover(); p != nil {
+ t.Fatalf("panicked: %v", p)
+ }
+ }()
+
+ if ln, err := Listen("tcp", ""); err == nil {
+ ln.Close()
+ }
+ if ln, err := ListenPacket("udp", ""); err == nil {
+ ln.Close()
+ }
+ if ln, err := ListenTCP("tcp", nil); err == nil {
+ ln.Close()
+ }
+ if ln, err := ListenUDP("udp", nil); err == nil {
+ ln.Close()
+ }
+ if ln, err := ListenIP("ip:icmp", nil); err == nil {
+ ln.Close()
+ }
+}
+
+var ipv4MulticastListenerTests = []struct {
+ net string
+ gaddr *UDPAddr // see RFC 4727
+}{
+ {"udp", &UDPAddr{IP: IPv4(224, 0, 0, 254), Port: 12345}},
+
+ {"udp4", &UDPAddr{IP: IPv4(224, 0, 0, 254), Port: 12345}},
+}
+
+// TestIPv4MulticastListener tests both single and double listen to a
+// test listener with same address family, same group address and same
+// port.
+func TestIPv4MulticastListener(t *testing.T) {
+ testenv.MustHaveExternalNetwork(t)
+
+ switch runtime.GOOS {
+ case "android", "plan9":
+ t.Skipf("not supported on %s", runtime.GOOS)
+ }
+ if !supportsIPv4() {
+ t.Skip("IPv4 is not supported")
+ }
+
+ closer := func(cs []*UDPConn) {
+ for _, c := range cs {
+ if c != nil {
+ c.Close()
+ }
+ }
+ }
+
+ for _, ifi := range []*Interface{loopbackInterface(), nil} {
+ // Note that multicast interface assignment by system
+ // is not recommended because it usually relies on
+ // routing stuff for finding out an appropriate
+ // nexthop containing both network and link layer
+ // adjacencies.
+ if ifi == nil || !*testIPv4 {
+ continue
+ }
+ for _, tt := range ipv4MulticastListenerTests {
+ var err error
+ cs := make([]*UDPConn, 2)
+ if cs[0], err = ListenMulticastUDP(tt.net, ifi, tt.gaddr); err != nil {
+ t.Fatal(err)
+ }
+ if err := checkMulticastListener(cs[0], tt.gaddr.IP); err != nil {
+ closer(cs)
+ t.Fatal(err)
+ }
+ if cs[1], err = ListenMulticastUDP(tt.net, ifi, tt.gaddr); err != nil {
+ closer(cs)
+ t.Fatal(err)
+ }
+ if err := checkMulticastListener(cs[1], tt.gaddr.IP); err != nil {
+ closer(cs)
+ t.Fatal(err)
+ }
+ closer(cs)
+ }
+ }
+}
+
+var ipv6MulticastListenerTests = []struct {
+ net string
+ gaddr *UDPAddr // see RFC 4727
+}{
+ {"udp", &UDPAddr{IP: ParseIP("ff01::114"), Port: 12345}},
+ {"udp", &UDPAddr{IP: ParseIP("ff02::114"), Port: 12345}},
+ {"udp", &UDPAddr{IP: ParseIP("ff04::114"), Port: 12345}},
+ {"udp", &UDPAddr{IP: ParseIP("ff05::114"), Port: 12345}},
+ {"udp", &UDPAddr{IP: ParseIP("ff08::114"), Port: 12345}},
+ {"udp", &UDPAddr{IP: ParseIP("ff0e::114"), Port: 12345}},
+
+ {"udp6", &UDPAddr{IP: ParseIP("ff01::114"), Port: 12345}},
+ {"udp6", &UDPAddr{IP: ParseIP("ff02::114"), Port: 12345}},
+ {"udp6", &UDPAddr{IP: ParseIP("ff04::114"), Port: 12345}},
+ {"udp6", &UDPAddr{IP: ParseIP("ff05::114"), Port: 12345}},
+ {"udp6", &UDPAddr{IP: ParseIP("ff08::114"), Port: 12345}},
+ {"udp6", &UDPAddr{IP: ParseIP("ff0e::114"), Port: 12345}},
+}
+
+// TestIPv6MulticastListener tests both single and double listen to a
+// test listener with same address family, same group address and same
+// port.
+func TestIPv6MulticastListener(t *testing.T) {
+ testenv.MustHaveExternalNetwork(t)
+
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("not supported on %s", runtime.GOOS)
+ }
+ if !supportsIPv6() {
+ t.Skip("IPv6 is not supported")
+ }
+ if os.Getuid() != 0 {
+ t.Skip("must be root")
+ }
+
+ closer := func(cs []*UDPConn) {
+ for _, c := range cs {
+ if c != nil {
+ c.Close()
+ }
+ }
+ }
+
+ for _, ifi := range []*Interface{loopbackInterface(), nil} {
+ // Note that multicast interface assignment by system
+ // is not recommended because it usually relies on
+ // routing stuff for finding out an appropriate
+ // nexthop containing both network and link layer
+ // adjacencies.
+ if ifi == nil && !*testIPv6 {
+ continue
+ }
+ for _, tt := range ipv6MulticastListenerTests {
+ var err error
+ cs := make([]*UDPConn, 2)
+ if cs[0], err = ListenMulticastUDP(tt.net, ifi, tt.gaddr); err != nil {
+ t.Fatal(err)
+ }
+ if err := checkMulticastListener(cs[0], tt.gaddr.IP); err != nil {
+ closer(cs)
+ t.Fatal(err)
+ }
+ if cs[1], err = ListenMulticastUDP(tt.net, ifi, tt.gaddr); err != nil {
+ closer(cs)
+ t.Fatal(err)
+ }
+ if err := checkMulticastListener(cs[1], tt.gaddr.IP); err != nil {
+ closer(cs)
+ t.Fatal(err)
+ }
+ closer(cs)
+ }
+ }
+}
+
+func checkMulticastListener(c *UDPConn, ip IP) error {
+ if ok, err := multicastRIBContains(ip); err != nil {
+ return err
+ } else if !ok {
+ return fmt.Errorf("%s not found in multicast rib", ip.String())
+ }
+ la := c.LocalAddr()
+ if la, ok := la.(*UDPAddr); !ok || la.Port == 0 {
+ return fmt.Errorf("got %v; want a proper address with non-zero port number", la)
+ }
+ return nil
+}
+
+func multicastRIBContains(ip IP) (bool, error) {
+ switch runtime.GOOS {
+ case "aix", "dragonfly", "netbsd", "openbsd", "plan9", "solaris", "illumos", "windows":
+ return true, nil // not implemented yet
+ case "linux":
+ if runtime.GOARCH == "arm" || runtime.GOARCH == "alpha" {
+ return true, nil // not implemented yet
+ }
+ }
+ ift, err := Interfaces()
+ if err != nil {
+ return false, err
+ }
+ for _, ifi := range ift {
+ ifmat, err := ifi.MulticastAddrs()
+ if err != nil {
+ return false, err
+ }
+ for _, ifma := range ifmat {
+ if ifma.(*IPAddr).IP.Equal(ip) {
+ return true, nil
+ }
+ }
+ }
+ return false, nil
+}
+
+// Issue 21856.
+func TestClosingListener(t *testing.T) {
+ ln := newLocalListener(t, "tcp")
+ addr := ln.Addr()
+
+ go func() {
+ for {
+ c, err := ln.Accept()
+ if err != nil {
+ return
+ }
+ c.Close()
+ }
+ }()
+
+ // Let the goroutine start. We don't sleep long: if the
+ // goroutine doesn't start, the test will pass without really
+ // testing anything, which is OK.
+ time.Sleep(time.Millisecond)
+
+ ln.Close()
+
+ ln2, err := Listen("tcp", addr.String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ ln2.Close()
+}
+
+func TestListenConfigControl(t *testing.T) {
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("not supported on %s", runtime.GOOS)
+ }
+
+ t.Run("StreamListen", func(t *testing.T) {
+ for _, network := range []string{"tcp", "tcp4", "tcp6", "unix", "unixpacket"} {
+ if !testableNetwork(network) {
+ continue
+ }
+ ln := newLocalListener(t, network, &ListenConfig{Control: controlOnConnSetup})
+ ln.Close()
+ }
+ })
+ t.Run("PacketListen", func(t *testing.T) {
+ for _, network := range []string{"udp", "udp4", "udp6", "unixgram"} {
+ if !testableNetwork(network) {
+ continue
+ }
+ c := newLocalPacketListener(t, network, &ListenConfig{Control: controlOnConnSetup})
+ c.Close()
+ }
+ })
+}
diff --git a/src/net/lookup.go b/src/net/lookup.go
new file mode 100644
index 0000000..a7133b5
--- /dev/null
+++ b/src/net/lookup.go
@@ -0,0 +1,908 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "context"
+ "errors"
+ "internal/nettrace"
+ "internal/singleflight"
+ "net/netip"
+ "sync"
+
+ "golang.org/x/net/dns/dnsmessage"
+)
+
+// protocols contains minimal mappings between internet protocol
+// names and numbers for platforms that don't have a complete list of
+// protocol numbers.
+//
+// See https://www.iana.org/assignments/protocol-numbers
+//
+// On Unix, this map is augmented by readProtocols via lookupProtocol.
+var protocols = map[string]int{
+ "icmp": 1,
+ "igmp": 2,
+ "tcp": 6,
+ "udp": 17,
+ "ipv6-icmp": 58,
+}
+
+// services contains minimal mappings between services names and port
+// numbers for platforms that don't have a complete list of port numbers.
+//
+// See https://www.iana.org/assignments/service-names-port-numbers
+//
+// On Unix, this map is augmented by readServices via goLookupPort.
+var services = map[string]map[string]int{
+ "udp": {
+ "domain": 53,
+ },
+ "tcp": {
+ "ftp": 21,
+ "ftps": 990,
+ "gopher": 70, // ʕ◔ϖ◔ʔ
+ "http": 80,
+ "https": 443,
+ "imap2": 143,
+ "imap3": 220,
+ "imaps": 993,
+ "pop3": 110,
+ "pop3s": 995,
+ "smtp": 25,
+ "ssh": 22,
+ "telnet": 23,
+ },
+}
+
+// dnsWaitGroup can be used by tests to wait for all DNS goroutines to
+// complete. This avoids races on the test hooks.
+var dnsWaitGroup sync.WaitGroup
+
+const maxProtoLength = len("RSVP-E2E-IGNORE") + 10 // with room to grow
+
+func lookupProtocolMap(name string) (int, error) {
+ var lowerProtocol [maxProtoLength]byte
+ n := copy(lowerProtocol[:], name)
+ lowerASCIIBytes(lowerProtocol[:n])
+ proto, found := protocols[string(lowerProtocol[:n])]
+ if !found || n != len(name) {
+ return 0, &AddrError{Err: "unknown IP protocol specified", Addr: name}
+ }
+ return proto, nil
+}
+
+// maxPortBufSize is the longest reasonable name of a service
+// (non-numeric port).
+// Currently the longest known IANA-unregistered name is
+// "mobility-header", so we use that length, plus some slop in case
+// something longer is added in the future.
+const maxPortBufSize = len("mobility-header") + 10
+
+func lookupPortMap(network, service string) (port int, error error) {
+ switch network {
+ case "tcp4", "tcp6":
+ network = "tcp"
+ case "udp4", "udp6":
+ network = "udp"
+ }
+
+ if m, ok := services[network]; ok {
+ var lowerService [maxPortBufSize]byte
+ n := copy(lowerService[:], service)
+ lowerASCIIBytes(lowerService[:n])
+ if port, ok := m[string(lowerService[:n])]; ok && n == len(service) {
+ return port, nil
+ }
+ }
+ return 0, &AddrError{Err: "unknown port", Addr: network + "/" + service}
+}
+
+// ipVersion returns the provided network's IP version: '4', '6' or 0
+// if network does not end in a '4' or '6' byte.
+func ipVersion(network string) byte {
+ if network == "" {
+ return 0
+ }
+ n := network[len(network)-1]
+ if n != '4' && n != '6' {
+ n = 0
+ }
+ return n
+}
+
+// DefaultResolver is the resolver used by the package-level Lookup
+// functions and by Dialers without a specified Resolver.
+var DefaultResolver = &Resolver{}
+
+// A Resolver looks up names and numbers.
+//
+// A nil *Resolver is equivalent to a zero Resolver.
+type Resolver struct {
+ // PreferGo controls whether Go's built-in DNS resolver is preferred
+ // on platforms where it's available. It is equivalent to setting
+ // GODEBUG=netdns=go, but scoped to just this resolver.
+ PreferGo bool
+
+ // StrictErrors controls the behavior of temporary errors
+ // (including timeout, socket errors, and SERVFAIL) when using
+ // Go's built-in resolver. For a query composed of multiple
+ // sub-queries (such as an A+AAAA address lookup, or walking the
+ // DNS search list), this option causes such errors to abort the
+ // whole query instead of returning a partial result. This is
+ // not enabled by default because it may affect compatibility
+ // with resolvers that process AAAA queries incorrectly.
+ StrictErrors bool
+
+ // Dial optionally specifies an alternate dialer for use by
+ // Go's built-in DNS resolver to make TCP and UDP connections
+ // to DNS services. The host in the address parameter will
+ // always be a literal IP address and not a host name, and the
+ // port in the address parameter will be a literal port number
+ // and not a service name.
+ // If the Conn returned is also a PacketConn, sent and received DNS
+ // messages must adhere to RFC 1035 section 4.2.1, "UDP usage".
+ // Otherwise, DNS messages transmitted over Conn must adhere
+ // to RFC 7766 section 5, "Transport Protocol Selection".
+ // If nil, the default dialer is used.
+ Dial func(ctx context.Context, network, address string) (Conn, error)
+
+ // lookupGroup merges LookupIPAddr calls together for lookups for the same
+ // host. The lookupGroup key is the LookupIPAddr.host argument.
+ // The return values are ([]IPAddr, error).
+ lookupGroup singleflight.Group
+
+ // TODO(bradfitz): optional interface impl override hook
+ // TODO(bradfitz): Timeout time.Duration?
+}
+
+func (r *Resolver) preferGo() bool { return r != nil && r.PreferGo }
+func (r *Resolver) strictErrors() bool { return r != nil && r.StrictErrors }
+
+func (r *Resolver) getLookupGroup() *singleflight.Group {
+ if r == nil {
+ return &DefaultResolver.lookupGroup
+ }
+ return &r.lookupGroup
+}
+
+// LookupHost looks up the given host using the local resolver.
+// It returns a slice of that host's addresses.
+//
+// LookupHost uses context.Background internally; to specify the context, use
+// Resolver.LookupHost.
+func LookupHost(host string) (addrs []string, err error) {
+ return DefaultResolver.LookupHost(context.Background(), host)
+}
+
+// LookupHost looks up the given host using the local resolver.
+// It returns a slice of that host's addresses.
+func (r *Resolver) LookupHost(ctx context.Context, host string) (addrs []string, err error) {
+ // Make sure that no matter what we do later, host=="" is rejected.
+ if host == "" {
+ return nil, &DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true}
+ }
+ if _, err := netip.ParseAddr(host); err == nil {
+ return []string{host}, nil
+ }
+ return r.lookupHost(ctx, host)
+}
+
+// LookupIP looks up host using the local resolver.
+// It returns a slice of that host's IPv4 and IPv6 addresses.
+func LookupIP(host string) ([]IP, error) {
+ addrs, err := DefaultResolver.LookupIPAddr(context.Background(), host)
+ if err != nil {
+ return nil, err
+ }
+ ips := make([]IP, len(addrs))
+ for i, ia := range addrs {
+ ips[i] = ia.IP
+ }
+ return ips, nil
+}
+
+// LookupIPAddr looks up host using the local resolver.
+// It returns a slice of that host's IPv4 and IPv6 addresses.
+func (r *Resolver) LookupIPAddr(ctx context.Context, host string) ([]IPAddr, error) {
+ return r.lookupIPAddr(ctx, "ip", host)
+}
+
+// LookupIP looks up host for the given network using the local resolver.
+// It returns a slice of that host's IP addresses of the type specified by
+// network.
+// network must be one of "ip", "ip4" or "ip6".
+func (r *Resolver) LookupIP(ctx context.Context, network, host string) ([]IP, error) {
+ afnet, _, err := parseNetwork(ctx, network, false)
+ if err != nil {
+ return nil, err
+ }
+ switch afnet {
+ case "ip", "ip4", "ip6":
+ default:
+ return nil, UnknownNetworkError(network)
+ }
+
+ if host == "" {
+ return nil, &DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true}
+ }
+ addrs, err := r.internetAddrList(ctx, afnet, host)
+ if err != nil {
+ return nil, err
+ }
+
+ ips := make([]IP, 0, len(addrs))
+ for _, addr := range addrs {
+ ips = append(ips, addr.(*IPAddr).IP)
+ }
+ return ips, nil
+}
+
+// LookupNetIP looks up host using the local resolver.
+// It returns a slice of that host's IP addresses of the type specified by
+// network.
+// The network must be one of "ip", "ip4" or "ip6".
+func (r *Resolver) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) {
+ // TODO(bradfitz): make this efficient, making the internal net package
+ // type throughout be netip.Addr and only converting to the net.IP slice
+ // version at the edge. But for now (2021-10-20), this is a wrapper around
+ // the old way.
+ ips, err := r.LookupIP(ctx, network, host)
+ if err != nil {
+ return nil, err
+ }
+ ret := make([]netip.Addr, 0, len(ips))
+ for _, ip := range ips {
+ if a, ok := netip.AddrFromSlice(ip); ok {
+ ret = append(ret, a)
+ }
+ }
+ return ret, nil
+}
+
+// onlyValuesCtx is a context that uses an underlying context
+// for value lookup if the underlying context hasn't yet expired.
+type onlyValuesCtx struct {
+ context.Context
+ lookupValues context.Context
+}
+
+var _ context.Context = (*onlyValuesCtx)(nil)
+
+// Value performs a lookup if the original context hasn't expired.
+func (ovc *onlyValuesCtx) Value(key any) any {
+ select {
+ case <-ovc.lookupValues.Done():
+ return nil
+ default:
+ return ovc.lookupValues.Value(key)
+ }
+}
+
+// withUnexpiredValuesPreserved returns a context.Context that only uses lookupCtx
+// for its values, otherwise it is never canceled and has no deadline.
+// If the lookup context expires, any looked up values will return nil.
+// See Issue 28600.
+func withUnexpiredValuesPreserved(lookupCtx context.Context) context.Context {
+ return &onlyValuesCtx{Context: context.Background(), lookupValues: lookupCtx}
+}
+
+// lookupIPAddr looks up host using the local resolver and particular network.
+// It returns a slice of that host's IPv4 and IPv6 addresses.
+func (r *Resolver) lookupIPAddr(ctx context.Context, network, host string) ([]IPAddr, error) {
+ // Make sure that no matter what we do later, host=="" is rejected.
+ if host == "" {
+ return nil, &DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true}
+ }
+ if ip, err := netip.ParseAddr(host); err == nil {
+ return []IPAddr{{IP: IP(ip.AsSlice()).To16(), Zone: ip.Zone()}}, nil
+ }
+ trace, _ := ctx.Value(nettrace.TraceKey{}).(*nettrace.Trace)
+ if trace != nil && trace.DNSStart != nil {
+ trace.DNSStart(host)
+ }
+ // The underlying resolver func is lookupIP by default but it
+ // can be overridden by tests. This is needed by net/http, so it
+ // uses a context key instead of unexported variables.
+ resolverFunc := r.lookupIP
+ if alt, _ := ctx.Value(nettrace.LookupIPAltResolverKey{}).(func(context.Context, string, string) ([]IPAddr, error)); alt != nil {
+ resolverFunc = alt
+ }
+
+ // We don't want a cancellation of ctx to affect the
+ // lookupGroup operation. Otherwise if our context gets
+ // canceled it might cause an error to be returned to a lookup
+ // using a completely different context. However we need to preserve
+ // only the values in context. See Issue 28600.
+ lookupGroupCtx, lookupGroupCancel := context.WithCancel(withUnexpiredValuesPreserved(ctx))
+
+ lookupKey := network + "\000" + host
+ dnsWaitGroup.Add(1)
+ ch := r.getLookupGroup().DoChan(lookupKey, func() (any, error) {
+ return testHookLookupIP(lookupGroupCtx, resolverFunc, network, host)
+ })
+
+ dnsWaitGroupDone := func(ch <-chan singleflight.Result, cancelFn context.CancelFunc) {
+ <-ch
+ dnsWaitGroup.Done()
+ cancelFn()
+ }
+ select {
+ case <-ctx.Done():
+ // Our context was canceled. If we are the only
+ // goroutine looking up this key, then drop the key
+ // from the lookupGroup and cancel the lookup.
+ // If there are other goroutines looking up this key,
+ // let the lookup continue uncanceled, and let later
+ // lookups with the same key share the result.
+ // See issues 8602, 20703, 22724.
+ if r.getLookupGroup().ForgetUnshared(lookupKey) {
+ lookupGroupCancel()
+ go dnsWaitGroupDone(ch, func() {})
+ } else {
+ go dnsWaitGroupDone(ch, lookupGroupCancel)
+ }
+ ctxErr := ctx.Err()
+ err := &DNSError{
+ Err: mapErr(ctxErr).Error(),
+ Name: host,
+ IsTimeout: ctxErr == context.DeadlineExceeded,
+ }
+ if trace != nil && trace.DNSDone != nil {
+ trace.DNSDone(nil, false, err)
+ }
+ return nil, err
+ case r := <-ch:
+ dnsWaitGroup.Done()
+ lookupGroupCancel()
+ err := r.Err
+ if err != nil {
+ if _, ok := err.(*DNSError); !ok {
+ isTimeout := false
+ if err == context.DeadlineExceeded {
+ isTimeout = true
+ } else if terr, ok := err.(timeout); ok {
+ isTimeout = terr.Timeout()
+ }
+ err = &DNSError{
+ Err: err.Error(),
+ Name: host,
+ IsTimeout: isTimeout,
+ }
+ }
+ }
+ if trace != nil && trace.DNSDone != nil {
+ addrs, _ := r.Val.([]IPAddr)
+ trace.DNSDone(ipAddrsEface(addrs), r.Shared, err)
+ }
+ return lookupIPReturn(r.Val, err, r.Shared)
+ }
+}
+
+// lookupIPReturn turns the return values from singleflight.Do into
+// the return values from LookupIP.
+func lookupIPReturn(addrsi any, err error, shared bool) ([]IPAddr, error) {
+ if err != nil {
+ return nil, err
+ }
+ addrs := addrsi.([]IPAddr)
+ if shared {
+ clone := make([]IPAddr, len(addrs))
+ copy(clone, addrs)
+ addrs = clone
+ }
+ return addrs, nil
+}
+
+// ipAddrsEface returns an empty interface slice of addrs.
+func ipAddrsEface(addrs []IPAddr) []any {
+ s := make([]any, len(addrs))
+ for i, v := range addrs {
+ s[i] = v
+ }
+ return s
+}
+
+// LookupPort looks up the port for the given network and service.
+//
+// LookupPort uses context.Background internally; to specify the context, use
+// Resolver.LookupPort.
+func LookupPort(network, service string) (port int, err error) {
+ return DefaultResolver.LookupPort(context.Background(), network, service)
+}
+
+// LookupPort looks up the port for the given network and service.
+func (r *Resolver) LookupPort(ctx context.Context, network, service string) (port int, err error) {
+ port, needsLookup := parsePort(service)
+ if needsLookup {
+ switch network {
+ case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6":
+ case "": // a hint wildcard for Go 1.0 undocumented behavior
+ network = "ip"
+ default:
+ return 0, &AddrError{Err: "unknown network", Addr: network}
+ }
+ port, err = r.lookupPort(ctx, network, service)
+ if err != nil {
+ return 0, err
+ }
+ }
+ if 0 > port || port > 65535 {
+ return 0, &AddrError{Err: "invalid port", Addr: service}
+ }
+ return port, nil
+}
+
+// LookupCNAME returns the canonical name for the given host.
+// Callers that do not care about the canonical name can call
+// LookupHost or LookupIP directly; both take care of resolving
+// the canonical name as part of the lookup.
+//
+// A canonical name is the final name after following zero
+// or more CNAME records.
+// LookupCNAME does not return an error if host does not
+// contain DNS "CNAME" records, as long as host resolves to
+// address records.
+//
+// The returned canonical name is validated to be a properly
+// formatted presentation-format domain name.
+//
+// LookupCNAME uses context.Background internally; to specify the context, use
+// Resolver.LookupCNAME.
+func LookupCNAME(host string) (cname string, err error) {
+ return DefaultResolver.LookupCNAME(context.Background(), host)
+}
+
+// LookupCNAME returns the canonical name for the given host.
+// Callers that do not care about the canonical name can call
+// LookupHost or LookupIP directly; both take care of resolving
+// the canonical name as part of the lookup.
+//
+// A canonical name is the final name after following zero
+// or more CNAME records.
+// LookupCNAME does not return an error if host does not
+// contain DNS "CNAME" records, as long as host resolves to
+// address records.
+//
+// The returned canonical name is validated to be a properly
+// formatted presentation-format domain name.
+func (r *Resolver) LookupCNAME(ctx context.Context, host string) (string, error) {
+ cname, err := r.lookupCNAME(ctx, host)
+ if err != nil {
+ return "", err
+ }
+ if !isDomainName(cname) {
+ return "", &DNSError{Err: errMalformedDNSRecordsDetail, Name: host}
+ }
+ return cname, nil
+}
+
+// LookupSRV tries to resolve an SRV query of the given service,
+// protocol, and domain name. The proto is "tcp" or "udp".
+// The returned records are sorted by priority and randomized
+// by weight within a priority.
+//
+// LookupSRV constructs the DNS name to look up following RFC 2782.
+// That is, it looks up _service._proto.name. To accommodate services
+// publishing SRV records under non-standard names, if both service
+// and proto are empty strings, LookupSRV looks up name directly.
+//
+// The returned service names are validated to be properly
+// formatted presentation-format domain names. If the response contains
+// invalid names, those records are filtered out and an error
+// will be returned alongside the remaining results, if any.
+func LookupSRV(service, proto, name string) (cname string, addrs []*SRV, err error) {
+ return DefaultResolver.LookupSRV(context.Background(), service, proto, name)
+}
+
+// LookupSRV tries to resolve an SRV query of the given service,
+// protocol, and domain name. The proto is "tcp" or "udp".
+// The returned records are sorted by priority and randomized
+// by weight within a priority.
+//
+// LookupSRV constructs the DNS name to look up following RFC 2782.
+// That is, it looks up _service._proto.name. To accommodate services
+// publishing SRV records under non-standard names, if both service
+// and proto are empty strings, LookupSRV looks up name directly.
+//
+// The returned service names are validated to be properly
+// formatted presentation-format domain names. If the response contains
+// invalid names, those records are filtered out and an error
+// will be returned alongside the remaining results, if any.
+func (r *Resolver) LookupSRV(ctx context.Context, service, proto, name string) (string, []*SRV, error) {
+ cname, addrs, err := r.lookupSRV(ctx, service, proto, name)
+ if err != nil {
+ return "", nil, err
+ }
+ if cname != "" && !isDomainName(cname) {
+ return "", nil, &DNSError{Err: "SRV header name is invalid", Name: name}
+ }
+ filteredAddrs := make([]*SRV, 0, len(addrs))
+ for _, addr := range addrs {
+ if addr == nil {
+ continue
+ }
+ if !isDomainName(addr.Target) {
+ continue
+ }
+ filteredAddrs = append(filteredAddrs, addr)
+ }
+ if len(addrs) != len(filteredAddrs) {
+ return cname, filteredAddrs, &DNSError{Err: errMalformedDNSRecordsDetail, Name: name}
+ }
+ return cname, filteredAddrs, nil
+}
+
+// LookupMX returns the DNS MX records for the given domain name sorted by preference.
+//
+// The returned mail server names are validated to be properly
+// formatted presentation-format domain names. If the response contains
+// invalid names, those records are filtered out and an error
+// will be returned alongside the remaining results, if any.
+//
+// LookupMX uses context.Background internally; to specify the context, use
+// Resolver.LookupMX.
+func LookupMX(name string) ([]*MX, error) {
+ return DefaultResolver.LookupMX(context.Background(), name)
+}
+
+// LookupMX returns the DNS MX records for the given domain name sorted by preference.
+//
+// The returned mail server names are validated to be properly
+// formatted presentation-format domain names. If the response contains
+// invalid names, those records are filtered out and an error
+// will be returned alongside the remaining results, if any.
+func (r *Resolver) LookupMX(ctx context.Context, name string) ([]*MX, error) {
+ records, err := r.lookupMX(ctx, name)
+ if err != nil {
+ return nil, err
+ }
+ filteredMX := make([]*MX, 0, len(records))
+ for _, mx := range records {
+ if mx == nil {
+ continue
+ }
+ if !isDomainName(mx.Host) {
+ continue
+ }
+ filteredMX = append(filteredMX, mx)
+ }
+ if len(records) != len(filteredMX) {
+ return filteredMX, &DNSError{Err: errMalformedDNSRecordsDetail, Name: name}
+ }
+ return filteredMX, nil
+}
+
+// LookupNS returns the DNS NS records for the given domain name.
+//
+// The returned name server names are validated to be properly
+// formatted presentation-format domain names. If the response contains
+// invalid names, those records are filtered out and an error
+// will be returned alongside the remaining results, if any.
+//
+// LookupNS uses context.Background internally; to specify the context, use
+// Resolver.LookupNS.
+func LookupNS(name string) ([]*NS, error) {
+ return DefaultResolver.LookupNS(context.Background(), name)
+}
+
+// LookupNS returns the DNS NS records for the given domain name.
+//
+// The returned name server names are validated to be properly
+// formatted presentation-format domain names. If the response contains
+// invalid names, those records are filtered out and an error
+// will be returned alongside the remaining results, if any.
+func (r *Resolver) LookupNS(ctx context.Context, name string) ([]*NS, error) {
+ records, err := r.lookupNS(ctx, name)
+ if err != nil {
+ return nil, err
+ }
+ filteredNS := make([]*NS, 0, len(records))
+ for _, ns := range records {
+ if ns == nil {
+ continue
+ }
+ if !isDomainName(ns.Host) {
+ continue
+ }
+ filteredNS = append(filteredNS, ns)
+ }
+ if len(records) != len(filteredNS) {
+ return filteredNS, &DNSError{Err: errMalformedDNSRecordsDetail, Name: name}
+ }
+ return filteredNS, nil
+}
+
+// LookupTXT returns the DNS TXT records for the given domain name.
+//
+// LookupTXT uses context.Background internally; to specify the context, use
+// Resolver.LookupTXT.
+func LookupTXT(name string) ([]string, error) {
+ return DefaultResolver.lookupTXT(context.Background(), name)
+}
+
+// LookupTXT returns the DNS TXT records for the given domain name.
+func (r *Resolver) LookupTXT(ctx context.Context, name string) ([]string, error) {
+ return r.lookupTXT(ctx, name)
+}
+
+// LookupAddr performs a reverse lookup for the given address, returning a list
+// of names mapping to that address.
+//
+// The returned names are validated to be properly formatted presentation-format
+// domain names. If the response contains invalid names, those records are filtered
+// out and an error will be returned alongside the remaining results, if any.
+//
+// When using the host C library resolver, at most one result will be
+// returned. To bypass the host resolver, use a custom Resolver.
+//
+// LookupAddr uses context.Background internally; to specify the context, use
+// Resolver.LookupAddr.
+func LookupAddr(addr string) (names []string, err error) {
+ return DefaultResolver.LookupAddr(context.Background(), addr)
+}
+
+// LookupAddr performs a reverse lookup for the given address, returning a list
+// of names mapping to that address.
+//
+// The returned names are validated to be properly formatted presentation-format
+// domain names. If the response contains invalid names, those records are filtered
+// out and an error will be returned alongside the remaining results, if any.
+func (r *Resolver) LookupAddr(ctx context.Context, addr string) ([]string, error) {
+ names, err := r.lookupAddr(ctx, addr)
+ if err != nil {
+ return nil, err
+ }
+ filteredNames := make([]string, 0, len(names))
+ for _, name := range names {
+ if isDomainName(name) {
+ filteredNames = append(filteredNames, name)
+ }
+ }
+ if len(names) != len(filteredNames) {
+ return filteredNames, &DNSError{Err: errMalformedDNSRecordsDetail, Name: addr}
+ }
+ return filteredNames, nil
+}
+
+// errMalformedDNSRecordsDetail is the DNSError detail which is returned when a Resolver.Lookup...
+// method receives DNS records which contain invalid DNS names. This may be returned alongside
+// results which have had the malformed records filtered out.
+var errMalformedDNSRecordsDetail = "DNS response contained records which contain invalid names"
+
+// dial makes a new connection to the provided server (which must be
+// an IP address) with the provided network type, using either r.Dial
+// (if both r and r.Dial are non-nil) or else Dialer.DialContext.
+func (r *Resolver) dial(ctx context.Context, network, server string) (Conn, error) {
+ // Calling Dial here is scary -- we have to be sure not to
+ // dial a name that will require a DNS lookup, or Dial will
+ // call back here to translate it. The DNS config parser has
+ // already checked that all the cfg.servers are IP
+ // addresses, which Dial will use without a DNS lookup.
+ var c Conn
+ var err error
+ if r != nil && r.Dial != nil {
+ c, err = r.Dial(ctx, network, server)
+ } else {
+ var d Dialer
+ c, err = d.DialContext(ctx, network, server)
+ }
+ if err != nil {
+ return nil, mapErr(err)
+ }
+ return c, nil
+}
+
+// goLookupSRV returns the SRV records for a target name, built either
+// from its component service ("sip"), protocol ("tcp"), and name
+// ("example.com."), or from name directly (if service and proto are
+// both empty).
+//
+// In either case, the returned target name ("_sip._tcp.example.com.")
+// is also returned on success.
+//
+// The records are sorted by weight.
+func (r *Resolver) goLookupSRV(ctx context.Context, service, proto, name string) (target string, srvs []*SRV, err error) {
+ if service == "" && proto == "" {
+ target = name
+ } else {
+ target = "_" + service + "._" + proto + "." + name
+ }
+ p, server, err := r.lookup(ctx, target, dnsmessage.TypeSRV, nil)
+ if err != nil {
+ return "", nil, err
+ }
+ var cname dnsmessage.Name
+ for {
+ h, err := p.AnswerHeader()
+ if err == dnsmessage.ErrSectionDone {
+ break
+ }
+ if err != nil {
+ return "", nil, &DNSError{
+ Err: "cannot unmarshal DNS message",
+ Name: name,
+ Server: server,
+ }
+ }
+ if h.Type != dnsmessage.TypeSRV {
+ if err := p.SkipAnswer(); err != nil {
+ return "", nil, &DNSError{
+ Err: "cannot unmarshal DNS message",
+ Name: name,
+ Server: server,
+ }
+ }
+ continue
+ }
+ if cname.Length == 0 && h.Name.Length != 0 {
+ cname = h.Name
+ }
+ srv, err := p.SRVResource()
+ if err != nil {
+ return "", nil, &DNSError{
+ Err: "cannot unmarshal DNS message",
+ Name: name,
+ Server: server,
+ }
+ }
+ srvs = append(srvs, &SRV{Target: srv.Target.String(), Port: srv.Port, Priority: srv.Priority, Weight: srv.Weight})
+ }
+ byPriorityWeight(srvs).sort()
+ return cname.String(), srvs, nil
+}
+
+// goLookupMX returns the MX records for name.
+func (r *Resolver) goLookupMX(ctx context.Context, name string) ([]*MX, error) {
+ p, server, err := r.lookup(ctx, name, dnsmessage.TypeMX, nil)
+ if err != nil {
+ return nil, err
+ }
+ var mxs []*MX
+ for {
+ h, err := p.AnswerHeader()
+ if err == dnsmessage.ErrSectionDone {
+ break
+ }
+ if err != nil {
+ return nil, &DNSError{
+ Err: "cannot unmarshal DNS message",
+ Name: name,
+ Server: server,
+ }
+ }
+ if h.Type != dnsmessage.TypeMX {
+ if err := p.SkipAnswer(); err != nil {
+ return nil, &DNSError{
+ Err: "cannot unmarshal DNS message",
+ Name: name,
+ Server: server,
+ }
+ }
+ continue
+ }
+ mx, err := p.MXResource()
+ if err != nil {
+ return nil, &DNSError{
+ Err: "cannot unmarshal DNS message",
+ Name: name,
+ Server: server,
+ }
+ }
+ mxs = append(mxs, &MX{Host: mx.MX.String(), Pref: mx.Pref})
+
+ }
+ byPref(mxs).sort()
+ return mxs, nil
+}
+
+// goLookupNS returns the NS records for name.
+func (r *Resolver) goLookupNS(ctx context.Context, name string) ([]*NS, error) {
+ p, server, err := r.lookup(ctx, name, dnsmessage.TypeNS, nil)
+ if err != nil {
+ return nil, err
+ }
+ var nss []*NS
+ for {
+ h, err := p.AnswerHeader()
+ if err == dnsmessage.ErrSectionDone {
+ break
+ }
+ if err != nil {
+ return nil, &DNSError{
+ Err: "cannot unmarshal DNS message",
+ Name: name,
+ Server: server,
+ }
+ }
+ if h.Type != dnsmessage.TypeNS {
+ if err := p.SkipAnswer(); err != nil {
+ return nil, &DNSError{
+ Err: "cannot unmarshal DNS message",
+ Name: name,
+ Server: server,
+ }
+ }
+ continue
+ }
+ ns, err := p.NSResource()
+ if err != nil {
+ return nil, &DNSError{
+ Err: "cannot unmarshal DNS message",
+ Name: name,
+ Server: server,
+ }
+ }
+ nss = append(nss, &NS{Host: ns.NS.String()})
+ }
+ return nss, nil
+}
+
+// goLookupTXT returns the TXT records from name.
+func (r *Resolver) goLookupTXT(ctx context.Context, name string) ([]string, error) {
+ p, server, err := r.lookup(ctx, name, dnsmessage.TypeTXT, nil)
+ if err != nil {
+ return nil, err
+ }
+ var txts []string
+ for {
+ h, err := p.AnswerHeader()
+ if err == dnsmessage.ErrSectionDone {
+ break
+ }
+ if err != nil {
+ return nil, &DNSError{
+ Err: "cannot unmarshal DNS message",
+ Name: name,
+ Server: server,
+ }
+ }
+ if h.Type != dnsmessage.TypeTXT {
+ if err := p.SkipAnswer(); err != nil {
+ return nil, &DNSError{
+ Err: "cannot unmarshal DNS message",
+ Name: name,
+ Server: server,
+ }
+ }
+ continue
+ }
+ txt, err := p.TXTResource()
+ if err != nil {
+ return nil, &DNSError{
+ Err: "cannot unmarshal DNS message",
+ Name: name,
+ Server: server,
+ }
+ }
+ // Multiple strings in one TXT record need to be
+ // concatenated without separator to be consistent
+ // with previous Go resolver.
+ n := 0
+ for _, s := range txt.TXT {
+ n += len(s)
+ }
+ txtJoin := make([]byte, 0, n)
+ for _, s := range txt.TXT {
+ txtJoin = append(txtJoin, s...)
+ }
+ if len(txts) == 0 {
+ txts = make([]string, 0, 1)
+ }
+ txts = append(txts, string(txtJoin))
+ }
+ return txts, nil
+}
+
+func parseCNAMEFromResources(resources []dnsmessage.Resource) (string, error) {
+ if len(resources) == 0 {
+ return "", errors.New("no CNAME record received")
+ }
+ c, ok := resources[0].Body.(*dnsmessage.CNAMEResource)
+ if !ok {
+ return "", errors.New("could not parse CNAME record")
+ }
+ return c.CNAME.String(), nil
+}
diff --git a/src/net/lookup_fake.go b/src/net/lookup_fake.go
new file mode 100644
index 0000000..c27eae4
--- /dev/null
+++ b/src/net/lookup_fake.go
@@ -0,0 +1,58 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build js && wasm
+
+package net
+
+import (
+ "context"
+ "syscall"
+)
+
+func lookupProtocol(ctx context.Context, name string) (proto int, err error) {
+ return lookupProtocolMap(name)
+}
+
+func (*Resolver) lookupHost(ctx context.Context, host string) (addrs []string, err error) {
+ return nil, syscall.ENOPROTOOPT
+}
+
+func (*Resolver) lookupIP(ctx context.Context, network, host string) (addrs []IPAddr, err error) {
+ return nil, syscall.ENOPROTOOPT
+}
+
+func (*Resolver) lookupPort(ctx context.Context, network, service string) (port int, err error) {
+ return goLookupPort(network, service)
+}
+
+func (*Resolver) lookupCNAME(ctx context.Context, name string) (cname string, err error) {
+ return "", syscall.ENOPROTOOPT
+}
+
+func (*Resolver) lookupSRV(ctx context.Context, service, proto, name string) (cname string, srvs []*SRV, err error) {
+ return "", nil, syscall.ENOPROTOOPT
+}
+
+func (*Resolver) lookupMX(ctx context.Context, name string) (mxs []*MX, err error) {
+ return nil, syscall.ENOPROTOOPT
+}
+
+func (*Resolver) lookupNS(ctx context.Context, name string) (nss []*NS, err error) {
+ return nil, syscall.ENOPROTOOPT
+}
+
+func (*Resolver) lookupTXT(ctx context.Context, name string) (txts []string, err error) {
+ return nil, syscall.ENOPROTOOPT
+}
+
+func (*Resolver) lookupAddr(ctx context.Context, addr string) (ptrs []string, err error) {
+ return nil, syscall.ENOPROTOOPT
+}
+
+// concurrentThreadsLimit returns the number of threads we permit to
+// run concurrently doing DNS lookups.
+func concurrentThreadsLimit() int {
+ return 500
+}
diff --git a/src/net/lookup_plan9.go b/src/net/lookup_plan9.go
new file mode 100644
index 0000000..5404b99
--- /dev/null
+++ b/src/net/lookup_plan9.go
@@ -0,0 +1,389 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "context"
+ "errors"
+ "internal/bytealg"
+ "internal/itoa"
+ "io"
+ "os"
+)
+
+// cgoAvailable set to true to indicate that the cgo resolver
+// is available on Plan 9. Note that on Plan 9 the cgo resolver
+// does not actually use cgo.
+const cgoAvailable = true
+
+func query(ctx context.Context, filename, query string, bufSize int) (addrs []string, err error) {
+ queryAddrs := func() (addrs []string, err error) {
+ file, err := os.OpenFile(filename, os.O_RDWR, 0)
+ if err != nil {
+ return nil, err
+ }
+ defer file.Close()
+
+ _, err = file.Seek(0, io.SeekStart)
+ if err != nil {
+ return nil, err
+ }
+ _, err = file.WriteString(query)
+ if err != nil {
+ return nil, err
+ }
+ _, err = file.Seek(0, io.SeekStart)
+ if err != nil {
+ return nil, err
+ }
+ buf := make([]byte, bufSize)
+ for {
+ n, _ := file.Read(buf)
+ if n <= 0 {
+ break
+ }
+ addrs = append(addrs, string(buf[:n]))
+ }
+ return addrs, nil
+ }
+
+ type ret struct {
+ addrs []string
+ err error
+ }
+
+ ch := make(chan ret, 1)
+ go func() {
+ addrs, err := queryAddrs()
+ ch <- ret{addrs: addrs, err: err}
+ }()
+
+ select {
+ case r := <-ch:
+ return r.addrs, r.err
+ case <-ctx.Done():
+ return nil, &DNSError{
+ Name: query,
+ Err: ctx.Err().Error(),
+ IsTimeout: ctx.Err() == context.DeadlineExceeded,
+ }
+ }
+}
+
+func queryCS(ctx context.Context, net, host, service string) (res []string, err error) {
+ switch net {
+ case "tcp4", "tcp6":
+ net = "tcp"
+ case "udp4", "udp6":
+ net = "udp"
+ }
+ if host == "" {
+ host = "*"
+ }
+ return query(ctx, netdir+"/cs", net+"!"+host+"!"+service, 128)
+}
+
+func queryCS1(ctx context.Context, net string, ip IP, port int) (clone, dest string, err error) {
+ ips := "*"
+ if len(ip) != 0 && !ip.IsUnspecified() {
+ ips = ip.String()
+ }
+ lines, err := queryCS(ctx, net, ips, itoa.Itoa(port))
+ if err != nil {
+ return
+ }
+ f := getFields(lines[0])
+ if len(f) < 2 {
+ return "", "", errors.New("bad response from ndb/cs")
+ }
+ clone, dest = f[0], f[1]
+ return
+}
+
+func queryDNS(ctx context.Context, addr string, typ string) (res []string, err error) {
+ return query(ctx, netdir+"/dns", addr+" "+typ, 1024)
+}
+
+// toLower returns a lower-case version of in. Restricting us to
+// ASCII is sufficient to handle the IP protocol names and allow
+// us to not depend on the strings and unicode packages.
+func toLower(in string) string {
+ for _, c := range in {
+ if 'A' <= c && c <= 'Z' {
+ // Has upper case; need to fix.
+ out := []byte(in)
+ for i := 0; i < len(in); i++ {
+ c := in[i]
+ if 'A' <= c && c <= 'Z' {
+ c += 'a' - 'A'
+ }
+ out[i] = c
+ }
+ return string(out)
+ }
+ }
+ return in
+}
+
+// lookupProtocol looks up IP protocol name and returns
+// the corresponding protocol number.
+func lookupProtocol(ctx context.Context, name string) (proto int, err error) {
+ lines, err := query(ctx, netdir+"/cs", "!protocol="+toLower(name), 128)
+ if err != nil {
+ return 0, err
+ }
+ if len(lines) == 0 {
+ return 0, UnknownNetworkError(name)
+ }
+ f := getFields(lines[0])
+ if len(f) < 2 {
+ return 0, UnknownNetworkError(name)
+ }
+ s := f[1]
+ if n, _, ok := dtoi(s[bytealg.IndexByteString(s, '=')+1:]); ok {
+ return n, nil
+ }
+ return 0, UnknownNetworkError(name)
+}
+
+func (*Resolver) lookupHost(ctx context.Context, host string) (addrs []string, err error) {
+ // Use netdir/cs instead of netdir/dns because cs knows about
+ // host names in local network (e.g. from /lib/ndb/local)
+ lines, err := queryCS(ctx, "net", host, "1")
+ if err != nil {
+ dnsError := &DNSError{Err: err.Error(), Name: host}
+ if stringsHasSuffix(err.Error(), "dns failure") {
+ dnsError.Err = errNoSuchHost.Error()
+ dnsError.IsNotFound = true
+ }
+ return nil, dnsError
+ }
+loop:
+ for _, line := range lines {
+ f := getFields(line)
+ if len(f) < 2 {
+ continue
+ }
+ addr := f[1]
+ if i := bytealg.IndexByteString(addr, '!'); i >= 0 {
+ addr = addr[:i] // remove port
+ }
+ if ParseIP(addr) == nil {
+ continue
+ }
+ // only return unique addresses
+ for _, a := range addrs {
+ if a == addr {
+ continue loop
+ }
+ }
+ addrs = append(addrs, addr)
+ }
+ return
+}
+
+// preferGoOverPlan9 reports whether the resolver should use the
+// "PreferGo" implementation rather than asking plan9 services
+// for the answers.
+func (r *Resolver) preferGoOverPlan9() bool {
+ _, _, res := r.preferGoOverPlan9WithOrderAndConf()
+ return res
+}
+
+func (r *Resolver) preferGoOverPlan9WithOrderAndConf() (hostLookupOrder, *dnsConfig, bool) {
+ order, conf := systemConf().hostLookupOrder(r, "") // name is unused
+
+ // TODO(bradfitz): for now we only permit use of the PreferGo
+ // implementation when there's a non-nil Resolver with a
+ // non-nil Dialer. This is a sign that they the code is trying
+ // to use their DNS-speaking net.Conn (such as an in-memory
+ // DNS cache) and they don't want to actually hit the network.
+ // Once we add support for looking the default DNS servers
+ // from plan9, though, then we can relax this.
+ return order, conf, order != hostLookupCgo && r != nil && r.Dial != nil
+}
+
+func (r *Resolver) lookupIP(ctx context.Context, network, host string) (addrs []IPAddr, err error) {
+ if r.preferGoOverPlan9() {
+ return r.goLookupIP(ctx, network, host)
+ }
+ lits, err := r.lookupHost(ctx, host)
+ if err != nil {
+ return
+ }
+ for _, lit := range lits {
+ host, zone := splitHostZone(lit)
+ if ip := ParseIP(host); ip != nil {
+ addr := IPAddr{IP: ip, Zone: zone}
+ addrs = append(addrs, addr)
+ }
+ }
+ return
+}
+
+func (*Resolver) lookupPort(ctx context.Context, network, service string) (port int, err error) {
+ switch network {
+ case "tcp4", "tcp6":
+ network = "tcp"
+ case "udp4", "udp6":
+ network = "udp"
+ }
+ lines, err := queryCS(ctx, network, "127.0.0.1", toLower(service))
+ if err != nil {
+ return
+ }
+ unknownPortError := &AddrError{Err: "unknown port", Addr: network + "/" + service}
+ if len(lines) == 0 {
+ return 0, unknownPortError
+ }
+ f := getFields(lines[0])
+ if len(f) < 2 {
+ return 0, unknownPortError
+ }
+ s := f[1]
+ if i := bytealg.IndexByteString(s, '!'); i >= 0 {
+ s = s[i+1:] // remove address
+ }
+ if n, _, ok := dtoi(s); ok {
+ return n, nil
+ }
+ return 0, unknownPortError
+}
+
+func (r *Resolver) lookupCNAME(ctx context.Context, name string) (cname string, err error) {
+ if order, conf, preferGo := r.preferGoOverPlan9WithOrderAndConf(); preferGo {
+ return r.goLookupCNAME(ctx, name, order, conf)
+ }
+
+ lines, err := queryDNS(ctx, name, "cname")
+ if err != nil {
+ if stringsHasSuffix(err.Error(), "dns failure") || stringsHasSuffix(err.Error(), "resource does not exist; negrcode 0") {
+ cname = name + "."
+ err = nil
+ }
+ return
+ }
+ if len(lines) > 0 {
+ if f := getFields(lines[0]); len(f) >= 3 {
+ return f[2] + ".", nil
+ }
+ }
+ return "", errors.New("bad response from ndb/dns")
+}
+
+func (r *Resolver) lookupSRV(ctx context.Context, service, proto, name string) (cname string, addrs []*SRV, err error) {
+ if r.preferGoOverPlan9() {
+ return r.goLookupSRV(ctx, service, proto, name)
+ }
+ var target string
+ if service == "" && proto == "" {
+ target = name
+ } else {
+ target = "_" + service + "._" + proto + "." + name
+ }
+ lines, err := queryDNS(ctx, target, "srv")
+ if err != nil {
+ return
+ }
+ for _, line := range lines {
+ f := getFields(line)
+ if len(f) < 6 {
+ continue
+ }
+ port, _, portOk := dtoi(f[4])
+ priority, _, priorityOk := dtoi(f[3])
+ weight, _, weightOk := dtoi(f[2])
+ if !(portOk && priorityOk && weightOk) {
+ continue
+ }
+ addrs = append(addrs, &SRV{absDomainName(f[5]), uint16(port), uint16(priority), uint16(weight)})
+ cname = absDomainName(f[0])
+ }
+ byPriorityWeight(addrs).sort()
+ return
+}
+
+func (r *Resolver) lookupMX(ctx context.Context, name string) (mx []*MX, err error) {
+ if r.preferGoOverPlan9() {
+ return r.goLookupMX(ctx, name)
+ }
+ lines, err := queryDNS(ctx, name, "mx")
+ if err != nil {
+ return
+ }
+ for _, line := range lines {
+ f := getFields(line)
+ if len(f) < 4 {
+ continue
+ }
+ if pref, _, ok := dtoi(f[2]); ok {
+ mx = append(mx, &MX{absDomainName(f[3]), uint16(pref)})
+ }
+ }
+ byPref(mx).sort()
+ return
+}
+
+func (r *Resolver) lookupNS(ctx context.Context, name string) (ns []*NS, err error) {
+ if r.preferGoOverPlan9() {
+ return r.goLookupNS(ctx, name)
+ }
+ lines, err := queryDNS(ctx, name, "ns")
+ if err != nil {
+ return
+ }
+ for _, line := range lines {
+ f := getFields(line)
+ if len(f) < 3 {
+ continue
+ }
+ ns = append(ns, &NS{absDomainName(f[2])})
+ }
+ return
+}
+
+func (r *Resolver) lookupTXT(ctx context.Context, name string) (txt []string, err error) {
+ if r.preferGoOverPlan9() {
+ return r.goLookupTXT(ctx, name)
+ }
+ lines, err := queryDNS(ctx, name, "txt")
+ if err != nil {
+ return
+ }
+ for _, line := range lines {
+ if i := bytealg.IndexByteString(line, '\t'); i >= 0 {
+ txt = append(txt, line[i+1:])
+ }
+ }
+ return
+}
+
+func (r *Resolver) lookupAddr(ctx context.Context, addr string) (name []string, err error) {
+ if order, conf, preferGo := r.preferGoOverPlan9WithOrderAndConf(); preferGo {
+ return r.goLookupPTR(ctx, addr, order, conf)
+ }
+ arpa, err := reverseaddr(addr)
+ if err != nil {
+ return
+ }
+ lines, err := queryDNS(ctx, arpa, "ptr")
+ if err != nil {
+ return
+ }
+ for _, line := range lines {
+ f := getFields(line)
+ if len(f) < 3 {
+ continue
+ }
+ name = append(name, absDomainName(f[2]))
+ }
+ return
+}
+
+// concurrentThreadsLimit returns the number of threads we permit to
+// run concurrently doing DNS lookups.
+func concurrentThreadsLimit() int {
+ return 500
+}
diff --git a/src/net/lookup_test.go b/src/net/lookup_test.go
new file mode 100644
index 0000000..0689c19
--- /dev/null
+++ b/src/net/lookup_test.go
@@ -0,0 +1,1464 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !js && !wasip1
+
+package net
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "internal/testenv"
+ "net/netip"
+ "reflect"
+ "runtime"
+ "sort"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+)
+
+func hasSuffixFold(s, suffix string) bool {
+ return strings.HasSuffix(strings.ToLower(s), strings.ToLower(suffix))
+}
+
+func lookupLocalhost(ctx context.Context, fn func(context.Context, string, string) ([]IPAddr, error), network, host string) ([]IPAddr, error) {
+ switch host {
+ case "localhost":
+ return []IPAddr{
+ {IP: IPv4(127, 0, 0, 1)},
+ {IP: IPv6loopback},
+ }, nil
+ default:
+ return fn(ctx, network, host)
+ }
+}
+
+// The Lookup APIs use various sources such as local database, DNS or
+// mDNS, and may use platform-dependent DNS stub resolver if possible.
+// The APIs accept any of forms for a query; host name in various
+// encodings, UTF-8 encoded net name, domain name, FQDN or absolute
+// FQDN, but the result would be one of the forms and it depends on
+// the circumstances.
+
+var lookupGoogleSRVTests = []struct {
+ service, proto, name string
+ cname, target string
+}{
+ {
+ "ldap", "tcp", "google.com",
+ "google.com.", "google.com.",
+ },
+ {
+ "ldap", "tcp", "google.com.",
+ "google.com.", "google.com.",
+ },
+
+ // non-standard back door
+ {
+ "", "", "_ldap._tcp.google.com",
+ "google.com.", "google.com.",
+ },
+ {
+ "", "", "_ldap._tcp.google.com.",
+ "google.com.", "google.com.",
+ },
+}
+
+var backoffDuration = [...]time.Duration{time.Second, 5 * time.Second, 30 * time.Second}
+
+func TestLookupGoogleSRV(t *testing.T) {
+ t.Parallel()
+ mustHaveExternalNetwork(t)
+
+ if runtime.GOOS == "ios" {
+ t.Skip("no resolv.conf on iOS")
+ }
+
+ if !supportsIPv4() || !*testIPv4 {
+ t.Skip("IPv4 is required")
+ }
+
+ attempts := 0
+ for i := 0; i < len(lookupGoogleSRVTests); i++ {
+ tt := lookupGoogleSRVTests[i]
+ cname, srvs, err := LookupSRV(tt.service, tt.proto, tt.name)
+ if err != nil {
+ testenv.SkipFlakyNet(t)
+ if attempts < len(backoffDuration) {
+ dur := backoffDuration[attempts]
+ t.Logf("backoff %v after failure %v\n", dur, err)
+ time.Sleep(dur)
+ attempts++
+ i--
+ continue
+ }
+ t.Fatal(err)
+ }
+ if len(srvs) == 0 {
+ t.Error("got no record")
+ }
+ if !hasSuffixFold(cname, tt.cname) {
+ t.Errorf("got %s; want %s", cname, tt.cname)
+ }
+ for _, srv := range srvs {
+ if !hasSuffixFold(srv.Target, tt.target) {
+ t.Errorf("got %v; want a record containing %s", srv, tt.target)
+ }
+ }
+ }
+}
+
+var lookupGmailMXTests = []struct {
+ name, host string
+}{
+ {"gmail.com", "google.com."},
+ {"gmail.com.", "google.com."},
+}
+
+func TestLookupGmailMX(t *testing.T) {
+ t.Parallel()
+ mustHaveExternalNetwork(t)
+
+ if runtime.GOOS == "ios" {
+ t.Skip("no resolv.conf on iOS")
+ }
+
+ if !supportsIPv4() || !*testIPv4 {
+ t.Skip("IPv4 is required")
+ }
+
+ attempts := 0
+ for i := 0; i < len(lookupGmailMXTests); i++ {
+ tt := lookupGmailMXTests[i]
+ mxs, err := LookupMX(tt.name)
+ if err != nil {
+ testenv.SkipFlakyNet(t)
+ if attempts < len(backoffDuration) {
+ dur := backoffDuration[attempts]
+ t.Logf("backoff %v after failure %v\n", dur, err)
+ time.Sleep(dur)
+ attempts++
+ i--
+ continue
+ }
+ t.Fatal(err)
+ }
+ if len(mxs) == 0 {
+ t.Error("got no record")
+ }
+ for _, mx := range mxs {
+ if !hasSuffixFold(mx.Host, tt.host) {
+ t.Errorf("got %v; want a record containing %s", mx, tt.host)
+ }
+ }
+ }
+}
+
+var lookupGmailNSTests = []struct {
+ name, host string
+}{
+ {"gmail.com", "google.com."},
+ {"gmail.com.", "google.com."},
+}
+
+func TestLookupGmailNS(t *testing.T) {
+ t.Parallel()
+ mustHaveExternalNetwork(t)
+
+ if runtime.GOOS == "ios" {
+ t.Skip("no resolv.conf on iOS")
+ }
+
+ if !supportsIPv4() || !*testIPv4 {
+ t.Skip("IPv4 is required")
+ }
+
+ attempts := 0
+ for i := 0; i < len(lookupGmailNSTests); i++ {
+ tt := lookupGmailNSTests[i]
+ nss, err := LookupNS(tt.name)
+ if err != nil {
+ testenv.SkipFlakyNet(t)
+ if attempts < len(backoffDuration) {
+ dur := backoffDuration[attempts]
+ t.Logf("backoff %v after failure %v\n", dur, err)
+ time.Sleep(dur)
+ attempts++
+ i--
+ continue
+ }
+ t.Fatal(err)
+ }
+ if len(nss) == 0 {
+ t.Error("got no record")
+ }
+ for _, ns := range nss {
+ if !hasSuffixFold(ns.Host, tt.host) {
+ t.Errorf("got %v; want a record containing %s", ns, tt.host)
+ }
+ }
+ }
+}
+
+var lookupGmailTXTTests = []struct {
+ name, txt, host string
+}{
+ {"gmail.com", "spf", "google.com"},
+ {"gmail.com.", "spf", "google.com"},
+}
+
+func TestLookupGmailTXT(t *testing.T) {
+ if runtime.GOOS == "plan9" {
+ t.Skip("skipping on plan9; see https://golang.org/issue/29722")
+ }
+ t.Parallel()
+ mustHaveExternalNetwork(t)
+
+ if runtime.GOOS == "ios" {
+ t.Skip("no resolv.conf on iOS")
+ }
+
+ if !supportsIPv4() || !*testIPv4 {
+ t.Skip("IPv4 is required")
+ }
+
+ attempts := 0
+ for i := 0; i < len(lookupGmailTXTTests); i++ {
+ tt := lookupGmailTXTTests[i]
+ txts, err := LookupTXT(tt.name)
+ if err != nil {
+ testenv.SkipFlakyNet(t)
+ if attempts < len(backoffDuration) {
+ dur := backoffDuration[attempts]
+ t.Logf("backoff %v after failure %v\n", dur, err)
+ time.Sleep(dur)
+ attempts++
+ i--
+ continue
+ }
+ t.Fatal(err)
+ }
+ if len(txts) == 0 {
+ t.Error("got no record")
+ }
+ found := false
+ for _, txt := range txts {
+ if strings.Contains(txt, tt.txt) && (strings.HasSuffix(txt, tt.host) || strings.HasSuffix(txt, tt.host+".")) {
+ found = true
+ break
+ }
+ }
+ if !found {
+ t.Errorf("got %v; want a record containing %s, %s", txts, tt.txt, tt.host)
+ }
+ }
+}
+
+var lookupGooglePublicDNSAddrTests = []string{
+ "8.8.8.8",
+ "8.8.4.4",
+ "2001:4860:4860::8888",
+ "2001:4860:4860::8844",
+}
+
+func TestLookupGooglePublicDNSAddr(t *testing.T) {
+ mustHaveExternalNetwork(t)
+
+ if !supportsIPv4() || !supportsIPv6() || !*testIPv4 || !*testIPv6 {
+ t.Skip("both IPv4 and IPv6 are required")
+ }
+
+ defer dnsWaitGroup.Wait()
+
+ for _, ip := range lookupGooglePublicDNSAddrTests {
+ names, err := LookupAddr(ip)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(names) == 0 {
+ t.Error("got no record")
+ }
+ for _, name := range names {
+ if !hasSuffixFold(name, ".google.com.") && !hasSuffixFold(name, ".google.") {
+ t.Errorf("got %q; want a record ending in .google.com. or .google.", name)
+ }
+ }
+ }
+}
+
+func TestLookupIPv6LinkLocalAddr(t *testing.T) {
+ if !supportsIPv6() || !*testIPv6 {
+ t.Skip("IPv6 is required")
+ }
+
+ defer dnsWaitGroup.Wait()
+
+ addrs, err := LookupHost("localhost")
+ if err != nil {
+ t.Fatal(err)
+ }
+ found := false
+ for _, addr := range addrs {
+ if addr == "fe80::1%lo0" {
+ found = true
+ break
+ }
+ }
+ if !found {
+ t.Skipf("not supported on %s", runtime.GOOS)
+ }
+ if _, err := LookupAddr("fe80::1%lo0"); err != nil {
+ t.Error(err)
+ }
+}
+
+func TestLookupIPv6LinkLocalAddrWithZone(t *testing.T) {
+ if !supportsIPv6() || !*testIPv6 {
+ t.Skip("IPv6 is required")
+ }
+
+ ipaddrs, err := DefaultResolver.LookupIPAddr(context.Background(), "fe80::1%lo0")
+ if err != nil {
+ t.Error(err)
+ }
+ for _, addr := range ipaddrs {
+ if e, a := "lo0", addr.Zone; e != a {
+ t.Errorf("wrong zone: want %q, got %q", e, a)
+ }
+ }
+
+ addrs, err := DefaultResolver.LookupHost(context.Background(), "fe80::1%lo0")
+ if err != nil {
+ t.Error(err)
+ }
+ for _, addr := range addrs {
+ if e, a := "fe80::1%lo0", addr; e != a {
+ t.Errorf("wrong host: want %q got %q", e, a)
+ }
+ }
+}
+
+var lookupCNAMETests = []struct {
+ name, cname string
+}{
+ {"www.iana.org", "icann.org."},
+ {"www.iana.org.", "icann.org."},
+ {"www.google.com", "google.com."},
+ {"google.com", "google.com."},
+ {"cname-to-txt.go4.org", "test-txt-record.go4.org."},
+}
+
+func TestLookupCNAME(t *testing.T) {
+ mustHaveExternalNetwork(t)
+ testenv.SkipFlakyNet(t)
+
+ if !supportsIPv4() || !*testIPv4 {
+ t.Skip("IPv4 is required")
+ }
+
+ defer dnsWaitGroup.Wait()
+
+ attempts := 0
+ for i := 0; i < len(lookupCNAMETests); i++ {
+ tt := lookupCNAMETests[i]
+ cname, err := LookupCNAME(tt.name)
+ if err != nil {
+ testenv.SkipFlakyNet(t)
+ if attempts < len(backoffDuration) {
+ dur := backoffDuration[attempts]
+ t.Logf("backoff %v after failure %v\n", dur, err)
+ time.Sleep(dur)
+ attempts++
+ i--
+ continue
+ }
+ t.Fatal(err)
+ }
+ if !hasSuffixFold(cname, tt.cname) {
+ t.Errorf("got %s; want a record containing %s", cname, tt.cname)
+ }
+ }
+}
+
+var lookupGoogleHostTests = []struct {
+ name string
+}{
+ {"google.com"},
+ {"google.com."},
+}
+
+func TestLookupGoogleHost(t *testing.T) {
+ mustHaveExternalNetwork(t)
+ testenv.SkipFlakyNet(t)
+
+ if !supportsIPv4() || !*testIPv4 {
+ t.Skip("IPv4 is required")
+ }
+
+ defer dnsWaitGroup.Wait()
+
+ for _, tt := range lookupGoogleHostTests {
+ addrs, err := LookupHost(tt.name)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(addrs) == 0 {
+ t.Error("got no record")
+ }
+ for _, addr := range addrs {
+ if ParseIP(addr) == nil {
+ t.Errorf("got %q; want a literal IP address", addr)
+ }
+ }
+ }
+}
+
+func TestLookupLongTXT(t *testing.T) {
+ testenv.SkipFlaky(t, 22857)
+ mustHaveExternalNetwork(t)
+
+ defer dnsWaitGroup.Wait()
+
+ txts, err := LookupTXT("golang.rsc.io")
+ if err != nil {
+ t.Fatal(err)
+ }
+ sort.Strings(txts)
+ want := []string{
+ strings.Repeat("abcdefghijklmnopqrstuvwxyABCDEFGHJIKLMNOPQRSTUVWXY", 10),
+ "gophers rule",
+ }
+ if !reflect.DeepEqual(txts, want) {
+ t.Fatalf("LookupTXT golang.rsc.io incorrect\nhave %q\nwant %q", txts, want)
+ }
+}
+
+var lookupGoogleIPTests = []struct {
+ name string
+}{
+ {"google.com"},
+ {"google.com."},
+}
+
+func TestLookupGoogleIP(t *testing.T) {
+ mustHaveExternalNetwork(t)
+ testenv.SkipFlakyNet(t)
+
+ if !supportsIPv4() || !*testIPv4 {
+ t.Skip("IPv4 is required")
+ }
+
+ defer dnsWaitGroup.Wait()
+
+ for _, tt := range lookupGoogleIPTests {
+ ips, err := LookupIP(tt.name)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(ips) == 0 {
+ t.Error("got no record")
+ }
+ for _, ip := range ips {
+ if ip.To4() == nil && ip.To16() == nil {
+ t.Errorf("got %v; want an IP address", ip)
+ }
+ }
+ }
+}
+
+var revAddrTests = []struct {
+ Addr string
+ Reverse string
+ ErrPrefix string
+}{
+ {"1.2.3.4", "4.3.2.1.in-addr.arpa.", ""},
+ {"245.110.36.114", "114.36.110.245.in-addr.arpa.", ""},
+ {"::ffff:12.34.56.78", "78.56.34.12.in-addr.arpa.", ""},
+ {"::1", "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa.", ""},
+ {"1::", "0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.1.0.0.0.ip6.arpa.", ""},
+ {"1234:567::89a:bcde", "e.d.c.b.a.9.8.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.7.6.5.0.4.3.2.1.ip6.arpa.", ""},
+ {"1234:567:fefe:bcbc:adad:9e4a:89a:bcde", "e.d.c.b.a.9.8.0.a.4.e.9.d.a.d.a.c.b.c.b.e.f.e.f.7.6.5.0.4.3.2.1.ip6.arpa.", ""},
+ {"1.2.3", "", "unrecognized address"},
+ {"1.2.3.4.5", "", "unrecognized address"},
+ {"1234:567:bcbca::89a:bcde", "", "unrecognized address"},
+ {"1234:567::bcbc:adad::89a:bcde", "", "unrecognized address"},
+}
+
+func TestReverseAddress(t *testing.T) {
+ defer dnsWaitGroup.Wait()
+ for i, tt := range revAddrTests {
+ a, err := reverseaddr(tt.Addr)
+ if len(tt.ErrPrefix) > 0 && err == nil {
+ t.Errorf("#%d: expected %q, got <nil> (error)", i, tt.ErrPrefix)
+ continue
+ }
+ if len(tt.ErrPrefix) == 0 && err != nil {
+ t.Errorf("#%d: expected <nil>, got %q (error)", i, err)
+ }
+ if err != nil && err.(*DNSError).Err != tt.ErrPrefix {
+ t.Errorf("#%d: expected %q, got %q (mismatched error)", i, tt.ErrPrefix, err.(*DNSError).Err)
+ }
+ if a != tt.Reverse {
+ t.Errorf("#%d: expected %q, got %q (reverse address)", i, tt.Reverse, a)
+ }
+ }
+}
+
+func TestDNSFlood(t *testing.T) {
+ if !*testDNSFlood {
+ t.Skip("test disabled; use -dnsflood to enable")
+ }
+
+ defer dnsWaitGroup.Wait()
+
+ var N = 5000
+ if runtime.GOOS == "darwin" || runtime.GOOS == "ios" {
+ // On Darwin this test consumes kernel threads much
+ // than other platforms for some reason.
+ // When we monitor the number of allocated Ms by
+ // observing on runtime.newm calls, we can see that it
+ // easily reaches the per process ceiling
+ // kern.num_threads when CGO_ENABLED=1 and
+ // GODEBUG=netdns=go.
+ N = 500
+ }
+
+ const timeout = 3 * time.Second
+ ctxHalfTimeout, cancel := context.WithTimeout(context.Background(), timeout/2)
+ defer cancel()
+ ctxTimeout, cancel := context.WithTimeout(context.Background(), timeout)
+ defer cancel()
+
+ c := make(chan error, 2*N)
+ for i := 0; i < N; i++ {
+ name := fmt.Sprintf("%d.net-test.golang.org", i)
+ go func() {
+ _, err := DefaultResolver.LookupIPAddr(ctxHalfTimeout, name)
+ c <- err
+ }()
+ go func() {
+ _, err := DefaultResolver.LookupIPAddr(ctxTimeout, name)
+ c <- err
+ }()
+ }
+ qstats := struct {
+ succeeded, failed int
+ timeout, temporary, other int
+ unknown int
+ }{}
+ deadline := time.After(timeout + time.Second)
+ for i := 0; i < 2*N; i++ {
+ select {
+ case <-deadline:
+ t.Fatal("deadline exceeded")
+ case err := <-c:
+ switch err := err.(type) {
+ case nil:
+ qstats.succeeded++
+ case Error:
+ qstats.failed++
+ if err.Timeout() {
+ qstats.timeout++
+ }
+ if err.Temporary() {
+ qstats.temporary++
+ }
+ if !err.Timeout() && !err.Temporary() {
+ qstats.other++
+ }
+ default:
+ qstats.failed++
+ qstats.unknown++
+ }
+ }
+ }
+
+ // A high volume of DNS queries for sub-domain of golang.org
+ // would be coordinated by authoritative or recursive server,
+ // or stub resolver which implements query-response rate
+ // limitation, so we can expect some query successes and more
+ // failures including timeout, temporary and other here.
+ // As a rule, unknown must not be shown but it might possibly
+ // happen due to issue 4856 for now.
+ t.Logf("%v succeeded, %v failed (%v timeout, %v temporary, %v other, %v unknown)", qstats.succeeded, qstats.failed, qstats.timeout, qstats.temporary, qstats.other, qstats.unknown)
+}
+
+func TestLookupDotsWithLocalSource(t *testing.T) {
+ if !supportsIPv4() || !*testIPv4 {
+ t.Skip("IPv4 is required")
+ }
+
+ mustHaveExternalNetwork(t)
+
+ defer dnsWaitGroup.Wait()
+
+ for i, fn := range []func() func(){forceGoDNS, forceCgoDNS} {
+ fixup := fn()
+ if fixup == nil {
+ continue
+ }
+ names, err := LookupAddr("127.0.0.1")
+ fixup()
+ if err != nil {
+ t.Logf("#%d: %v", i, err)
+ continue
+ }
+ mode := "netgo"
+ if i == 1 {
+ mode = "netcgo"
+ }
+ loop:
+ for i, name := range names {
+ if strings.Index(name, ".") == len(name)-1 { // "localhost" not "localhost."
+ for j := range names {
+ if j == i {
+ continue
+ }
+ if names[j] == name[:len(name)-1] {
+ // It's OK if we find the name without the dot,
+ // as some systems say 127.0.0.1 localhost localhost.
+ continue loop
+ }
+ }
+ t.Errorf("%s: got %s; want %s", mode, name, name[:len(name)-1])
+ } else if strings.Contains(name, ".") && !strings.HasSuffix(name, ".") { // "localhost.localdomain." not "localhost.localdomain"
+ t.Errorf("%s: got %s; want name ending with trailing dot", mode, name)
+ }
+ }
+ }
+}
+
+func TestLookupDotsWithRemoteSource(t *testing.T) {
+ if runtime.GOOS == "darwin" || runtime.GOOS == "ios" {
+ testenv.SkipFlaky(t, 27992)
+ }
+ mustHaveExternalNetwork(t)
+ testenv.SkipFlakyNet(t)
+
+ if !supportsIPv4() || !*testIPv4 {
+ t.Skip("IPv4 is required")
+ }
+
+ if runtime.GOOS == "ios" {
+ t.Skip("no resolv.conf on iOS")
+ }
+
+ defer dnsWaitGroup.Wait()
+
+ if fixup := forceGoDNS(); fixup != nil {
+ testDots(t, "go")
+ fixup()
+ }
+ if fixup := forceCgoDNS(); fixup != nil {
+ testDots(t, "cgo")
+ fixup()
+ }
+}
+
+func testDots(t *testing.T, mode string) {
+ names, err := LookupAddr("8.8.8.8") // Google dns server
+ if err != nil {
+ t.Errorf("LookupAddr(8.8.8.8): %v (mode=%v)", err, mode)
+ } else {
+ for _, name := range names {
+ if !hasSuffixFold(name, ".google.com.") && !hasSuffixFold(name, ".google.") {
+ t.Errorf("LookupAddr(8.8.8.8) = %v, want names ending in .google.com or .google with trailing dot (mode=%v)", names, mode)
+ break
+ }
+ }
+ }
+
+ cname, err := LookupCNAME("www.mit.edu")
+ if err != nil {
+ t.Errorf("LookupCNAME(www.mit.edu, mode=%v): %v", mode, err)
+ } else if !strings.HasSuffix(cname, ".") {
+ t.Errorf("LookupCNAME(www.mit.edu) = %v, want cname ending in . with trailing dot (mode=%v)", cname, mode)
+ }
+
+ mxs, err := LookupMX("google.com")
+ if err != nil {
+ t.Errorf("LookupMX(google.com): %v (mode=%v)", err, mode)
+ } else {
+ for _, mx := range mxs {
+ if !hasSuffixFold(mx.Host, ".google.com.") {
+ t.Errorf("LookupMX(google.com) = %v, want names ending in .google.com. with trailing dot (mode=%v)", mxString(mxs), mode)
+ break
+ }
+ }
+ }
+
+ nss, err := LookupNS("google.com")
+ if err != nil {
+ t.Errorf("LookupNS(google.com): %v (mode=%v)", err, mode)
+ } else {
+ for _, ns := range nss {
+ if !hasSuffixFold(ns.Host, ".google.com.") {
+ t.Errorf("LookupNS(google.com) = %v, want names ending in .google.com. with trailing dot (mode=%v)", nsString(nss), mode)
+ break
+ }
+ }
+ }
+
+ cname, srvs, err := LookupSRV("ldap", "tcp", "google.com")
+ if err != nil {
+ t.Errorf("LookupSRV(ldap, tcp, google.com): %v (mode=%v)", err, mode)
+ } else {
+ if !hasSuffixFold(cname, ".google.com.") {
+ t.Errorf("LookupSRV(ldap, tcp, google.com) returned cname=%v, want name ending in .google.com. with trailing dot (mode=%v)", cname, mode)
+ }
+ for _, srv := range srvs {
+ if !hasSuffixFold(srv.Target, ".google.com.") {
+ t.Errorf("LookupSRV(ldap, tcp, google.com) returned addrs=%v, want names ending in .google.com. with trailing dot (mode=%v)", srvString(srvs), mode)
+ break
+ }
+ }
+ }
+}
+
+func mxString(mxs []*MX) string {
+ var buf strings.Builder
+ sep := ""
+ fmt.Fprintf(&buf, "[")
+ for _, mx := range mxs {
+ fmt.Fprintf(&buf, "%s%s:%d", sep, mx.Host, mx.Pref)
+ sep = " "
+ }
+ fmt.Fprintf(&buf, "]")
+ return buf.String()
+}
+
+func nsString(nss []*NS) string {
+ var buf strings.Builder
+ sep := ""
+ fmt.Fprintf(&buf, "[")
+ for _, ns := range nss {
+ fmt.Fprintf(&buf, "%s%s", sep, ns.Host)
+ sep = " "
+ }
+ fmt.Fprintf(&buf, "]")
+ return buf.String()
+}
+
+func srvString(srvs []*SRV) string {
+ var buf strings.Builder
+ sep := ""
+ fmt.Fprintf(&buf, "[")
+ for _, srv := range srvs {
+ fmt.Fprintf(&buf, "%s%s:%d:%d:%d", sep, srv.Target, srv.Port, srv.Priority, srv.Weight)
+ sep = " "
+ }
+ fmt.Fprintf(&buf, "]")
+ return buf.String()
+}
+
+func TestLookupPort(t *testing.T) {
+ // See https://www.iana.org/assignments/service-names-port-numbers/service-names-port-numbers.xhtml
+ //
+ // Please be careful about adding new test cases.
+ // There are platforms which have incomplete mappings for
+ // restricted resource access and security reasons.
+ type test struct {
+ network string
+ name string
+ port int
+ ok bool
+ }
+ var tests = []test{
+ {"tcp", "0", 0, true},
+ {"udp", "0", 0, true},
+ {"udp", "domain", 53, true},
+
+ {"--badnet--", "zzz", 0, false},
+ {"tcp", "--badport--", 0, false},
+ {"tcp", "-1", 0, false},
+ {"tcp", "65536", 0, false},
+ {"udp", "-1", 0, false},
+ {"udp", "65536", 0, false},
+ {"tcp", "123456789", 0, false},
+
+ // Issue 13610: LookupPort("tcp", "")
+ {"tcp", "", 0, true},
+ {"tcp4", "", 0, true},
+ {"tcp6", "", 0, true},
+ {"udp", "", 0, true},
+ {"udp4", "", 0, true},
+ {"udp6", "", 0, true},
+ }
+
+ switch runtime.GOOS {
+ case "android":
+ if netGoBuildTag {
+ t.Skipf("not supported on %s without cgo; see golang.org/issues/14576", runtime.GOOS)
+ }
+ default:
+ tests = append(tests, test{"tcp", "http", 80, true})
+ }
+
+ for _, tt := range tests {
+ port, err := LookupPort(tt.network, tt.name)
+ if port != tt.port || (err == nil) != tt.ok {
+ t.Errorf("LookupPort(%q, %q) = %d, %v; want %d, error=%t", tt.network, tt.name, port, err, tt.port, !tt.ok)
+ }
+ if err != nil {
+ if perr := parseLookupPortError(err); perr != nil {
+ t.Error(perr)
+ }
+ }
+ }
+}
+
+// Like TestLookupPort but with minimal tests that should always pass
+// because the answers are baked-in to the net package.
+func TestLookupPort_Minimal(t *testing.T) {
+ type test struct {
+ network string
+ name string
+ port int
+ }
+ var tests = []test{
+ {"tcp", "http", 80},
+ {"tcp", "HTTP", 80}, // case shouldn't matter
+ {"tcp", "https", 443},
+ {"tcp", "ssh", 22},
+ {"tcp", "gopher", 70},
+ {"tcp4", "http", 80},
+ {"tcp6", "http", 80},
+ }
+
+ for _, tt := range tests {
+ port, err := LookupPort(tt.network, tt.name)
+ if port != tt.port || err != nil {
+ t.Errorf("LookupPort(%q, %q) = %d, %v; want %d, error=nil", tt.network, tt.name, port, err, tt.port)
+ }
+ }
+}
+
+func TestLookupProtocol_Minimal(t *testing.T) {
+ type test struct {
+ name string
+ want int
+ }
+ var tests = []test{
+ {"tcp", 6},
+ {"TcP", 6}, // case shouldn't matter
+ {"icmp", 1},
+ {"igmp", 2},
+ {"udp", 17},
+ {"ipv6-icmp", 58},
+ }
+
+ for _, tt := range tests {
+ got, err := lookupProtocol(context.Background(), tt.name)
+ if got != tt.want || err != nil {
+ t.Errorf("LookupProtocol(%q) = %d, %v; want %d, error=nil", tt.name, got, err, tt.want)
+ }
+ }
+
+}
+
+func TestLookupNonLDH(t *testing.T) {
+ defer dnsWaitGroup.Wait()
+
+ if fixup := forceGoDNS(); fixup != nil {
+ defer fixup()
+ }
+
+ // "LDH" stands for letters, digits, and hyphens and is the usual
+ // description of standard DNS names.
+ // This test is checking that other kinds of names are reported
+ // as not found, not reported as invalid names.
+ addrs, err := LookupHost("!!!.###.bogus..domain.")
+ if err == nil {
+ t.Fatalf("lookup succeeded: %v", addrs)
+ }
+ if !strings.HasSuffix(err.Error(), errNoSuchHost.Error()) {
+ t.Fatalf("lookup error = %v, want %v", err, errNoSuchHost)
+ }
+ if !err.(*DNSError).IsNotFound {
+ t.Fatalf("lookup error = %v, want true", err.(*DNSError).IsNotFound)
+ }
+}
+
+func TestLookupContextCancel(t *testing.T) {
+ mustHaveExternalNetwork(t)
+ testenv.SkipFlakyNet(t)
+
+ origTestHookLookupIP := testHookLookupIP
+ defer func() {
+ dnsWaitGroup.Wait()
+ testHookLookupIP = origTestHookLookupIP
+ }()
+
+ lookupCtx, cancelLookup := context.WithCancel(context.Background())
+ unblockLookup := make(chan struct{})
+
+ // Set testHookLookupIP to start a new, concurrent call to LookupIPAddr
+ // and cancel the original one, then block until the canceled call has returned
+ // (ensuring that it has performed any synchronous cleanup).
+ testHookLookupIP = func(
+ ctx context.Context,
+ fn func(context.Context, string, string) ([]IPAddr, error),
+ network string,
+ host string,
+ ) ([]IPAddr, error) {
+ select {
+ case <-unblockLookup:
+ default:
+ // Start a concurrent LookupIPAddr for the same host while the caller is
+ // still blocked, and sleep a little to give it time to be deduplicated
+ // before we cancel (and unblock) the caller.
+ // (If the timing doesn't quite work out, we'll end up testing sequential
+ // calls instead of concurrent ones, but the test should still pass.)
+ t.Logf("starting concurrent LookupIPAddr")
+ dnsWaitGroup.Add(1)
+ go func() {
+ defer dnsWaitGroup.Done()
+ _, err := DefaultResolver.LookupIPAddr(context.Background(), host)
+ if err != nil {
+ t.Error(err)
+ }
+ }()
+ time.Sleep(1 * time.Millisecond)
+ }
+
+ cancelLookup()
+ <-unblockLookup
+ // If the concurrent lookup above is deduplicated to this one
+ // (as we expect to happen most of the time), it is important
+ // that the original call does not cancel the shared Context.
+ // (See https://go.dev/issue/22724.) Explicitly check for
+ // cancellation now, just in case fn itself doesn't notice it.
+ if err := ctx.Err(); err != nil {
+ t.Logf("testHookLookupIP canceled")
+ return nil, err
+ }
+ t.Logf("testHookLookupIP performing lookup")
+ return fn(ctx, network, host)
+ }
+
+ _, err := DefaultResolver.LookupIPAddr(lookupCtx, "google.com")
+ if dnsErr, ok := err.(*DNSError); !ok || dnsErr.Err != errCanceled.Error() {
+ t.Errorf("unexpected error from canceled, blocked LookupIPAddr: %v", err)
+ }
+ close(unblockLookup)
+}
+
+// Issue 24330: treat the nil *Resolver like a zero value. Verify nothing
+// crashes if nil is used.
+func TestNilResolverLookup(t *testing.T) {
+ mustHaveExternalNetwork(t)
+ var r *Resolver = nil
+ ctx := context.Background()
+
+ // Don't care about the results, just that nothing panics:
+ r.LookupAddr(ctx, "8.8.8.8")
+ r.LookupCNAME(ctx, "google.com")
+ r.LookupHost(ctx, "google.com")
+ r.LookupIPAddr(ctx, "google.com")
+ r.LookupIP(ctx, "ip", "google.com")
+ r.LookupMX(ctx, "gmail.com")
+ r.LookupNS(ctx, "google.com")
+ r.LookupPort(ctx, "tcp", "smtp")
+ r.LookupSRV(ctx, "service", "proto", "name")
+ r.LookupTXT(ctx, "gmail.com")
+}
+
+// TestLookupHostCancel verifies that lookup works even after many
+// canceled lookups (see golang.org/issue/24178 for details).
+func TestLookupHostCancel(t *testing.T) {
+ mustHaveExternalNetwork(t)
+ testenv.SkipFlakyNet(t)
+ t.Parallel() // Executes 600ms worth of sequential sleeps.
+
+ const (
+ google = "www.google.com"
+ invalidDomain = "invalid.invalid" // RFC 2606 reserves .invalid
+ n = 600 // this needs to be larger than threadLimit size
+ )
+
+ _, err := LookupHost(google)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+ for i := 0; i < n; i++ {
+ addr, err := DefaultResolver.LookupHost(ctx, invalidDomain)
+ if err == nil {
+ t.Fatalf("LookupHost(%q): returns %v, but should fail", invalidDomain, addr)
+ }
+
+ // Don't verify what the actual error is.
+ // We know that it must be non-nil because the domain is invalid,
+ // but we don't have any guarantee that LookupHost actually bothers
+ // to check for cancellation on the fast path.
+ // (For example, it could use a local cache to avoid blocking entirely.)
+
+ // The lookup may deduplicate in-flight requests, so give it time to settle
+ // in between.
+ time.Sleep(time.Millisecond * 1)
+ }
+
+ _, err = LookupHost(google)
+ if err != nil {
+ t.Fatal(err)
+ }
+}
+
+type lookupCustomResolver struct {
+ *Resolver
+ mu sync.RWMutex
+ dialed bool
+}
+
+func (lcr *lookupCustomResolver) dial() func(ctx context.Context, network, address string) (Conn, error) {
+ return func(ctx context.Context, network, address string) (Conn, error) {
+ lcr.mu.Lock()
+ lcr.dialed = true
+ lcr.mu.Unlock()
+ return Dial(network, address)
+ }
+}
+
+// TestConcurrentPreferGoResolversDial tests that multiple resolvers with the
+// PreferGo option used concurrently are all dialed properly.
+func TestConcurrentPreferGoResolversDial(t *testing.T) {
+ switch runtime.GOOS {
+ case "plan9":
+ // TODO: plan9 implementation of the resolver uses the Dial function since
+ // https://go.dev/cl/409234, this test could probably be reenabled.
+ t.Skipf("skip on %v", runtime.GOOS)
+ }
+
+ testenv.MustHaveExternalNetwork(t)
+ testenv.SkipFlakyNet(t)
+
+ defer dnsWaitGroup.Wait()
+
+ resolvers := make([]*lookupCustomResolver, 2)
+ for i := range resolvers {
+ cs := lookupCustomResolver{Resolver: &Resolver{PreferGo: true}}
+ cs.Dial = cs.dial()
+ resolvers[i] = &cs
+ }
+
+ var wg sync.WaitGroup
+ wg.Add(len(resolvers))
+ for i, resolver := range resolvers {
+ go func(r *Resolver, index int) {
+ defer wg.Done()
+ _, err := r.LookupIPAddr(context.Background(), "google.com")
+ if err != nil {
+ t.Errorf("lookup failed for resolver %d: %q", index, err)
+ }
+ }(resolver.Resolver, i)
+ }
+ wg.Wait()
+
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ for i, resolver := range resolvers {
+ if !resolver.dialed {
+ t.Errorf("custom resolver %d not dialed during lookup", i)
+ }
+ }
+}
+
+var ipVersionTests = []struct {
+ network string
+ version byte
+}{
+ {"tcp", 0},
+ {"tcp4", '4'},
+ {"tcp6", '6'},
+ {"udp", 0},
+ {"udp4", '4'},
+ {"udp6", '6'},
+ {"ip", 0},
+ {"ip4", '4'},
+ {"ip6", '6'},
+ {"ip7", 0},
+ {"", 0},
+}
+
+func TestIPVersion(t *testing.T) {
+ for _, tt := range ipVersionTests {
+ if version := ipVersion(tt.network); version != tt.version {
+ t.Errorf("Family for: %s. Expected: %s, Got: %s", tt.network,
+ string(tt.version), string(version))
+ }
+ }
+}
+
+// Issue 28600: The context that is used to lookup ips should always
+// preserve the values from the context that was passed into LookupIPAddr.
+func TestLookupIPAddrPreservesContextValues(t *testing.T) {
+ origTestHookLookupIP := testHookLookupIP
+ defer func() { testHookLookupIP = origTestHookLookupIP }()
+
+ keyValues := []struct {
+ key, value any
+ }{
+ {"key-1", 12},
+ {384, "value2"},
+ {new(float64), 137},
+ }
+ ctx := context.Background()
+ for _, kv := range keyValues {
+ ctx = context.WithValue(ctx, kv.key, kv.value)
+ }
+
+ wantIPs := []IPAddr{
+ {IP: IPv4(127, 0, 0, 1)},
+ {IP: IPv6loopback},
+ }
+
+ checkCtxValues := func(ctx_ context.Context, fn func(context.Context, string, string) ([]IPAddr, error), network, host string) ([]IPAddr, error) {
+ for _, kv := range keyValues {
+ g, w := ctx_.Value(kv.key), kv.value
+ if !reflect.DeepEqual(g, w) {
+ t.Errorf("Value lookup:\n\tGot: %v\n\tWant: %v", g, w)
+ }
+ }
+ return wantIPs, nil
+ }
+ testHookLookupIP = checkCtxValues
+
+ resolvers := []*Resolver{
+ nil,
+ new(Resolver),
+ }
+
+ for i, resolver := range resolvers {
+ gotIPs, err := resolver.LookupIPAddr(ctx, "golang.org")
+ if err != nil {
+ t.Errorf("Resolver #%d: unexpected error: %v", i, err)
+ }
+ if !reflect.DeepEqual(gotIPs, wantIPs) {
+ t.Errorf("#%d: mismatched IPAddr results\n\tGot: %v\n\tWant: %v", i, gotIPs, wantIPs)
+ }
+ }
+}
+
+// Issue 30521: The lookup group should call the resolver for each network.
+func TestLookupIPAddrConcurrentCallsForNetworks(t *testing.T) {
+ origTestHookLookupIP := testHookLookupIP
+ defer func() { testHookLookupIP = origTestHookLookupIP }()
+
+ queries := [][]string{
+ {"udp", "golang.org"},
+ {"udp4", "golang.org"},
+ {"udp6", "golang.org"},
+ {"udp", "golang.org"},
+ {"udp", "golang.org"},
+ }
+ results := map[[2]string][]IPAddr{
+ {"udp", "golang.org"}: {
+ {IP: IPv4(127, 0, 0, 1)},
+ {IP: IPv6loopback},
+ },
+ {"udp4", "golang.org"}: {
+ {IP: IPv4(127, 0, 0, 1)},
+ },
+ {"udp6", "golang.org"}: {
+ {IP: IPv6loopback},
+ },
+ }
+ calls := int32(0)
+ waitCh := make(chan struct{})
+ testHookLookupIP = func(ctx context.Context, fn func(context.Context, string, string) ([]IPAddr, error), network, host string) ([]IPAddr, error) {
+ // We'll block until this is called one time for each different
+ // expected result. This will ensure that the lookup group would wait
+ // for the existing call if it was to be reused.
+ if atomic.AddInt32(&calls, 1) == int32(len(results)) {
+ close(waitCh)
+ }
+ select {
+ case <-waitCh:
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ }
+ return results[[2]string{network, host}], nil
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ defer cancel()
+ wg := sync.WaitGroup{}
+ for _, q := range queries {
+ network := q[0]
+ host := q[1]
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ gotIPs, err := DefaultResolver.lookupIPAddr(ctx, network, host)
+ if err != nil {
+ t.Errorf("lookupIPAddr(%v, %v): unexpected error: %v", network, host, err)
+ }
+ wantIPs := results[[2]string{network, host}]
+ if !reflect.DeepEqual(gotIPs, wantIPs) {
+ t.Errorf("lookupIPAddr(%v, %v): mismatched IPAddr results\n\tGot: %v\n\tWant: %v", network, host, gotIPs, wantIPs)
+ }
+ }()
+ }
+ wg.Wait()
+}
+
+// Issue 53995: Resolver.LookupIP should return error for empty host name.
+func TestResolverLookupIPWithEmptyHost(t *testing.T) {
+ _, err := DefaultResolver.LookupIP(context.Background(), "ip", "")
+ if err == nil {
+ t.Fatal("DefaultResolver.LookupIP for empty host success, want no host error")
+ }
+ if !strings.HasSuffix(err.Error(), errNoSuchHost.Error()) {
+ t.Fatalf("lookup error = %v, want %v", err, errNoSuchHost)
+ }
+}
+
+func TestWithUnexpiredValuesPreserved(t *testing.T) {
+ ctx, cancel := context.WithCancel(context.Background())
+
+ // Insert a value into it.
+ key, value := "key-1", 2
+ ctx = context.WithValue(ctx, key, value)
+
+ // Now use the "values preserving context" like
+ // we would for LookupIPAddr. See Issue 28600.
+ ctx = withUnexpiredValuesPreserved(ctx)
+
+ // Lookup before expiry.
+ if g, w := ctx.Value(key), value; g != w {
+ t.Errorf("Lookup before expiry: Got %v Want %v", g, w)
+ }
+
+ // Cancel the context.
+ cancel()
+
+ // Lookup after expiry should return nil
+ if g := ctx.Value(key); g != nil {
+ t.Errorf("Lookup after expiry: Got %v want nil", g)
+ }
+}
+
+// Issue 31597: don't panic on null byte in name
+func TestLookupNullByte(t *testing.T) {
+ testenv.MustHaveExternalNetwork(t)
+ testenv.SkipFlakyNet(t)
+ LookupHost("foo\x00bar") // check that it doesn't panic; it used to on Windows
+}
+
+func TestResolverLookupIP(t *testing.T) {
+ testenv.MustHaveExternalNetwork(t)
+
+ v4Ok := supportsIPv4() && *testIPv4
+ v6Ok := supportsIPv6() && *testIPv6
+
+ defer dnsWaitGroup.Wait()
+
+ for _, impl := range []struct {
+ name string
+ fn func() func()
+ }{
+ {"go", forceGoDNS},
+ {"cgo", forceCgoDNS},
+ } {
+ t.Run("implementation: "+impl.name, func(t *testing.T) {
+ fixup := impl.fn()
+ if fixup == nil {
+ t.Skip("not supported")
+ }
+ defer fixup()
+
+ for _, network := range []string{"ip", "ip4", "ip6"} {
+ t.Run("network: "+network, func(t *testing.T) {
+ switch {
+ case network == "ip4" && !v4Ok:
+ t.Skip("IPv4 is not supported")
+ case network == "ip6" && !v6Ok:
+ t.Skip("IPv6 is not supported")
+ }
+
+ // google.com has both A and AAAA records.
+ const host = "google.com"
+ ips, err := DefaultResolver.LookupIP(context.Background(), network, host)
+ if err != nil {
+ testenv.SkipFlakyNet(t)
+ t.Fatalf("DefaultResolver.LookupIP(%q, %q): failed with unexpected error: %v", network, host, err)
+ }
+
+ var v4Addrs []netip.Addr
+ var v6Addrs []netip.Addr
+ for _, ip := range ips {
+ if addr, ok := netip.AddrFromSlice(ip); ok {
+ if addr.Is4() {
+ v4Addrs = append(v4Addrs, addr)
+ } else {
+ v6Addrs = append(v6Addrs, addr)
+ }
+ } else {
+ t.Fatalf("IP=%q is neither IPv4 nor IPv6", ip)
+ }
+ }
+
+ // Check that we got the expected addresses.
+ if network == "ip4" || network == "ip" && v4Ok {
+ if len(v4Addrs) == 0 {
+ t.Errorf("DefaultResolver.LookupIP(%q, %q): no IPv4 addresses", network, host)
+ }
+ }
+ if network == "ip6" || network == "ip" && v6Ok {
+ if len(v6Addrs) == 0 {
+ t.Errorf("DefaultResolver.LookupIP(%q, %q): no IPv6 addresses", network, host)
+ }
+ }
+
+ // Check that we didn't get any unexpected addresses.
+ if network == "ip6" && len(v4Addrs) > 0 {
+ t.Errorf("DefaultResolver.LookupIP(%q, %q): unexpected IPv4 addresses: %v", network, host, v4Addrs)
+ }
+ if network == "ip4" && len(v6Addrs) > 0 {
+ t.Errorf("DefaultResolver.LookupIP(%q, %q): unexpected IPv6 or IPv4-mapped IPv6 addresses: %v", network, host, v6Addrs)
+ }
+ })
+ }
+ })
+ }
+}
+
+// A context timeout should still return a DNSError.
+func TestDNSTimeout(t *testing.T) {
+ origTestHookLookupIP := testHookLookupIP
+ defer func() { testHookLookupIP = origTestHookLookupIP }()
+ defer dnsWaitGroup.Wait()
+
+ timeoutHookGo := make(chan bool, 1)
+ timeoutHook := func(ctx context.Context, fn func(context.Context, string, string) ([]IPAddr, error), network, host string) ([]IPAddr, error) {
+ <-timeoutHookGo
+ return nil, context.DeadlineExceeded
+ }
+ testHookLookupIP = timeoutHook
+
+ checkErr := func(err error) {
+ t.Helper()
+ if err == nil {
+ t.Error("expected an error")
+ } else if dnserr, ok := err.(*DNSError); !ok {
+ t.Errorf("got error type %T, want %T", err, (*DNSError)(nil))
+ } else if !dnserr.IsTimeout {
+ t.Errorf("got error %#v, want IsTimeout == true", dnserr)
+ } else if isTimeout := dnserr.Timeout(); !isTimeout {
+ t.Errorf("got err.Timeout() == %t, want true", isTimeout)
+ }
+ }
+
+ // Single lookup.
+ timeoutHookGo <- true
+ _, err := LookupIP("golang.org")
+ checkErr(err)
+
+ // Double lookup.
+ var err1, err2 error
+ var wg sync.WaitGroup
+ wg.Add(2)
+ go func() {
+ defer wg.Done()
+ _, err1 = LookupIP("golang1.org")
+ }()
+ go func() {
+ defer wg.Done()
+ _, err2 = LookupIP("golang1.org")
+ }()
+ close(timeoutHookGo)
+ wg.Wait()
+ checkErr(err1)
+ checkErr(err2)
+
+ // Double lookup with context.
+ timeoutHookGo = make(chan bool)
+ ctx, cancel := context.WithTimeout(context.Background(), time.Nanosecond)
+ wg.Add(2)
+ go func() {
+ defer wg.Done()
+ _, err1 = DefaultResolver.LookupIPAddr(ctx, "golang2.org")
+ }()
+ go func() {
+ defer wg.Done()
+ _, err2 = DefaultResolver.LookupIPAddr(ctx, "golang2.org")
+ }()
+ time.Sleep(10 * time.Nanosecond)
+ close(timeoutHookGo)
+ wg.Wait()
+ checkErr(err1)
+ checkErr(err2)
+ cancel()
+}
+
+func TestLookupNoData(t *testing.T) {
+ if runtime.GOOS == "plan9" {
+ t.Skip("not supported on plan9")
+ }
+
+ mustHaveExternalNetwork(t)
+
+ testLookupNoData(t, "default resolver")
+
+ func() {
+ defer forceGoDNS()()
+ testLookupNoData(t, "forced go resolver")
+ }()
+
+ func() {
+ defer forceCgoDNS()()
+ testLookupNoData(t, "forced cgo resolver")
+ }()
+}
+
+func testLookupNoData(t *testing.T, prefix string) {
+ attempts := 0
+ for {
+ // Domain that doesn't have any A/AAAA RRs, but has different one (in this case a TXT),
+ // so that it returns an empty response without any error codes (NXDOMAIN).
+ _, err := LookupHost("golang.rsc.io.")
+ if err == nil {
+ t.Errorf("%v: unexpected success", prefix)
+ return
+ }
+
+ var dnsErr *DNSError
+ if errors.As(err, &dnsErr) {
+ succeeded := true
+ if !dnsErr.IsNotFound {
+ succeeded = false
+ t.Logf("%v: IsNotFound is set to false", prefix)
+ }
+
+ if dnsErr.Err != errNoSuchHost.Error() {
+ succeeded = false
+ t.Logf("%v: error message is not equal to: %v", prefix, errNoSuchHost.Error())
+ }
+
+ if succeeded {
+ return
+ }
+ }
+
+ testenv.SkipFlakyNet(t)
+ if attempts < len(backoffDuration) {
+ dur := backoffDuration[attempts]
+ t.Logf("%v: backoff %v after failure %v\n", prefix, dur, err)
+ time.Sleep(dur)
+ attempts++
+ continue
+ }
+
+ t.Errorf("%v: unexpected error: %v", prefix, err)
+ return
+ }
+}
diff --git a/src/net/lookup_unix.go b/src/net/lookup_unix.go
new file mode 100644
index 0000000..56ae11e
--- /dev/null
+++ b/src/net/lookup_unix.go
@@ -0,0 +1,149 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build unix || wasip1
+
+package net
+
+import (
+ "context"
+ "internal/bytealg"
+ "sync"
+ "syscall"
+)
+
+var onceReadProtocols sync.Once
+
+// readProtocols loads contents of /etc/protocols into protocols map
+// for quick access.
+func readProtocols() {
+ file, err := open("/etc/protocols")
+ if err != nil {
+ return
+ }
+ defer file.close()
+
+ for line, ok := file.readLine(); ok; line, ok = file.readLine() {
+ // tcp 6 TCP # transmission control protocol
+ if i := bytealg.IndexByteString(line, '#'); i >= 0 {
+ line = line[0:i]
+ }
+ f := getFields(line)
+ if len(f) < 2 {
+ continue
+ }
+ if proto, _, ok := dtoi(f[1]); ok {
+ if _, ok := protocols[f[0]]; !ok {
+ protocols[f[0]] = proto
+ }
+ for _, alias := range f[2:] {
+ if _, ok := protocols[alias]; !ok {
+ protocols[alias] = proto
+ }
+ }
+ }
+ }
+}
+
+// lookupProtocol looks up IP protocol name in /etc/protocols and
+// returns correspondent protocol number.
+func lookupProtocol(_ context.Context, name string) (int, error) {
+ onceReadProtocols.Do(readProtocols)
+ return lookupProtocolMap(name)
+}
+
+func (r *Resolver) lookupHost(ctx context.Context, host string) (addrs []string, err error) {
+ order, conf := systemConf().hostLookupOrder(r, host)
+ if order == hostLookupCgo {
+ return cgoLookupHost(ctx, host)
+ }
+ return r.goLookupHostOrder(ctx, host, order, conf)
+}
+
+func (r *Resolver) lookupIP(ctx context.Context, network, host string) (addrs []IPAddr, err error) {
+ if r.preferGo() {
+ return r.goLookupIP(ctx, network, host)
+ }
+ order, conf := systemConf().hostLookupOrder(r, host)
+ if order == hostLookupCgo {
+ return cgoLookupIP(ctx, network, host)
+ }
+ ips, _, err := r.goLookupIPCNAMEOrder(ctx, network, host, order, conf)
+ return ips, err
+}
+
+func (r *Resolver) lookupPort(ctx context.Context, network, service string) (int, error) {
+ // Port lookup is not a DNS operation.
+ // Prefer the cgo resolver if possible.
+ if !systemConf().mustUseGoResolver(r) {
+ port, err := cgoLookupPort(ctx, network, service)
+ if err != nil {
+ // Issue 18213: if cgo fails, first check to see whether we
+ // have the answer baked-in to the net package.
+ if port, err := goLookupPort(network, service); err == nil {
+ return port, nil
+ }
+ }
+ return port, err
+ }
+ return goLookupPort(network, service)
+}
+
+func (r *Resolver) lookupCNAME(ctx context.Context, name string) (string, error) {
+ order, conf := systemConf().hostLookupOrder(r, name)
+ if order == hostLookupCgo {
+ if cname, err, ok := cgoLookupCNAME(ctx, name); ok {
+ return cname, err
+ }
+ }
+ return r.goLookupCNAME(ctx, name, order, conf)
+}
+
+func (r *Resolver) lookupSRV(ctx context.Context, service, proto, name string) (string, []*SRV, error) {
+ return r.goLookupSRV(ctx, service, proto, name)
+}
+
+func (r *Resolver) lookupMX(ctx context.Context, name string) ([]*MX, error) {
+ return r.goLookupMX(ctx, name)
+}
+
+func (r *Resolver) lookupNS(ctx context.Context, name string) ([]*NS, error) {
+ return r.goLookupNS(ctx, name)
+}
+
+func (r *Resolver) lookupTXT(ctx context.Context, name string) ([]string, error) {
+ return r.goLookupTXT(ctx, name)
+}
+
+func (r *Resolver) lookupAddr(ctx context.Context, addr string) ([]string, error) {
+ order, conf := systemConf().addrLookupOrder(r, addr)
+ if order == hostLookupCgo {
+ return cgoLookupPTR(ctx, addr)
+ }
+ return r.goLookupPTR(ctx, addr, order, conf)
+}
+
+// concurrentThreadsLimit returns the number of threads we permit to
+// run concurrently doing DNS lookups via cgo. A DNS lookup may use a
+// file descriptor so we limit this to less than the number of
+// permitted open files. On some systems, notably Darwin, if
+// getaddrinfo is unable to open a file descriptor it simply returns
+// EAI_NONAME rather than a useful error. Limiting the number of
+// concurrent getaddrinfo calls to less than the permitted number of
+// file descriptors makes that error less likely. We don't bother to
+// apply the same limit to DNS lookups run directly from Go, because
+// there we will return a meaningful "too many open files" error.
+func concurrentThreadsLimit() int {
+ var rlim syscall.Rlimit
+ if err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rlim); err != nil {
+ return 500
+ }
+ r := rlim.Cur
+ if r > 500 {
+ r = 500
+ } else if r > 30 {
+ r -= 30
+ }
+ return int(r)
+}
diff --git a/src/net/lookup_windows.go b/src/net/lookup_windows.go
new file mode 100644
index 0000000..33d5ac5
--- /dev/null
+++ b/src/net/lookup_windows.go
@@ -0,0 +1,455 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "context"
+ "internal/syscall/windows"
+ "os"
+ "runtime"
+ "syscall"
+ "time"
+ "unsafe"
+)
+
+// cgoAvailable set to true to indicate that the cgo resolver
+// is available on Windows. Note that on Windows the cgo resolver
+// does not actually use cgo.
+const cgoAvailable = true
+
+const (
+ _WSAHOST_NOT_FOUND = syscall.Errno(11001)
+ _WSATRY_AGAIN = syscall.Errno(11002)
+)
+
+func winError(call string, err error) error {
+ switch err {
+ case _WSAHOST_NOT_FOUND:
+ return errNoSuchHost
+ }
+ return os.NewSyscallError(call, err)
+}
+
+func getprotobyname(name string) (proto int, err error) {
+ p, err := syscall.GetProtoByName(name)
+ if err != nil {
+ return 0, winError("getprotobyname", err)
+ }
+ return int(p.Proto), nil
+}
+
+// lookupProtocol looks up IP protocol name and returns correspondent protocol number.
+func lookupProtocol(ctx context.Context, name string) (int, error) {
+ // GetProtoByName return value is stored in thread local storage.
+ // Start new os thread before the call to prevent races.
+ type result struct {
+ proto int
+ err error
+ }
+ ch := make(chan result) // unbuffered
+ go func() {
+ acquireThread()
+ defer releaseThread()
+ runtime.LockOSThread()
+ defer runtime.UnlockOSThread()
+ proto, err := getprotobyname(name)
+ select {
+ case ch <- result{proto: proto, err: err}:
+ case <-ctx.Done():
+ }
+ }()
+ select {
+ case r := <-ch:
+ if r.err != nil {
+ if proto, err := lookupProtocolMap(name); err == nil {
+ return proto, nil
+ }
+
+ dnsError := &DNSError{Err: r.err.Error(), Name: name}
+ if r.err == errNoSuchHost {
+ dnsError.IsNotFound = true
+ }
+ r.err = dnsError
+ }
+ return r.proto, r.err
+ case <-ctx.Done():
+ return 0, mapErr(ctx.Err())
+ }
+}
+
+func (r *Resolver) lookupHost(ctx context.Context, name string) ([]string, error) {
+ ips, err := r.lookupIP(ctx, "ip", name)
+ if err != nil {
+ return nil, err
+ }
+ addrs := make([]string, 0, len(ips))
+ for _, ip := range ips {
+ addrs = append(addrs, ip.String())
+ }
+ return addrs, nil
+}
+
+// preferGoOverWindows reports whether the resolver should use the
+// pure Go implementation rather than making win32 calls to ask the
+// kernel for its answer.
+func (r *Resolver) preferGoOverWindows() bool {
+ conf := systemConf()
+ order, _ := conf.hostLookupOrder(r, "") // name is unused
+ return order != hostLookupCgo
+}
+
+func (r *Resolver) lookupIP(ctx context.Context, network, name string) ([]IPAddr, error) {
+ if r.preferGoOverWindows() {
+ return r.goLookupIP(ctx, network, name)
+ }
+ // TODO(bradfitz,brainman): use ctx more. See TODO below.
+
+ var family int32 = syscall.AF_UNSPEC
+ switch ipVersion(network) {
+ case '4':
+ family = syscall.AF_INET
+ case '6':
+ family = syscall.AF_INET6
+ }
+
+ getaddr := func() ([]IPAddr, error) {
+ acquireThread()
+ defer releaseThread()
+ hints := syscall.AddrinfoW{
+ Family: family,
+ Socktype: syscall.SOCK_STREAM,
+ Protocol: syscall.IPPROTO_IP,
+ }
+ var result *syscall.AddrinfoW
+ name16p, err := syscall.UTF16PtrFromString(name)
+ if err != nil {
+ return nil, &DNSError{Name: name, Err: err.Error()}
+ }
+
+ dnsConf := getSystemDNSConfig()
+ start := time.Now()
+
+ var e error
+ for i := 0; i < dnsConf.attempts; i++ {
+ e = syscall.GetAddrInfoW(name16p, nil, &hints, &result)
+ if e == nil || e != _WSATRY_AGAIN || time.Since(start) > dnsConf.timeout {
+ break
+ }
+ }
+ if e != nil {
+ err := winError("getaddrinfow", e)
+ dnsError := &DNSError{Err: err.Error(), Name: name}
+ if err == errNoSuchHost {
+ dnsError.IsNotFound = true
+ }
+ return nil, dnsError
+ }
+ defer syscall.FreeAddrInfoW(result)
+ addrs := make([]IPAddr, 0, 5)
+ for ; result != nil; result = result.Next {
+ addr := unsafe.Pointer(result.Addr)
+ switch result.Family {
+ case syscall.AF_INET:
+ a := (*syscall.RawSockaddrInet4)(addr).Addr
+ addrs = append(addrs, IPAddr{IP: copyIP(a[:])})
+ case syscall.AF_INET6:
+ a := (*syscall.RawSockaddrInet6)(addr).Addr
+ zone := zoneCache.name(int((*syscall.RawSockaddrInet6)(addr).Scope_id))
+ addrs = append(addrs, IPAddr{IP: copyIP(a[:]), Zone: zone})
+ default:
+ return nil, &DNSError{Err: syscall.EWINDOWS.Error(), Name: name}
+ }
+ }
+ return addrs, nil
+ }
+
+ type ret struct {
+ addrs []IPAddr
+ err error
+ }
+
+ var ch chan ret
+ if ctx.Err() == nil {
+ ch = make(chan ret, 1)
+ go func() {
+ addr, err := getaddr()
+ ch <- ret{addrs: addr, err: err}
+ }()
+ }
+
+ select {
+ case r := <-ch:
+ return r.addrs, r.err
+ case <-ctx.Done():
+ // TODO(bradfitz,brainman): cancel the ongoing
+ // GetAddrInfoW? It would require conditionally using
+ // GetAddrInfoEx with lpOverlapped, which requires
+ // Windows 8 or newer. I guess we'll need oldLookupIP,
+ // newLookupIP, and newerLookUP.
+ //
+ // For now we just let it finish and write to the
+ // buffered channel.
+ return nil, &DNSError{
+ Name: name,
+ Err: ctx.Err().Error(),
+ IsTimeout: ctx.Err() == context.DeadlineExceeded,
+ }
+ }
+}
+
+func (r *Resolver) lookupPort(ctx context.Context, network, service string) (int, error) {
+ if r.preferGoOverWindows() {
+ return lookupPortMap(network, service)
+ }
+
+ // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
+ acquireThread()
+ defer releaseThread()
+ var stype int32
+ switch network {
+ case "tcp4", "tcp6":
+ stype = syscall.SOCK_STREAM
+ case "udp4", "udp6":
+ stype = syscall.SOCK_DGRAM
+ }
+ hints := syscall.AddrinfoW{
+ Family: syscall.AF_UNSPEC,
+ Socktype: stype,
+ Protocol: syscall.IPPROTO_IP,
+ }
+ var result *syscall.AddrinfoW
+ e := syscall.GetAddrInfoW(nil, syscall.StringToUTF16Ptr(service), &hints, &result)
+ if e != nil {
+ if port, err := lookupPortMap(network, service); err == nil {
+ return port, nil
+ }
+ err := winError("getaddrinfow", e)
+ dnsError := &DNSError{Err: err.Error(), Name: network + "/" + service}
+ if err == errNoSuchHost {
+ dnsError.IsNotFound = true
+ }
+ return 0, dnsError
+ }
+ defer syscall.FreeAddrInfoW(result)
+ if result == nil {
+ return 0, &DNSError{Err: syscall.EINVAL.Error(), Name: network + "/" + service}
+ }
+ addr := unsafe.Pointer(result.Addr)
+ switch result.Family {
+ case syscall.AF_INET:
+ a := (*syscall.RawSockaddrInet4)(addr)
+ return int(syscall.Ntohs(a.Port)), nil
+ case syscall.AF_INET6:
+ a := (*syscall.RawSockaddrInet6)(addr)
+ return int(syscall.Ntohs(a.Port)), nil
+ }
+ return 0, &DNSError{Err: syscall.EINVAL.Error(), Name: network + "/" + service}
+}
+
+func (r *Resolver) lookupCNAME(ctx context.Context, name string) (string, error) {
+ if order, conf := systemConf().hostLookupOrder(r, ""); order != hostLookupCgo {
+ return r.goLookupCNAME(ctx, name, order, conf)
+ }
+
+ // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
+ acquireThread()
+ defer releaseThread()
+ var rec *syscall.DNSRecord
+ e := syscall.DnsQuery(name, syscall.DNS_TYPE_CNAME, 0, nil, &rec, nil)
+ // windows returns DNS_INFO_NO_RECORDS if there are no CNAME-s
+ if errno, ok := e.(syscall.Errno); ok && errno == syscall.DNS_INFO_NO_RECORDS {
+ // if there are no aliases, the canonical name is the input name
+ return absDomainName(name), nil
+ }
+ if e != nil {
+ return "", &DNSError{Err: winError("dnsquery", e).Error(), Name: name}
+ }
+ defer syscall.DnsRecordListFree(rec, 1)
+
+ resolved := resolveCNAME(syscall.StringToUTF16Ptr(name), rec)
+ cname := windows.UTF16PtrToString(resolved)
+ return absDomainName(cname), nil
+}
+
+func (r *Resolver) lookupSRV(ctx context.Context, service, proto, name string) (string, []*SRV, error) {
+ if r.preferGoOverWindows() {
+ return r.goLookupSRV(ctx, service, proto, name)
+ }
+ // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
+ acquireThread()
+ defer releaseThread()
+ var target string
+ if service == "" && proto == "" {
+ target = name
+ } else {
+ target = "_" + service + "._" + proto + "." + name
+ }
+ var rec *syscall.DNSRecord
+ e := syscall.DnsQuery(target, syscall.DNS_TYPE_SRV, 0, nil, &rec, nil)
+ if e != nil {
+ return "", nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: target}
+ }
+ defer syscall.DnsRecordListFree(rec, 1)
+
+ srvs := make([]*SRV, 0, 10)
+ for _, p := range validRecs(rec, syscall.DNS_TYPE_SRV, target) {
+ v := (*syscall.DNSSRVData)(unsafe.Pointer(&p.Data[0]))
+ srvs = append(srvs, &SRV{absDomainName(syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Target))[:])), v.Port, v.Priority, v.Weight})
+ }
+ byPriorityWeight(srvs).sort()
+ return absDomainName(target), srvs, nil
+}
+
+func (r *Resolver) lookupMX(ctx context.Context, name string) ([]*MX, error) {
+ if r.preferGoOverWindows() {
+ return r.goLookupMX(ctx, name)
+ }
+ // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
+ acquireThread()
+ defer releaseThread()
+ var rec *syscall.DNSRecord
+ e := syscall.DnsQuery(name, syscall.DNS_TYPE_MX, 0, nil, &rec, nil)
+ if e != nil {
+ return nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: name}
+ }
+ defer syscall.DnsRecordListFree(rec, 1)
+
+ mxs := make([]*MX, 0, 10)
+ for _, p := range validRecs(rec, syscall.DNS_TYPE_MX, name) {
+ v := (*syscall.DNSMXData)(unsafe.Pointer(&p.Data[0]))
+ mxs = append(mxs, &MX{absDomainName(windows.UTF16PtrToString(v.NameExchange)), v.Preference})
+ }
+ byPref(mxs).sort()
+ return mxs, nil
+}
+
+func (r *Resolver) lookupNS(ctx context.Context, name string) ([]*NS, error) {
+ if r.preferGoOverWindows() {
+ return r.goLookupNS(ctx, name)
+ }
+ // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
+ acquireThread()
+ defer releaseThread()
+ var rec *syscall.DNSRecord
+ e := syscall.DnsQuery(name, syscall.DNS_TYPE_NS, 0, nil, &rec, nil)
+ if e != nil {
+ return nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: name}
+ }
+ defer syscall.DnsRecordListFree(rec, 1)
+
+ nss := make([]*NS, 0, 10)
+ for _, p := range validRecs(rec, syscall.DNS_TYPE_NS, name) {
+ v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0]))
+ nss = append(nss, &NS{absDomainName(syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:]))})
+ }
+ return nss, nil
+}
+
+func (r *Resolver) lookupTXT(ctx context.Context, name string) ([]string, error) {
+ if r.preferGoOverWindows() {
+ return r.goLookupTXT(ctx, name)
+ }
+ // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
+ acquireThread()
+ defer releaseThread()
+ var rec *syscall.DNSRecord
+ e := syscall.DnsQuery(name, syscall.DNS_TYPE_TEXT, 0, nil, &rec, nil)
+ if e != nil {
+ return nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: name}
+ }
+ defer syscall.DnsRecordListFree(rec, 1)
+
+ txts := make([]string, 0, 10)
+ for _, p := range validRecs(rec, syscall.DNS_TYPE_TEXT, name) {
+ d := (*syscall.DNSTXTData)(unsafe.Pointer(&p.Data[0]))
+ s := ""
+ for _, v := range (*[1 << 10]*uint16)(unsafe.Pointer(&(d.StringArray[0])))[:d.StringCount:d.StringCount] {
+ s += windows.UTF16PtrToString(v)
+ }
+ txts = append(txts, s)
+ }
+ return txts, nil
+}
+
+func (r *Resolver) lookupAddr(ctx context.Context, addr string) ([]string, error) {
+ if order, conf := systemConf().hostLookupOrder(r, ""); order != hostLookupCgo {
+ return r.goLookupPTR(ctx, addr, order, conf)
+ }
+
+ // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
+ acquireThread()
+ defer releaseThread()
+ arpa, err := reverseaddr(addr)
+ if err != nil {
+ return nil, err
+ }
+ var rec *syscall.DNSRecord
+ e := syscall.DnsQuery(arpa, syscall.DNS_TYPE_PTR, 0, nil, &rec, nil)
+ if e != nil {
+ return nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: addr}
+ }
+ defer syscall.DnsRecordListFree(rec, 1)
+
+ ptrs := make([]string, 0, 10)
+ for _, p := range validRecs(rec, syscall.DNS_TYPE_PTR, arpa) {
+ v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0]))
+ ptrs = append(ptrs, absDomainName(windows.UTF16PtrToString(v.Host)))
+ }
+ return ptrs, nil
+}
+
+const dnsSectionMask = 0x0003
+
+// returns only results applicable to name and resolves CNAME entries.
+func validRecs(r *syscall.DNSRecord, dnstype uint16, name string) []*syscall.DNSRecord {
+ cname := syscall.StringToUTF16Ptr(name)
+ if dnstype != syscall.DNS_TYPE_CNAME {
+ cname = resolveCNAME(cname, r)
+ }
+ rec := make([]*syscall.DNSRecord, 0, 10)
+ for p := r; p != nil; p = p.Next {
+ // in case of a local machine, DNS records are returned with DNSREC_QUESTION flag instead of DNS_ANSWER
+ if p.Dw&dnsSectionMask != syscall.DnsSectionAnswer && p.Dw&dnsSectionMask != syscall.DnsSectionQuestion {
+ continue
+ }
+ if p.Type != dnstype {
+ continue
+ }
+ if !syscall.DnsNameCompare(cname, p.Name) {
+ continue
+ }
+ rec = append(rec, p)
+ }
+ return rec
+}
+
+// returns the last CNAME in chain.
+func resolveCNAME(name *uint16, r *syscall.DNSRecord) *uint16 {
+ // limit cname resolving to 10 in case of an infinite CNAME loop
+Cname:
+ for cnameloop := 0; cnameloop < 10; cnameloop++ {
+ for p := r; p != nil; p = p.Next {
+ if p.Dw&dnsSectionMask != syscall.DnsSectionAnswer {
+ continue
+ }
+ if p.Type != syscall.DNS_TYPE_CNAME {
+ continue
+ }
+ if !syscall.DnsNameCompare(name, p.Name) {
+ continue
+ }
+ name = (*syscall.DNSPTRData)(unsafe.Pointer(&r.Data[0])).Host
+ continue Cname
+ }
+ break
+ }
+ return name
+}
+
+// concurrentThreadsLimit returns the number of threads we permit to
+// run concurrently doing DNS lookups.
+func concurrentThreadsLimit() int {
+ return 500
+}
diff --git a/src/net/lookup_windows_test.go b/src/net/lookup_windows_test.go
new file mode 100644
index 0000000..c618a05
--- /dev/null
+++ b/src/net/lookup_windows_test.go
@@ -0,0 +1,340 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "internal/testenv"
+ "os/exec"
+ "reflect"
+ "regexp"
+ "sort"
+ "strings"
+ "syscall"
+ "testing"
+)
+
+var nslookupTestServers = []string{"mail.golang.com", "gmail.com"}
+var lookupTestIPs = []string{"8.8.8.8", "1.1.1.1"}
+
+func toJson(v any) string {
+ data, _ := json.Marshal(v)
+ return string(data)
+}
+
+func testLookup(t *testing.T, fn func(*testing.T, *Resolver, string)) {
+ for _, def := range []bool{true, false} {
+ def := def
+ for _, server := range nslookupTestServers {
+ server := server
+ var name string
+ if def {
+ name = "default/"
+ } else {
+ name = "go/"
+ }
+ t.Run(name+server, func(t *testing.T) {
+ t.Parallel()
+ r := DefaultResolver
+ if !def {
+ r = &Resolver{PreferGo: true}
+ }
+ fn(t, r, server)
+ })
+ }
+ }
+}
+
+func TestNSLookupMX(t *testing.T) {
+ testenv.MustHaveExternalNetwork(t)
+
+ testLookup(t, func(t *testing.T, r *Resolver, server string) {
+ mx, err := r.LookupMX(context.Background(), server)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(mx) == 0 {
+ t.Fatal("no results")
+ }
+ expected, err := nslookupMX(server)
+ if err != nil {
+ t.Skipf("skipping failed nslookup %s test: %s", server, err)
+ }
+ sort.Sort(byPrefAndHost(expected))
+ sort.Sort(byPrefAndHost(mx))
+ if !reflect.DeepEqual(expected, mx) {
+ t.Errorf("different results %s:\texp:%v\tgot:%v", server, toJson(expected), toJson(mx))
+ }
+ })
+}
+
+func TestNSLookupCNAME(t *testing.T) {
+ testenv.MustHaveExternalNetwork(t)
+
+ testLookup(t, func(t *testing.T, r *Resolver, server string) {
+ cname, err := r.LookupCNAME(context.Background(), server)
+ if err != nil {
+ t.Fatalf("failed %s: %s", server, err)
+ }
+ if cname == "" {
+ t.Fatalf("no result %s", server)
+ }
+ expected, err := nslookupCNAME(server)
+ if err != nil {
+ t.Skipf("skipping failed nslookup %s test: %s", server, err)
+ }
+ if expected != cname {
+ t.Errorf("different results %s:\texp:%v\tgot:%v", server, expected, cname)
+ }
+ })
+}
+
+func TestNSLookupNS(t *testing.T) {
+ testenv.MustHaveExternalNetwork(t)
+
+ testLookup(t, func(t *testing.T, r *Resolver, server string) {
+ ns, err := r.LookupNS(context.Background(), server)
+ if err != nil {
+ t.Fatalf("failed %s: %s", server, err)
+ }
+ if len(ns) == 0 {
+ t.Fatal("no results")
+ }
+ expected, err := nslookupNS(server)
+ if err != nil {
+ t.Skipf("skipping failed nslookup %s test: %s", server, err)
+ }
+ sort.Sort(byHost(expected))
+ sort.Sort(byHost(ns))
+ if !reflect.DeepEqual(expected, ns) {
+ t.Errorf("different results %s:\texp:%v\tgot:%v", toJson(server), toJson(expected), ns)
+ }
+ })
+}
+
+func TestNSLookupTXT(t *testing.T) {
+ testenv.MustHaveExternalNetwork(t)
+
+ testLookup(t, func(t *testing.T, r *Resolver, server string) {
+ txt, err := r.LookupTXT(context.Background(), server)
+ if err != nil {
+ t.Fatalf("failed %s: %s", server, err)
+ }
+ if len(txt) == 0 {
+ t.Fatalf("no results")
+ }
+ expected, err := nslookupTXT(server)
+ if err != nil {
+ t.Skipf("skipping failed nslookup %s test: %s", server, err)
+ }
+ sort.Strings(expected)
+ sort.Strings(txt)
+ if !reflect.DeepEqual(expected, txt) {
+ t.Errorf("different results %s:\texp:%v\tgot:%v", server, toJson(expected), toJson(txt))
+ }
+ })
+}
+
+func TestLookupLocalPTR(t *testing.T) {
+ testenv.MustHaveExternalNetwork(t)
+
+ addr, err := localIP()
+ if err != nil {
+ t.Errorf("failed to get local ip: %s", err)
+ }
+ names, err := LookupAddr(addr.String())
+ if err != nil {
+ t.Errorf("failed %s: %s", addr, err)
+ }
+ if len(names) == 0 {
+ t.Errorf("no results")
+ }
+ expected, err := lookupPTR(addr.String())
+ if err != nil {
+ t.Skipf("skipping failed lookup %s test: %s", addr.String(), err)
+ }
+ sort.Strings(expected)
+ sort.Strings(names)
+ if !reflect.DeepEqual(expected, names) {
+ t.Errorf("different results %s:\texp:%v\tgot:%v", addr, toJson(expected), toJson(names))
+ }
+}
+
+func TestLookupPTR(t *testing.T) {
+ testenv.MustHaveExternalNetwork(t)
+
+ for _, addr := range lookupTestIPs {
+ names, err := LookupAddr(addr)
+ if err != nil {
+ // The DNSError type stores the error as a string, so it cannot wrap the
+ // original error code and we cannot check for it here. However, we can at
+ // least use its error string to identify the correct localized text for
+ // the error to skip.
+ var DNS_ERROR_RCODE_SERVER_FAILURE syscall.Errno = 9002
+ if strings.HasSuffix(err.Error(), DNS_ERROR_RCODE_SERVER_FAILURE.Error()) {
+ testenv.SkipFlaky(t, 38111)
+ }
+ t.Errorf("failed %s: %s", addr, err)
+ }
+ if len(names) == 0 {
+ t.Errorf("no results")
+ }
+ expected, err := lookupPTR(addr)
+ if err != nil {
+ t.Logf("skipping failed lookup %s test: %s", addr, err)
+ continue
+ }
+ sort.Strings(expected)
+ sort.Strings(names)
+ if !reflect.DeepEqual(expected, names) {
+ t.Errorf("different results %s:\texp:%v\tgot:%v", addr, toJson(expected), toJson(names))
+ }
+ }
+}
+
+type byPrefAndHost []*MX
+
+func (s byPrefAndHost) Len() int { return len(s) }
+func (s byPrefAndHost) Less(i, j int) bool {
+ if s[i].Pref != s[j].Pref {
+ return s[i].Pref < s[j].Pref
+ }
+ return s[i].Host < s[j].Host
+}
+func (s byPrefAndHost) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
+
+type byHost []*NS
+
+func (s byHost) Len() int { return len(s) }
+func (s byHost) Less(i, j int) bool { return s[i].Host < s[j].Host }
+func (s byHost) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
+
+func nslookup(qtype, name string) (string, error) {
+ var out strings.Builder
+ var err strings.Builder
+ cmd := exec.Command("nslookup", "-querytype="+qtype, name)
+ cmd.Stdout = &out
+ cmd.Stderr = &err
+ if err := cmd.Run(); err != nil {
+ return "", err
+ }
+ r := strings.ReplaceAll(out.String(), "\r\n", "\n")
+ // nslookup stderr output contains also debug information such as
+ // "Non-authoritative answer" and it doesn't return the correct errcode
+ if strings.Contains(err.String(), "can't find") {
+ return r, errors.New(err.String())
+ }
+ return r, nil
+}
+
+func nslookupMX(name string) (mx []*MX, err error) {
+ var r string
+ if r, err = nslookup("mx", name); err != nil {
+ return
+ }
+ mx = make([]*MX, 0, 10)
+ // linux nslookup syntax
+ // golang.org mail exchanger = 2 alt1.aspmx.l.google.com.
+ rx := regexp.MustCompile(`(?m)^([a-z0-9.\-]+)\s+mail exchanger\s*=\s*([0-9]+)\s*([a-z0-9.\-]+)$`)
+ for _, ans := range rx.FindAllStringSubmatch(r, -1) {
+ pref, _, _ := dtoi(ans[2])
+ mx = append(mx, &MX{absDomainName(ans[3]), uint16(pref)})
+ }
+ // windows nslookup syntax
+ // gmail.com MX preference = 30, mail exchanger = alt3.gmail-smtp-in.l.google.com
+ rx = regexp.MustCompile(`(?m)^([a-z0-9.\-]+)\s+MX preference\s*=\s*([0-9]+)\s*,\s*mail exchanger\s*=\s*([a-z0-9.\-]+)$`)
+ for _, ans := range rx.FindAllStringSubmatch(r, -1) {
+ pref, _, _ := dtoi(ans[2])
+ mx = append(mx, &MX{absDomainName(ans[3]), uint16(pref)})
+ }
+ return
+}
+
+func nslookupNS(name string) (ns []*NS, err error) {
+ var r string
+ if r, err = nslookup("ns", name); err != nil {
+ return
+ }
+ ns = make([]*NS, 0, 10)
+ // golang.org nameserver = ns1.google.com.
+ rx := regexp.MustCompile(`(?m)^([a-z0-9.\-]+)\s+nameserver\s*=\s*([a-z0-9.\-]+)$`)
+ for _, ans := range rx.FindAllStringSubmatch(r, -1) {
+ ns = append(ns, &NS{absDomainName(ans[2])})
+ }
+ return
+}
+
+func nslookupCNAME(name string) (cname string, err error) {
+ var r string
+ if r, err = nslookup("cname", name); err != nil {
+ return
+ }
+ // mail.golang.com canonical name = golang.org.
+ rx := regexp.MustCompile(`(?m)^([a-z0-9.\-]+)\s+canonical name\s*=\s*([a-z0-9.\-]+)$`)
+ // assumes the last CNAME is the correct one
+ last := name
+ for _, ans := range rx.FindAllStringSubmatch(r, -1) {
+ last = ans[2]
+ }
+ return absDomainName(last), nil
+}
+
+func nslookupTXT(name string) (txt []string, err error) {
+ var r string
+ if r, err = nslookup("txt", name); err != nil {
+ return
+ }
+ txt = make([]string, 0, 10)
+ // linux
+ // golang.org text = "v=spf1 redirect=_spf.google.com"
+
+ // windows
+ // golang.org text =
+ //
+ // "v=spf1 redirect=_spf.google.com"
+ rx := regexp.MustCompile(`(?m)^([a-z0-9.\-]+)\s+text\s*=\s*"(.*)"$`)
+ for _, ans := range rx.FindAllStringSubmatch(r, -1) {
+ txt = append(txt, ans[2])
+ }
+ return
+}
+
+func ping(name string) (string, error) {
+ cmd := exec.Command("ping", "-n", "1", "-a", name)
+ stdoutStderr, err := cmd.CombinedOutput()
+ if err != nil {
+ return "", fmt.Errorf("%v: %v", err, string(stdoutStderr))
+ }
+ r := strings.ReplaceAll(string(stdoutStderr), "\r\n", "\n")
+ return r, nil
+}
+
+func lookupPTR(name string) (ptr []string, err error) {
+ var r string
+ if r, err = ping(name); err != nil {
+ return
+ }
+ ptr = make([]string, 0, 10)
+ rx := regexp.MustCompile(`(?m)^Pinging\s+([a-zA-Z0-9.\-]+)\s+\[.*$`)
+ for _, ans := range rx.FindAllStringSubmatch(r, -1) {
+ ptr = append(ptr, absDomainName(ans[1]))
+ }
+ return
+}
+
+func localIP() (ip IP, err error) {
+ conn, err := Dial("udp", "golang.org:80")
+ if err != nil {
+ return nil, err
+ }
+ defer conn.Close()
+
+ localAddr := conn.LocalAddr().(*UDPAddr)
+
+ return localAddr.IP, nil
+}
diff --git a/src/net/mac.go b/src/net/mac.go
new file mode 100644
index 0000000..53d5b2d
--- /dev/null
+++ b/src/net/mac.go
@@ -0,0 +1,86 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+const hexDigit = "0123456789abcdef"
+
+// A HardwareAddr represents a physical hardware address.
+type HardwareAddr []byte
+
+func (a HardwareAddr) String() string {
+ if len(a) == 0 {
+ return ""
+ }
+ buf := make([]byte, 0, len(a)*3-1)
+ for i, b := range a {
+ if i > 0 {
+ buf = append(buf, ':')
+ }
+ buf = append(buf, hexDigit[b>>4])
+ buf = append(buf, hexDigit[b&0xF])
+ }
+ return string(buf)
+}
+
+// ParseMAC parses s as an IEEE 802 MAC-48, EUI-48, EUI-64, or a 20-octet
+// IP over InfiniBand link-layer address using one of the following formats:
+//
+// 00:00:5e:00:53:01
+// 02:00:5e:10:00:00:00:01
+// 00:00:00:00:fe:80:00:00:00:00:00:00:02:00:5e:10:00:00:00:01
+// 00-00-5e-00-53-01
+// 02-00-5e-10-00-00-00-01
+// 00-00-00-00-fe-80-00-00-00-00-00-00-02-00-5e-10-00-00-00-01
+// 0000.5e00.5301
+// 0200.5e10.0000.0001
+// 0000.0000.fe80.0000.0000.0000.0200.5e10.0000.0001
+func ParseMAC(s string) (hw HardwareAddr, err error) {
+ if len(s) < 14 {
+ goto error
+ }
+
+ if s[2] == ':' || s[2] == '-' {
+ if (len(s)+1)%3 != 0 {
+ goto error
+ }
+ n := (len(s) + 1) / 3
+ if n != 6 && n != 8 && n != 20 {
+ goto error
+ }
+ hw = make(HardwareAddr, n)
+ for x, i := 0, 0; i < n; i++ {
+ var ok bool
+ if hw[i], ok = xtoi2(s[x:], s[2]); !ok {
+ goto error
+ }
+ x += 3
+ }
+ } else if s[4] == '.' {
+ if (len(s)+1)%5 != 0 {
+ goto error
+ }
+ n := 2 * (len(s) + 1) / 5
+ if n != 6 && n != 8 && n != 20 {
+ goto error
+ }
+ hw = make(HardwareAddr, n)
+ for x, i := 0, 0; i < n; i += 2 {
+ var ok bool
+ if hw[i], ok = xtoi2(s[x:x+2], 0); !ok {
+ goto error
+ }
+ if hw[i+1], ok = xtoi2(s[x+2:], s[4]); !ok {
+ goto error
+ }
+ x += 5
+ }
+ } else {
+ goto error
+ }
+ return hw, nil
+
+error:
+ return nil, &AddrError{Err: "invalid MAC address", Addr: s}
+}
diff --git a/src/net/mac_test.go b/src/net/mac_test.go
new file mode 100644
index 0000000..cad884f
--- /dev/null
+++ b/src/net/mac_test.go
@@ -0,0 +1,109 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "reflect"
+ "strings"
+ "testing"
+)
+
+var parseMACTests = []struct {
+ in string
+ out HardwareAddr
+ err string
+}{
+ // See RFC 7042, Section 2.1.1.
+ {"00:00:5e:00:53:01", HardwareAddr{0x00, 0x00, 0x5e, 0x00, 0x53, 0x01}, ""},
+ {"00-00-5e-00-53-01", HardwareAddr{0x00, 0x00, 0x5e, 0x00, 0x53, 0x01}, ""},
+ {"0000.5e00.5301", HardwareAddr{0x00, 0x00, 0x5e, 0x00, 0x53, 0x01}, ""},
+
+ // See RFC 7042, Section 2.2.2.
+ {"02:00:5e:10:00:00:00:01", HardwareAddr{0x02, 0x00, 0x5e, 0x10, 0x00, 0x00, 0x00, 0x01}, ""},
+ {"02-00-5e-10-00-00-00-01", HardwareAddr{0x02, 0x00, 0x5e, 0x10, 0x00, 0x00, 0x00, 0x01}, ""},
+ {"0200.5e10.0000.0001", HardwareAddr{0x02, 0x00, 0x5e, 0x10, 0x00, 0x00, 0x00, 0x01}, ""},
+
+ // See RFC 4391, Section 9.1.1.
+ {
+ "00:00:00:00:fe:80:00:00:00:00:00:00:02:00:5e:10:00:00:00:01",
+ HardwareAddr{
+ 0x00, 0x00, 0x00, 0x00,
+ 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x02, 0x00, 0x5e, 0x10, 0x00, 0x00, 0x00, 0x01,
+ },
+ "",
+ },
+ {
+ "00-00-00-00-fe-80-00-00-00-00-00-00-02-00-5e-10-00-00-00-01",
+ HardwareAddr{
+ 0x00, 0x00, 0x00, 0x00,
+ 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x02, 0x00, 0x5e, 0x10, 0x00, 0x00, 0x00, 0x01,
+ },
+ "",
+ },
+ {
+ "0000.0000.fe80.0000.0000.0000.0200.5e10.0000.0001",
+ HardwareAddr{
+ 0x00, 0x00, 0x00, 0x00,
+ 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x02, 0x00, 0x5e, 0x10, 0x00, 0x00, 0x00, 0x01,
+ },
+ "",
+ },
+
+ {"ab:cd:ef:AB:CD:EF", HardwareAddr{0xab, 0xcd, 0xef, 0xab, 0xcd, 0xef}, ""},
+ {"ab:cd:ef:AB:CD:EF:ab:cd", HardwareAddr{0xab, 0xcd, 0xef, 0xab, 0xcd, 0xef, 0xab, 0xcd}, ""},
+ {
+ "ab:cd:ef:AB:CD:EF:ab:cd:ef:AB:CD:EF:ab:cd:ef:AB:CD:EF:ab:cd",
+ HardwareAddr{
+ 0xab, 0xcd, 0xef, 0xab,
+ 0xcd, 0xef, 0xab, 0xcd, 0xef, 0xab, 0xcd, 0xef,
+ 0xab, 0xcd, 0xef, 0xab, 0xcd, 0xef, 0xab, 0xcd,
+ },
+ "",
+ },
+
+ {"01.02.03.04.05.06", nil, "invalid MAC address"},
+ {"01:02:03:04:05:06:", nil, "invalid MAC address"},
+ {"x1:02:03:04:05:06", nil, "invalid MAC address"},
+ {"01002:03:04:05:06", nil, "invalid MAC address"},
+ {"01:02003:04:05:06", nil, "invalid MAC address"},
+ {"01:02:03004:05:06", nil, "invalid MAC address"},
+ {"01:02:03:04005:06", nil, "invalid MAC address"},
+ {"01:02:03:04:05006", nil, "invalid MAC address"},
+ {"01-02:03:04:05:06", nil, "invalid MAC address"},
+ {"01:02-03-04-05-06", nil, "invalid MAC address"},
+ {"0123:4567:89AF", nil, "invalid MAC address"},
+ {"0123-4567-89AF", nil, "invalid MAC address"},
+}
+
+func TestParseMAC(t *testing.T) {
+ match := func(err error, s string) bool {
+ if s == "" {
+ return err == nil
+ }
+ return err != nil && strings.Contains(err.Error(), s)
+ }
+
+ for i, tt := range parseMACTests {
+ out, err := ParseMAC(tt.in)
+ if !reflect.DeepEqual(out, tt.out) || !match(err, tt.err) {
+ t.Errorf("ParseMAC(%q) = %v, %v, want %v, %v", tt.in, out, err, tt.out, tt.err)
+ }
+ if tt.err == "" {
+ // Verify that serialization works too, and that it round-trips.
+ s := out.String()
+ out2, err := ParseMAC(s)
+ if err != nil {
+ t.Errorf("%d. ParseMAC(%q) = %v", i, s, err)
+ continue
+ }
+ if !reflect.DeepEqual(out2, out) {
+ t.Errorf("%d. ParseMAC(%q) = %v, want %v", i, s, out2, out)
+ }
+ }
+ }
+}
diff --git a/src/net/mail/example_test.go b/src/net/mail/example_test.go
new file mode 100644
index 0000000..d325dc7
--- /dev/null
+++ b/src/net/mail/example_test.go
@@ -0,0 +1,77 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package mail_test
+
+import (
+ "fmt"
+ "io"
+ "log"
+ "net/mail"
+ "strings"
+)
+
+func ExampleParseAddressList() {
+ const list = "Alice <alice@example.com>, Bob <bob@example.com>, Eve <eve@example.com>"
+ emails, err := mail.ParseAddressList(list)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ for _, v := range emails {
+ fmt.Println(v.Name, v.Address)
+ }
+
+ // Output:
+ // Alice alice@example.com
+ // Bob bob@example.com
+ // Eve eve@example.com
+}
+
+func ExampleParseAddress() {
+ e, err := mail.ParseAddress("Alice <alice@example.com>")
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ fmt.Println(e.Name, e.Address)
+
+ // Output:
+ // Alice alice@example.com
+}
+
+func ExampleReadMessage() {
+ msg := `Date: Mon, 23 Jun 2015 11:40:36 -0400
+From: Gopher <from@example.com>
+To: Another Gopher <to@example.com>
+Subject: Gophers at Gophercon
+
+Message body
+`
+
+ r := strings.NewReader(msg)
+ m, err := mail.ReadMessage(r)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ header := m.Header
+ fmt.Println("Date:", header.Get("Date"))
+ fmt.Println("From:", header.Get("From"))
+ fmt.Println("To:", header.Get("To"))
+ fmt.Println("Subject:", header.Get("Subject"))
+
+ body, err := io.ReadAll(m.Body)
+ if err != nil {
+ log.Fatal(err)
+ }
+ fmt.Printf("%s", body)
+
+ // Output:
+ // Date: Mon, 23 Jun 2015 11:40:36 -0400
+ // From: Gopher <from@example.com>
+ // To: Another Gopher <to@example.com>
+ // Subject: Gophers at Gophercon
+ // Message body
+}
diff --git a/src/net/mail/message.go b/src/net/mail/message.go
new file mode 100644
index 0000000..fc2a9e4
--- /dev/null
+++ b/src/net/mail/message.go
@@ -0,0 +1,915 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+/*
+Package mail implements parsing of mail messages.
+
+For the most part, this package follows the syntax as specified by RFC 5322 and
+extended by RFC 6532.
+Notable divergences:
+ - Obsolete address formats are not parsed, including addresses with
+ embedded route information.
+ - The full range of spacing (the CFWS syntax element) is not supported,
+ such as breaking addresses across lines.
+ - No unicode normalization is performed.
+ - The special characters ()[]:;@\, are allowed to appear unquoted in names.
+ - A leading From line is permitted, as in mbox format (RFC 4155).
+*/
+package mail
+
+import (
+ "bufio"
+ "errors"
+ "fmt"
+ "io"
+ "log"
+ "mime"
+ "net/textproto"
+ "strings"
+ "sync"
+ "time"
+ "unicode/utf8"
+)
+
+var debug = debugT(false)
+
+type debugT bool
+
+func (d debugT) Printf(format string, args ...any) {
+ if d {
+ log.Printf(format, args...)
+ }
+}
+
+// A Message represents a parsed mail message.
+type Message struct {
+ Header Header
+ Body io.Reader
+}
+
+// ReadMessage reads a message from r.
+// The headers are parsed, and the body of the message will be available
+// for reading from msg.Body.
+func ReadMessage(r io.Reader) (msg *Message, err error) {
+ tp := textproto.NewReader(bufio.NewReader(r))
+
+ hdr, err := readHeader(tp)
+ if err != nil && (err != io.EOF || len(hdr) == 0) {
+ return nil, err
+ }
+
+ return &Message{
+ Header: Header(hdr),
+ Body: tp.R,
+ }, nil
+}
+
+// readHeader reads the message headers from r.
+// This is like textproto.ReadMIMEHeader, but doesn't validate.
+// The fix for issue #53188 tightened up net/textproto to enforce
+// restrictions of RFC 7230.
+// This package implements RFC 5322, which does not have those restrictions.
+// This function copies the relevant code from net/textproto,
+// simplified for RFC 5322.
+func readHeader(r *textproto.Reader) (map[string][]string, error) {
+ m := make(map[string][]string)
+
+ // The first line cannot start with a leading space.
+ if buf, err := r.R.Peek(1); err == nil && (buf[0] == ' ' || buf[0] == '\t') {
+ line, err := r.ReadLine()
+ if err != nil {
+ return m, err
+ }
+ return m, errors.New("malformed initial line: " + line)
+ }
+
+ for {
+ kv, err := r.ReadContinuedLine()
+ if kv == "" {
+ return m, err
+ }
+
+ // Key ends at first colon.
+ k, v, ok := strings.Cut(kv, ":")
+ if !ok {
+ return m, errors.New("malformed header line: " + kv)
+ }
+ key := textproto.CanonicalMIMEHeaderKey(k)
+
+ // Permit empty key, because that is what we did in the past.
+ if key == "" {
+ continue
+ }
+
+ // Skip initial spaces in value.
+ value := strings.TrimLeft(v, " \t")
+
+ m[key] = append(m[key], value)
+
+ if err != nil {
+ return m, err
+ }
+ }
+}
+
+// Layouts suitable for passing to time.Parse.
+// These are tried in order.
+var (
+ dateLayoutsBuildOnce sync.Once
+ dateLayouts []string
+)
+
+func buildDateLayouts() {
+ // Generate layouts based on RFC 5322, section 3.3.
+
+ dows := [...]string{"", "Mon, "} // day-of-week
+ days := [...]string{"2", "02"} // day = 1*2DIGIT
+ years := [...]string{"2006", "06"} // year = 4*DIGIT / 2*DIGIT
+ seconds := [...]string{":05", ""} // second
+ // "-0700 (MST)" is not in RFC 5322, but is common.
+ zones := [...]string{"-0700", "MST", "UT"} // zone = (("+" / "-") 4DIGIT) / "UT" / "GMT" / ...
+
+ for _, dow := range dows {
+ for _, day := range days {
+ for _, year := range years {
+ for _, second := range seconds {
+ for _, zone := range zones {
+ s := dow + day + " Jan " + year + " 15:04" + second + " " + zone
+ dateLayouts = append(dateLayouts, s)
+ }
+ }
+ }
+ }
+ }
+}
+
+// ParseDate parses an RFC 5322 date string.
+func ParseDate(date string) (time.Time, error) {
+ dateLayoutsBuildOnce.Do(buildDateLayouts)
+ // CR and LF must match and are tolerated anywhere in the date field.
+ date = strings.ReplaceAll(date, "\r\n", "")
+ if strings.Contains(date, "\r") {
+ return time.Time{}, errors.New("mail: header has a CR without LF")
+ }
+ // Re-using some addrParser methods which support obsolete text, i.e. non-printable ASCII
+ p := addrParser{date, nil}
+ p.skipSpace()
+
+ // RFC 5322: zone = (FWS ( "+" / "-" ) 4DIGIT) / obs-zone
+ // zone length is always 5 chars unless obsolete (obs-zone)
+ if ind := strings.IndexAny(p.s, "+-"); ind != -1 && len(p.s) >= ind+5 {
+ date = p.s[:ind+5]
+ p.s = p.s[ind+5:]
+ } else {
+ ind := strings.Index(p.s, "T")
+ if ind == 0 {
+ // In this case we have the following date formats:
+ // * Thu, 20 Nov 1997 09:55:06 MDT
+ // * Thu, 20 Nov 1997 09:55:06 MDT (MDT)
+ // * Thu, 20 Nov 1997 09:55:06 MDT (This comment)
+ ind = strings.Index(p.s[1:], "T")
+ if ind != -1 {
+ ind++
+ }
+ }
+
+ if ind != -1 && len(p.s) >= ind+5 {
+ // The last letter T of the obsolete time zone is checked when no standard time zone is found.
+ // If T is misplaced, the date to parse is garbage.
+ date = p.s[:ind+1]
+ p.s = p.s[ind+1:]
+ }
+ }
+ if !p.skipCFWS() {
+ return time.Time{}, errors.New("mail: misformatted parenthetical comment")
+ }
+ for _, layout := range dateLayouts {
+ t, err := time.Parse(layout, date)
+ if err == nil {
+ return t, nil
+ }
+ }
+ return time.Time{}, errors.New("mail: header could not be parsed")
+}
+
+// A Header represents the key-value pairs in a mail message header.
+type Header map[string][]string
+
+// Get gets the first value associated with the given key.
+// It is case insensitive; CanonicalMIMEHeaderKey is used
+// to canonicalize the provided key.
+// If there are no values associated with the key, Get returns "".
+// To access multiple values of a key, or to use non-canonical keys,
+// access the map directly.
+func (h Header) Get(key string) string {
+ return textproto.MIMEHeader(h).Get(key)
+}
+
+var ErrHeaderNotPresent = errors.New("mail: header not in message")
+
+// Date parses the Date header field.
+func (h Header) Date() (time.Time, error) {
+ hdr := h.Get("Date")
+ if hdr == "" {
+ return time.Time{}, ErrHeaderNotPresent
+ }
+ return ParseDate(hdr)
+}
+
+// AddressList parses the named header field as a list of addresses.
+func (h Header) AddressList(key string) ([]*Address, error) {
+ hdr := h.Get(key)
+ if hdr == "" {
+ return nil, ErrHeaderNotPresent
+ }
+ return ParseAddressList(hdr)
+}
+
+// Address represents a single mail address.
+// An address such as "Barry Gibbs <bg@example.com>" is represented
+// as Address{Name: "Barry Gibbs", Address: "bg@example.com"}.
+type Address struct {
+ Name string // Proper name; may be empty.
+ Address string // user@domain
+}
+
+// ParseAddress parses a single RFC 5322 address, e.g. "Barry Gibbs <bg@example.com>"
+func ParseAddress(address string) (*Address, error) {
+ return (&addrParser{s: address}).parseSingleAddress()
+}
+
+// ParseAddressList parses the given string as a list of addresses.
+func ParseAddressList(list string) ([]*Address, error) {
+ return (&addrParser{s: list}).parseAddressList()
+}
+
+// An AddressParser is an RFC 5322 address parser.
+type AddressParser struct {
+ // WordDecoder optionally specifies a decoder for RFC 2047 encoded-words.
+ WordDecoder *mime.WordDecoder
+}
+
+// Parse parses a single RFC 5322 address of the
+// form "Gogh Fir <gf@example.com>" or "foo@example.com".
+func (p *AddressParser) Parse(address string) (*Address, error) {
+ return (&addrParser{s: address, dec: p.WordDecoder}).parseSingleAddress()
+}
+
+// ParseList parses the given string as a list of comma-separated addresses
+// of the form "Gogh Fir <gf@example.com>" or "foo@example.com".
+func (p *AddressParser) ParseList(list string) ([]*Address, error) {
+ return (&addrParser{s: list, dec: p.WordDecoder}).parseAddressList()
+}
+
+// String formats the address as a valid RFC 5322 address.
+// If the address's name contains non-ASCII characters
+// the name will be rendered according to RFC 2047.
+func (a *Address) String() string {
+ // Format address local@domain
+ at := strings.LastIndex(a.Address, "@")
+ var local, domain string
+ if at < 0 {
+ // This is a malformed address ("@" is required in addr-spec);
+ // treat the whole address as local-part.
+ local = a.Address
+ } else {
+ local, domain = a.Address[:at], a.Address[at+1:]
+ }
+
+ // Add quotes if needed
+ quoteLocal := false
+ for i, r := range local {
+ if isAtext(r, false) {
+ continue
+ }
+ if r == '.' {
+ // Dots are okay if they are surrounded by atext.
+ // We only need to check that the previous byte is
+ // not a dot, and this isn't the end of the string.
+ if i > 0 && local[i-1] != '.' && i < len(local)-1 {
+ continue
+ }
+ }
+ quoteLocal = true
+ break
+ }
+ if quoteLocal {
+ local = quoteString(local)
+
+ }
+
+ s := "<" + local + "@" + domain + ">"
+
+ if a.Name == "" {
+ return s
+ }
+
+ // If every character is printable ASCII, quoting is simple.
+ allPrintable := true
+ for _, r := range a.Name {
+ // isWSP here should actually be isFWS,
+ // but we don't support folding yet.
+ if !isVchar(r) && !isWSP(r) || isMultibyte(r) {
+ allPrintable = false
+ break
+ }
+ }
+ if allPrintable {
+ return quoteString(a.Name) + " " + s
+ }
+
+ // Text in an encoded-word in a display-name must not contain certain
+ // characters like quotes or parentheses (see RFC 2047 section 5.3).
+ // When this is the case encode the name using base64 encoding.
+ if strings.ContainsAny(a.Name, "\"#$%&'(),.:;<>@[]^`{|}~") {
+ return mime.BEncoding.Encode("utf-8", a.Name) + " " + s
+ }
+ return mime.QEncoding.Encode("utf-8", a.Name) + " " + s
+}
+
+type addrParser struct {
+ s string
+ dec *mime.WordDecoder // may be nil
+}
+
+func (p *addrParser) parseAddressList() ([]*Address, error) {
+ var list []*Address
+ for {
+ p.skipSpace()
+
+ // allow skipping empty entries (RFC5322 obs-addr-list)
+ if p.consume(',') {
+ continue
+ }
+
+ addrs, err := p.parseAddress(true)
+ if err != nil {
+ return nil, err
+ }
+ list = append(list, addrs...)
+
+ if !p.skipCFWS() {
+ return nil, errors.New("mail: misformatted parenthetical comment")
+ }
+ if p.empty() {
+ break
+ }
+ if p.peek() != ',' {
+ return nil, errors.New("mail: expected comma")
+ }
+
+ // Skip empty entries for obs-addr-list.
+ for p.consume(',') {
+ p.skipSpace()
+ }
+ if p.empty() {
+ break
+ }
+ }
+ return list, nil
+}
+
+func (p *addrParser) parseSingleAddress() (*Address, error) {
+ addrs, err := p.parseAddress(true)
+ if err != nil {
+ return nil, err
+ }
+ if !p.skipCFWS() {
+ return nil, errors.New("mail: misformatted parenthetical comment")
+ }
+ if !p.empty() {
+ return nil, fmt.Errorf("mail: expected single address, got %q", p.s)
+ }
+ if len(addrs) == 0 {
+ return nil, errors.New("mail: empty group")
+ }
+ if len(addrs) > 1 {
+ return nil, errors.New("mail: group with multiple addresses")
+ }
+ return addrs[0], nil
+}
+
+// parseAddress parses a single RFC 5322 address at the start of p.
+func (p *addrParser) parseAddress(handleGroup bool) ([]*Address, error) {
+ debug.Printf("parseAddress: %q", p.s)
+ p.skipSpace()
+ if p.empty() {
+ return nil, errors.New("mail: no address")
+ }
+
+ // address = mailbox / group
+ // mailbox = name-addr / addr-spec
+ // group = display-name ":" [group-list] ";" [CFWS]
+
+ // addr-spec has a more restricted grammar than name-addr,
+ // so try parsing it first, and fallback to name-addr.
+ // TODO(dsymonds): Is this really correct?
+ spec, err := p.consumeAddrSpec()
+ if err == nil {
+ var displayName string
+ p.skipSpace()
+ if !p.empty() && p.peek() == '(' {
+ displayName, err = p.consumeDisplayNameComment()
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ return []*Address{{
+ Name: displayName,
+ Address: spec,
+ }}, err
+ }
+ debug.Printf("parseAddress: not an addr-spec: %v", err)
+ debug.Printf("parseAddress: state is now %q", p.s)
+
+ // display-name
+ var displayName string
+ if p.peek() != '<' {
+ displayName, err = p.consumePhrase()
+ if err != nil {
+ return nil, err
+ }
+ }
+ debug.Printf("parseAddress: displayName=%q", displayName)
+
+ p.skipSpace()
+ if handleGroup {
+ if p.consume(':') {
+ return p.consumeGroupList()
+ }
+ }
+ // angle-addr = "<" addr-spec ">"
+ if !p.consume('<') {
+ atext := true
+ for _, r := range displayName {
+ if !isAtext(r, true) {
+ atext = false
+ break
+ }
+ }
+ if atext {
+ // The input is like "foo.bar"; it's possible the input
+ // meant to be "foo.bar@domain", or "foo.bar <...>".
+ return nil, errors.New("mail: missing '@' or angle-addr")
+ }
+ // The input is like "Full Name", which couldn't possibly be a
+ // valid email address if followed by "@domain"; the input
+ // likely meant to be "Full Name <...>".
+ return nil, errors.New("mail: no angle-addr")
+ }
+ spec, err = p.consumeAddrSpec()
+ if err != nil {
+ return nil, err
+ }
+ if !p.consume('>') {
+ return nil, errors.New("mail: unclosed angle-addr")
+ }
+ debug.Printf("parseAddress: spec=%q", spec)
+
+ return []*Address{{
+ Name: displayName,
+ Address: spec,
+ }}, nil
+}
+
+func (p *addrParser) consumeGroupList() ([]*Address, error) {
+ var group []*Address
+ // handle empty group.
+ p.skipSpace()
+ if p.consume(';') {
+ if !p.skipCFWS() {
+ return nil, errors.New("mail: misformatted parenthetical comment")
+ }
+ return group, nil
+ }
+
+ for {
+ p.skipSpace()
+ // embedded groups not allowed.
+ addrs, err := p.parseAddress(false)
+ if err != nil {
+ return nil, err
+ }
+ group = append(group, addrs...)
+
+ if !p.skipCFWS() {
+ return nil, errors.New("mail: misformatted parenthetical comment")
+ }
+ if p.consume(';') {
+ if !p.skipCFWS() {
+ return nil, errors.New("mail: misformatted parenthetical comment")
+ }
+ break
+ }
+ if !p.consume(',') {
+ return nil, errors.New("mail: expected comma")
+ }
+ }
+ return group, nil
+}
+
+// consumeAddrSpec parses a single RFC 5322 addr-spec at the start of p.
+func (p *addrParser) consumeAddrSpec() (spec string, err error) {
+ debug.Printf("consumeAddrSpec: %q", p.s)
+
+ orig := *p
+ defer func() {
+ if err != nil {
+ *p = orig
+ }
+ }()
+
+ // local-part = dot-atom / quoted-string
+ var localPart string
+ p.skipSpace()
+ if p.empty() {
+ return "", errors.New("mail: no addr-spec")
+ }
+ if p.peek() == '"' {
+ // quoted-string
+ debug.Printf("consumeAddrSpec: parsing quoted-string")
+ localPart, err = p.consumeQuotedString()
+ if localPart == "" {
+ err = errors.New("mail: empty quoted string in addr-spec")
+ }
+ } else {
+ // dot-atom
+ debug.Printf("consumeAddrSpec: parsing dot-atom")
+ localPart, err = p.consumeAtom(true, false)
+ }
+ if err != nil {
+ debug.Printf("consumeAddrSpec: failed: %v", err)
+ return "", err
+ }
+
+ if !p.consume('@') {
+ return "", errors.New("mail: missing @ in addr-spec")
+ }
+
+ // domain = dot-atom / domain-literal
+ var domain string
+ p.skipSpace()
+ if p.empty() {
+ return "", errors.New("mail: no domain in addr-spec")
+ }
+ // TODO(dsymonds): Handle domain-literal
+ domain, err = p.consumeAtom(true, false)
+ if err != nil {
+ return "", err
+ }
+
+ return localPart + "@" + domain, nil
+}
+
+// consumePhrase parses the RFC 5322 phrase at the start of p.
+func (p *addrParser) consumePhrase() (phrase string, err error) {
+ debug.Printf("consumePhrase: [%s]", p.s)
+ // phrase = 1*word
+ var words []string
+ var isPrevEncoded bool
+ for {
+ // obs-phrase allows CFWS after one word
+ if len(words) > 0 {
+ if !p.skipCFWS() {
+ return "", errors.New("mail: misformatted parenthetical comment")
+ }
+ }
+ // word = atom / quoted-string
+ var word string
+ p.skipSpace()
+ if p.empty() {
+ break
+ }
+ isEncoded := false
+ if p.peek() == '"' {
+ // quoted-string
+ word, err = p.consumeQuotedString()
+ } else {
+ // atom
+ // We actually parse dot-atom here to be more permissive
+ // than what RFC 5322 specifies.
+ word, err = p.consumeAtom(true, true)
+ if err == nil {
+ word, isEncoded, err = p.decodeRFC2047Word(word)
+ }
+ }
+
+ if err != nil {
+ break
+ }
+ debug.Printf("consumePhrase: consumed %q", word)
+ if isPrevEncoded && isEncoded {
+ words[len(words)-1] += word
+ } else {
+ words = append(words, word)
+ }
+ isPrevEncoded = isEncoded
+ }
+ // Ignore any error if we got at least one word.
+ if err != nil && len(words) == 0 {
+ debug.Printf("consumePhrase: hit err: %v", err)
+ return "", fmt.Errorf("mail: missing word in phrase: %v", err)
+ }
+ phrase = strings.Join(words, " ")
+ return phrase, nil
+}
+
+// consumeQuotedString parses the quoted string at the start of p.
+func (p *addrParser) consumeQuotedString() (qs string, err error) {
+ // Assume first byte is '"'.
+ i := 1
+ qsb := make([]rune, 0, 10)
+
+ escaped := false
+
+Loop:
+ for {
+ r, size := utf8.DecodeRuneInString(p.s[i:])
+
+ switch {
+ case size == 0:
+ return "", errors.New("mail: unclosed quoted-string")
+
+ case size == 1 && r == utf8.RuneError:
+ return "", fmt.Errorf("mail: invalid utf-8 in quoted-string: %q", p.s)
+
+ case escaped:
+ // quoted-pair = ("\" (VCHAR / WSP))
+
+ if !isVchar(r) && !isWSP(r) {
+ return "", fmt.Errorf("mail: bad character in quoted-string: %q", r)
+ }
+
+ qsb = append(qsb, r)
+ escaped = false
+
+ case isQtext(r) || isWSP(r):
+ // qtext (printable US-ASCII excluding " and \), or
+ // FWS (almost; we're ignoring CRLF)
+ qsb = append(qsb, r)
+
+ case r == '"':
+ break Loop
+
+ case r == '\\':
+ escaped = true
+
+ default:
+ return "", fmt.Errorf("mail: bad character in quoted-string: %q", r)
+
+ }
+
+ i += size
+ }
+ p.s = p.s[i+1:]
+ return string(qsb), nil
+}
+
+// consumeAtom parses an RFC 5322 atom at the start of p.
+// If dot is true, consumeAtom parses an RFC 5322 dot-atom instead.
+// If permissive is true, consumeAtom will not fail on:
+// - leading/trailing/double dots in the atom (see golang.org/issue/4938)
+func (p *addrParser) consumeAtom(dot bool, permissive bool) (atom string, err error) {
+ i := 0
+
+Loop:
+ for {
+ r, size := utf8.DecodeRuneInString(p.s[i:])
+ switch {
+ case size == 1 && r == utf8.RuneError:
+ return "", fmt.Errorf("mail: invalid utf-8 in address: %q", p.s)
+
+ case size == 0 || !isAtext(r, dot):
+ break Loop
+
+ default:
+ i += size
+
+ }
+ }
+
+ if i == 0 {
+ return "", errors.New("mail: invalid string")
+ }
+ atom, p.s = p.s[:i], p.s[i:]
+ if !permissive {
+ if strings.HasPrefix(atom, ".") {
+ return "", errors.New("mail: leading dot in atom")
+ }
+ if strings.Contains(atom, "..") {
+ return "", errors.New("mail: double dot in atom")
+ }
+ if strings.HasSuffix(atom, ".") {
+ return "", errors.New("mail: trailing dot in atom")
+ }
+ }
+ return atom, nil
+}
+
+func (p *addrParser) consumeDisplayNameComment() (string, error) {
+ if !p.consume('(') {
+ return "", errors.New("mail: comment does not start with (")
+ }
+ comment, ok := p.consumeComment()
+ if !ok {
+ return "", errors.New("mail: misformatted parenthetical comment")
+ }
+
+ // TODO(stapelberg): parse quoted-string within comment
+ words := strings.FieldsFunc(comment, func(r rune) bool { return r == ' ' || r == '\t' })
+ for idx, word := range words {
+ decoded, isEncoded, err := p.decodeRFC2047Word(word)
+ if err != nil {
+ return "", err
+ }
+ if isEncoded {
+ words[idx] = decoded
+ }
+ }
+
+ return strings.Join(words, " "), nil
+}
+
+func (p *addrParser) consume(c byte) bool {
+ if p.empty() || p.peek() != c {
+ return false
+ }
+ p.s = p.s[1:]
+ return true
+}
+
+// skipSpace skips the leading space and tab characters.
+func (p *addrParser) skipSpace() {
+ p.s = strings.TrimLeft(p.s, " \t")
+}
+
+func (p *addrParser) peek() byte {
+ return p.s[0]
+}
+
+func (p *addrParser) empty() bool {
+ return p.len() == 0
+}
+
+func (p *addrParser) len() int {
+ return len(p.s)
+}
+
+// skipCFWS skips CFWS as defined in RFC5322.
+func (p *addrParser) skipCFWS() bool {
+ p.skipSpace()
+
+ for {
+ if !p.consume('(') {
+ break
+ }
+
+ if _, ok := p.consumeComment(); !ok {
+ return false
+ }
+
+ p.skipSpace()
+ }
+
+ return true
+}
+
+func (p *addrParser) consumeComment() (string, bool) {
+ // '(' already consumed.
+ depth := 1
+
+ var comment string
+ for {
+ if p.empty() || depth == 0 {
+ break
+ }
+
+ if p.peek() == '\\' && p.len() > 1 {
+ p.s = p.s[1:]
+ } else if p.peek() == '(' {
+ depth++
+ } else if p.peek() == ')' {
+ depth--
+ }
+ if depth > 0 {
+ comment += p.s[:1]
+ }
+ p.s = p.s[1:]
+ }
+
+ return comment, depth == 0
+}
+
+func (p *addrParser) decodeRFC2047Word(s string) (word string, isEncoded bool, err error) {
+ dec := p.dec
+ if dec == nil {
+ dec = &rfc2047Decoder
+ }
+
+ // Substitute our own CharsetReader function so that we can tell
+ // whether an error from the Decode method was due to the
+ // CharsetReader (meaning the charset is invalid).
+ // We used to look for the charsetError type in the error result,
+ // but that behaves badly with CharsetReaders other than the
+ // one in rfc2047Decoder.
+ adec := *dec
+ charsetReaderError := false
+ adec.CharsetReader = func(charset string, input io.Reader) (io.Reader, error) {
+ if dec.CharsetReader == nil {
+ charsetReaderError = true
+ return nil, charsetError(charset)
+ }
+ r, err := dec.CharsetReader(charset, input)
+ if err != nil {
+ charsetReaderError = true
+ }
+ return r, err
+ }
+ word, err = adec.Decode(s)
+ if err == nil {
+ return word, true, nil
+ }
+
+ // If the error came from the character set reader
+ // (meaning the character set itself is invalid
+ // but the decoding worked fine until then),
+ // return the original text and the error,
+ // with isEncoded=true.
+ if charsetReaderError {
+ return s, true, err
+ }
+
+ // Ignore invalid RFC 2047 encoded-word errors.
+ return s, false, nil
+}
+
+var rfc2047Decoder = mime.WordDecoder{
+ CharsetReader: func(charset string, input io.Reader) (io.Reader, error) {
+ return nil, charsetError(charset)
+ },
+}
+
+type charsetError string
+
+func (e charsetError) Error() string {
+ return fmt.Sprintf("charset not supported: %q", string(e))
+}
+
+// isAtext reports whether r is an RFC 5322 atext character.
+// If dot is true, period is included.
+func isAtext(r rune, dot bool) bool {
+ switch r {
+ case '.':
+ return dot
+
+ // RFC 5322 3.2.3. specials
+ case '(', ')', '<', '>', '[', ']', ':', ';', '@', '\\', ',', '"': // RFC 5322 3.2.3. specials
+ return false
+ }
+ return isVchar(r)
+}
+
+// isQtext reports whether r is an RFC 5322 qtext character.
+func isQtext(r rune) bool {
+ // Printable US-ASCII, excluding backslash or quote.
+ if r == '\\' || r == '"' {
+ return false
+ }
+ return isVchar(r)
+}
+
+// quoteString renders a string as an RFC 5322 quoted-string.
+func quoteString(s string) string {
+ var b strings.Builder
+ b.WriteByte('"')
+ for _, r := range s {
+ if isQtext(r) || isWSP(r) {
+ b.WriteRune(r)
+ } else if isVchar(r) {
+ b.WriteByte('\\')
+ b.WriteRune(r)
+ }
+ }
+ b.WriteByte('"')
+ return b.String()
+}
+
+// isVchar reports whether r is an RFC 5322 VCHAR character.
+func isVchar(r rune) bool {
+ // Visible (printing) characters.
+ return '!' <= r && r <= '~' || isMultibyte(r)
+}
+
+// isMultibyte reports whether r is a multi-byte UTF-8 character
+// as supported by RFC 6532.
+func isMultibyte(r rune) bool {
+ return r >= utf8.RuneSelf
+}
+
+// isWSP reports whether r is a WSP (white space).
+// WSP is a space or horizontal tab (RFC 5234 Appendix B).
+func isWSP(r rune) bool {
+ return r == ' ' || r == '\t'
+}
diff --git a/src/net/mail/message_test.go b/src/net/mail/message_test.go
new file mode 100644
index 0000000..1f2f62a
--- /dev/null
+++ b/src/net/mail/message_test.go
@@ -0,0 +1,1219 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package mail
+
+import (
+ "bytes"
+ "io"
+ "mime"
+ "reflect"
+ "strings"
+ "testing"
+ "time"
+)
+
+var parseTests = []struct {
+ in string
+ header Header
+ body string
+}{
+ {
+ // RFC 5322, Appendix A.1.1
+ in: `From: John Doe <jdoe@machine.example>
+To: Mary Smith <mary@example.net>
+Subject: Saying Hello
+Date: Fri, 21 Nov 1997 09:55:06 -0600
+Message-ID: <1234@local.machine.example>
+
+This is a message just to say hello.
+So, "Hello".
+`,
+ header: Header{
+ "From": []string{"John Doe <jdoe@machine.example>"},
+ "To": []string{"Mary Smith <mary@example.net>"},
+ "Subject": []string{"Saying Hello"},
+ "Date": []string{"Fri, 21 Nov 1997 09:55:06 -0600"},
+ "Message-Id": []string{"<1234@local.machine.example>"},
+ },
+ body: "This is a message just to say hello.\nSo, \"Hello\".\n",
+ },
+ {
+ // RFC 5965, Appendix B.1, a part of the multipart message (a header-only sub message)
+ in: `Feedback-Type: abuse
+User-Agent: SomeGenerator/1.0
+Version: 1
+`,
+ header: Header{
+ "Feedback-Type": []string{"abuse"},
+ "User-Agent": []string{"SomeGenerator/1.0"},
+ "Version": []string{"1"},
+ },
+ body: "",
+ },
+ {
+ // RFC 5322 permits any printable ASCII character,
+ // except colon, in a header key. Issue #58862.
+ in: `From: iant@golang.org
+Custom/Header: v
+
+Body
+`,
+ header: Header{
+ "From": []string{"iant@golang.org"},
+ "Custom/Header": []string{"v"},
+ },
+ body: "Body\n",
+ },
+ {
+ // RFC 4155 mbox format. We've historically permitted this,
+ // so we continue to permit it. Issue #60332.
+ in: `From iant@golang.org Mon Jun 19 00:00:00 2023
+From: iant@golang.org
+
+Hello, gophers!
+`,
+ header: Header{
+ "From": []string{"iant@golang.org"},
+ "From iant@golang.org Mon Jun 19 00": []string{"00:00 2023"},
+ },
+ body: "Hello, gophers!\n",
+ },
+}
+
+func TestParsing(t *testing.T) {
+ for i, test := range parseTests {
+ msg, err := ReadMessage(bytes.NewBuffer([]byte(test.in)))
+ if err != nil {
+ t.Errorf("test #%d: Failed parsing message: %v", i, err)
+ continue
+ }
+ if !headerEq(msg.Header, test.header) {
+ t.Errorf("test #%d: Incorrectly parsed message header.\nGot:\n%+v\nWant:\n%+v",
+ i, msg.Header, test.header)
+ }
+ body, err := io.ReadAll(msg.Body)
+ if err != nil {
+ t.Errorf("test #%d: Failed reading body: %v", i, err)
+ continue
+ }
+ bodyStr := string(body)
+ if bodyStr != test.body {
+ t.Errorf("test #%d: Incorrectly parsed message body.\nGot:\n%+v\nWant:\n%+v",
+ i, bodyStr, test.body)
+ }
+ }
+}
+
+func headerEq(a, b Header) bool {
+ if len(a) != len(b) {
+ return false
+ }
+ for k, as := range a {
+ bs, ok := b[k]
+ if !ok {
+ return false
+ }
+ if !reflect.DeepEqual(as, bs) {
+ return false
+ }
+ }
+ return true
+}
+
+func TestDateParsing(t *testing.T) {
+ tests := []struct {
+ dateStr string
+ exp time.Time
+ }{
+ // RFC 5322, Appendix A.1.1
+ {
+ "Fri, 21 Nov 1997 09:55:06 -0600",
+ time.Date(1997, 11, 21, 9, 55, 6, 0, time.FixedZone("", -6*60*60)),
+ },
+ // RFC 5322, Appendix A.6.2
+ // Obsolete date.
+ {
+ "21 Nov 97 09:55:06 GMT",
+ time.Date(1997, 11, 21, 9, 55, 6, 0, time.FixedZone("GMT", 0)),
+ },
+ // Commonly found format not specified by RFC 5322.
+ {
+ "Fri, 21 Nov 1997 09:55:06 -0600 (MDT)",
+ time.Date(1997, 11, 21, 9, 55, 6, 0, time.FixedZone("", -6*60*60)),
+ },
+ {
+ "Thu, 20 Nov 1997 09:55:06 -0600 (MDT)",
+ time.Date(1997, 11, 20, 9, 55, 6, 0, time.FixedZone("", -6*60*60)),
+ },
+ {
+ "Thu, 20 Nov 1997 09:55:06 GMT (GMT)",
+ time.Date(1997, 11, 20, 9, 55, 6, 0, time.UTC),
+ },
+ {
+ "Fri, 21 Nov 1997 09:55:06 +1300 (TOT)",
+ time.Date(1997, 11, 21, 9, 55, 6, 0, time.FixedZone("", +13*60*60)),
+ },
+ }
+ for _, test := range tests {
+ hdr := Header{
+ "Date": []string{test.dateStr},
+ }
+ date, err := hdr.Date()
+ if err != nil {
+ t.Errorf("Header(Date: %s).Date(): %v", test.dateStr, err)
+ } else if !date.Equal(test.exp) {
+ t.Errorf("Header(Date: %s).Date() = %+v, want %+v", test.dateStr, date, test.exp)
+ }
+
+ date, err = ParseDate(test.dateStr)
+ if err != nil {
+ t.Errorf("ParseDate(%s): %v", test.dateStr, err)
+ } else if !date.Equal(test.exp) {
+ t.Errorf("ParseDate(%s) = %+v, want %+v", test.dateStr, date, test.exp)
+ }
+ }
+}
+
+func TestDateParsingCFWS(t *testing.T) {
+ tests := []struct {
+ dateStr string
+ exp time.Time
+ valid bool
+ }{
+ // FWS-only. No date.
+ {
+ " ",
+ // nil is not allowed
+ time.Date(1997, 11, 21, 9, 55, 6, 0, time.FixedZone("", -6*60*60)),
+ false,
+ },
+ // FWS is allowed before optional day of week.
+ {
+ " Fri, 21 Nov 1997 09:55:06 -0600",
+ time.Date(1997, 11, 21, 9, 55, 6, 0, time.FixedZone("", -6*60*60)),
+ true,
+ },
+ {
+ "21 Nov 1997 09:55:06 -0600",
+ time.Date(1997, 11, 21, 9, 55, 6, 0, time.FixedZone("", -6*60*60)),
+ true,
+ },
+ {
+ "Fri 21 Nov 1997 09:55:06 -0600",
+ time.Date(1997, 11, 21, 9, 55, 6, 0, time.FixedZone("", -6*60*60)),
+ false, // missing ,
+ },
+ // FWS is allowed before day of month but HTAB fails.
+ {
+ "Fri, 21 Nov 1997 09:55:06 -0600",
+ time.Date(1997, 11, 21, 9, 55, 6, 0, time.FixedZone("", -6*60*60)),
+ true,
+ },
+ // FWS is allowed before and after year but HTAB fails.
+ {
+ "Fri, 21 Nov 1997 09:55:06 -0600",
+ time.Date(1997, 11, 21, 9, 55, 6, 0, time.FixedZone("", -6*60*60)),
+ true,
+ },
+ // FWS is allowed before zone but HTAB is not handled. Obsolete timezone is handled.
+ {
+ "Fri, 21 Nov 1997 09:55:06 CST",
+ time.Time{},
+ true,
+ },
+ // FWS is allowed after date and a CRLF is already replaced.
+ {
+ "Fri, 21 Nov 1997 09:55:06 CST (no leading FWS and a trailing CRLF) \r\n",
+ time.Time{},
+ true,
+ },
+ // CFWS is a reduced set of US-ASCII where space and accentuated are obsolete. No error.
+ {
+ "Fri, 21 Nov 1997 09:55:06 -0600 (MDT and non-US-ASCII signs éèç )",
+ time.Date(1997, 11, 21, 9, 55, 6, 0, time.FixedZone("", -6*60*60)),
+ true,
+ },
+ // CFWS is allowed after zone including a nested comment.
+ // Trailing FWS is allowed.
+ {
+ "Fri, 21 Nov 1997 09:55:06 -0600 \r\n (thisisa(valid)cfws) \t ",
+ time.Date(1997, 11, 21, 9, 55, 6, 0, time.FixedZone("", -6*60*60)),
+ true,
+ },
+ // CRLF is incomplete and misplaced.
+ {
+ "Fri, 21 Nov 1997 \r 09:55:06 -0600 \r\n (thisisa(valid)cfws) \t ",
+ time.Date(1997, 11, 21, 9, 55, 6, 0, time.FixedZone("", -6*60*60)),
+ false,
+ },
+ // CRLF is complete but misplaced. No error is returned.
+ {
+ "Fri, 21 Nov 199\r\n7 09:55:06 -0600 \r\n (thisisa(valid)cfws) \t ",
+ time.Date(1997, 11, 21, 9, 55, 6, 0, time.FixedZone("", -6*60*60)),
+ true, // should be false in the strict interpretation of RFC 5322.
+ },
+ // Invalid ASCII in date.
+ {
+ "Fri, 21 Nov 1997 ù 09:55:06 -0600 \r\n (thisisa(valid)cfws) \t ",
+ time.Date(1997, 11, 21, 9, 55, 6, 0, time.FixedZone("", -6*60*60)),
+ false,
+ },
+ // CFWS chars () in date.
+ {
+ "Fri, 21 Nov () 1997 09:55:06 -0600 \r\n (thisisa(valid)cfws) \t ",
+ time.Date(1997, 11, 21, 9, 55, 6, 0, time.FixedZone("", -6*60*60)),
+ false,
+ },
+ // Timezone is invalid but T is found in comment.
+ {
+ "Fri, 21 Nov 1997 09:55:06 -060 \r\n (Thisisa(valid)cfws) \t ",
+ time.Date(1997, 11, 21, 9, 55, 6, 0, time.FixedZone("", -6*60*60)),
+ false,
+ },
+ // Date has no month.
+ {
+ "Fri, 21 1997 09:55:06 -0600",
+ time.Date(1997, 11, 21, 9, 55, 6, 0, time.FixedZone("", -6*60*60)),
+ false,
+ },
+ // Invalid month : OCT iso Oct
+ {
+ "Fri, 21 OCT 1997 09:55:06 CST",
+ time.Time{},
+ false,
+ },
+ // A too short time zone.
+ {
+ "Fri, 21 Nov 1997 09:55:06 -060",
+ time.Date(1997, 11, 21, 9, 55, 6, 0, time.FixedZone("", -6*60*60)),
+ false,
+ },
+ // A too short obsolete time zone.
+ {
+ "Fri, 21 1997 09:55:06 GT",
+ time.Date(1997, 11, 21, 9, 55, 6, 0, time.FixedZone("", -6*60*60)),
+ false,
+ },
+ // Ensure that the presence of "T" in the date
+ // doesn't trip out ParseDate, as per issue 39260.
+ {
+ "Tue, 26 May 2020 14:04:40 GMT",
+ time.Date(2020, 05, 26, 14, 04, 40, 0, time.UTC),
+ true,
+ },
+ {
+ "Tue, 26 May 2020 14:04:40 UT",
+ time.Date(2020, 05, 26, 14, 04, 40, 0, time.UTC),
+ true,
+ },
+ {
+ "Thu, 21 May 2020 14:04:40 UT",
+ time.Date(2020, 05, 21, 14, 04, 40, 0, time.UTC),
+ true,
+ },
+ {
+ "Tue, 26 May 2020 14:04:40 XT",
+ time.Date(2020, 05, 26, 14, 04, 40, 0, time.UTC),
+ false,
+ },
+ {
+ "Thu, 21 May 2020 14:04:40 XT",
+ time.Date(2020, 05, 21, 14, 04, 40, 0, time.UTC),
+ false,
+ },
+ {
+ "Thu, 21 May 2020 14:04:40 UTC",
+ time.Date(2020, 05, 21, 14, 04, 40, 0, time.UTC),
+ true,
+ },
+ {
+ "Fri, 21 Nov 1997 09:55:06 GMT (GMT)",
+ time.Date(1997, 11, 21, 9, 55, 6, 0, time.UTC),
+ true,
+ },
+ }
+ for _, test := range tests {
+ hdr := Header{
+ "Date": []string{test.dateStr},
+ }
+ date, err := hdr.Date()
+ if err != nil && test.valid {
+ t.Errorf("Header(Date: %s).Date(): %v", test.dateStr, err)
+ } else if err == nil && test.exp.IsZero() {
+ // OK. Used when exact result depends on the
+ // system's local zoneinfo.
+ } else if err == nil && !date.Equal(test.exp) && test.valid {
+ t.Errorf("Header(Date: %s).Date() = %+v, want %+v", test.dateStr, date, test.exp)
+ } else if err == nil && !test.valid { // an invalid expression was tested
+ t.Errorf("Header(Date: %s).Date() did not return an error but %v", test.dateStr, date)
+ }
+
+ date, err = ParseDate(test.dateStr)
+ if err != nil && test.valid {
+ t.Errorf("ParseDate(%s): %v", test.dateStr, err)
+ } else if err == nil && test.exp.IsZero() {
+ // OK. Used when exact result depends on the
+ // system's local zoneinfo.
+ } else if err == nil && !test.valid { // an invalid expression was tested
+ t.Errorf("ParseDate(%s) did not return an error but %v", test.dateStr, date)
+ } else if err == nil && test.valid && !date.Equal(test.exp) {
+ t.Errorf("ParseDate(%s) = %+v, want %+v", test.dateStr, date, test.exp)
+ }
+ }
+}
+
+func TestAddressParsingError(t *testing.T) {
+ mustErrTestCases := [...]struct {
+ text string
+ wantErrText string
+ }{
+ 0: {"=?iso-8859-2?Q?Bogl=E1rka_Tak=E1cs?= <unknown@gmail.com>", "charset not supported"},
+ 1: {"a@gmail.com b@gmail.com", "expected single address"},
+ 2: {string([]byte{0xed, 0xa0, 0x80}) + " <micro@example.net>", "invalid utf-8 in address"},
+ 3: {"\"" + string([]byte{0xed, 0xa0, 0x80}) + "\" <half-surrogate@example.com>", "invalid utf-8 in quoted-string"},
+ 4: {"\"\\" + string([]byte{0x80}) + "\" <escaped-invalid-unicode@example.net>", "invalid utf-8 in quoted-string"},
+ 5: {"\"\x00\" <null@example.net>", "bad character in quoted-string"},
+ 6: {"\"\\\x00\" <escaped-null@example.net>", "bad character in quoted-string"},
+ 7: {"John Doe", "no angle-addr"},
+ 8: {`<jdoe#machine.example>`, "missing @ in addr-spec"},
+ 9: {`John <middle> Doe <jdoe@machine.example>`, "missing @ in addr-spec"},
+ 10: {"cfws@example.com (", "misformatted parenthetical comment"},
+ 11: {"empty group: ;", "empty group"},
+ 12: {"root group: embed group: null@example.com;", "no angle-addr"},
+ 13: {"group not closed: null@example.com", "expected comma"},
+ 14: {"group: first@example.com, second@example.com;", "group with multiple addresses"},
+ 15: {"john.doe", "missing '@' or angle-addr"},
+ 16: {"john.doe@", "missing '@' or angle-addr"},
+ 17: {"John Doe@foo.bar", "no angle-addr"},
+ 18: {" group: null@example.com; (asd", "misformatted parenthetical comment"},
+ 19: {" group: ; (asd", "misformatted parenthetical comment"},
+ 20: {`(John) Doe <jdoe@machine.example>`, "missing word in phrase:"},
+ }
+
+ for i, tc := range mustErrTestCases {
+ _, err := ParseAddress(tc.text)
+ if err == nil || !strings.Contains(err.Error(), tc.wantErrText) {
+ t.Errorf(`mail.ParseAddress(%q) #%d want %q, got %v`, tc.text, i, tc.wantErrText, err)
+ }
+ }
+
+ t.Run("CustomWordDecoder", func(t *testing.T) {
+ p := &AddressParser{WordDecoder: &mime.WordDecoder{}}
+ for i, tc := range mustErrTestCases {
+ _, err := p.Parse(tc.text)
+ if err == nil || !strings.Contains(err.Error(), tc.wantErrText) {
+ t.Errorf(`p.Parse(%q) #%d want %q, got %v`, tc.text, i, tc.wantErrText, err)
+ }
+ }
+ })
+
+}
+
+func TestAddressParsing(t *testing.T) {
+ tests := []struct {
+ addrsStr string
+ exp []*Address
+ }{
+ // Bare address
+ {
+ `jdoe@machine.example`,
+ []*Address{{
+ Address: "jdoe@machine.example",
+ }},
+ },
+ // RFC 5322, Appendix A.1.1
+ {
+ `John Doe <jdoe@machine.example>`,
+ []*Address{{
+ Name: "John Doe",
+ Address: "jdoe@machine.example",
+ }},
+ },
+ // RFC 5322, Appendix A.1.2
+ {
+ `"Joe Q. Public" <john.q.public@example.com>`,
+ []*Address{{
+ Name: "Joe Q. Public",
+ Address: "john.q.public@example.com",
+ }},
+ },
+ // Comment in display name
+ {
+ `John (middle) Doe <jdoe@machine.example>`,
+ []*Address{{
+ Name: "John Doe",
+ Address: "jdoe@machine.example",
+ }},
+ },
+ // Display name is quoted string, so comment is not a comment
+ {
+ `"John (middle) Doe" <jdoe@machine.example>`,
+ []*Address{{
+ Name: "John (middle) Doe",
+ Address: "jdoe@machine.example",
+ }},
+ },
+ {
+ `"John <middle> Doe" <jdoe@machine.example>`,
+ []*Address{{
+ Name: "John <middle> Doe",
+ Address: "jdoe@machine.example",
+ }},
+ },
+ {
+ `Mary Smith <mary@x.test>, jdoe@example.org, Who? <one@y.test>`,
+ []*Address{
+ {
+ Name: "Mary Smith",
+ Address: "mary@x.test",
+ },
+ {
+ Address: "jdoe@example.org",
+ },
+ {
+ Name: "Who?",
+ Address: "one@y.test",
+ },
+ },
+ },
+ {
+ `<boss@nil.test>, "Giant; \"Big\" Box" <sysservices@example.net>`,
+ []*Address{
+ {
+ Address: "boss@nil.test",
+ },
+ {
+ Name: `Giant; "Big" Box`,
+ Address: "sysservices@example.net",
+ },
+ },
+ },
+ // RFC 5322, Appendix A.6.1
+ {
+ `Joe Q. Public <john.q.public@example.com>`,
+ []*Address{{
+ Name: "Joe Q. Public",
+ Address: "john.q.public@example.com",
+ }},
+ },
+ // RFC 5322, Appendix A.1.3
+ {
+ `group1: groupaddr1@example.com;`,
+ []*Address{
+ {
+ Name: "",
+ Address: "groupaddr1@example.com",
+ },
+ },
+ },
+ {
+ `empty group: ;`,
+ []*Address(nil),
+ },
+ {
+ `A Group:Ed Jones <c@a.test>,joe@where.test,John <jdoe@one.test>;`,
+ []*Address{
+ {
+ Name: "Ed Jones",
+ Address: "c@a.test",
+ },
+ {
+ Name: "",
+ Address: "joe@where.test",
+ },
+ {
+ Name: "John",
+ Address: "jdoe@one.test",
+ },
+ },
+ },
+ // RFC5322 4.4 obs-addr-list
+ {
+ ` , joe@where.test,,John <jdoe@one.test>,`,
+ []*Address{
+ {
+ Name: "",
+ Address: "joe@where.test",
+ },
+ {
+ Name: "John",
+ Address: "jdoe@one.test",
+ },
+ },
+ },
+ {
+ ` , joe@where.test,,John <jdoe@one.test>,,`,
+ []*Address{
+ {
+ Name: "",
+ Address: "joe@where.test",
+ },
+ {
+ Name: "John",
+ Address: "jdoe@one.test",
+ },
+ },
+ },
+ {
+ `Group1: <addr1@example.com>;, Group 2: addr2@example.com;, John <addr3@example.com>`,
+ []*Address{
+ {
+ Name: "",
+ Address: "addr1@example.com",
+ },
+ {
+ Name: "",
+ Address: "addr2@example.com",
+ },
+ {
+ Name: "John",
+ Address: "addr3@example.com",
+ },
+ },
+ },
+ // RFC 2047 "Q"-encoded ISO-8859-1 address.
+ {
+ `=?iso-8859-1?q?J=F6rg_Doe?= <joerg@example.com>`,
+ []*Address{
+ {
+ Name: `Jörg Doe`,
+ Address: "joerg@example.com",
+ },
+ },
+ },
+ // RFC 2047 "Q"-encoded US-ASCII address. Dumb but legal.
+ {
+ `=?us-ascii?q?J=6Frg_Doe?= <joerg@example.com>`,
+ []*Address{
+ {
+ Name: `Jorg Doe`,
+ Address: "joerg@example.com",
+ },
+ },
+ },
+ // RFC 2047 "Q"-encoded UTF-8 address.
+ {
+ `=?utf-8?q?J=C3=B6rg_Doe?= <joerg@example.com>`,
+ []*Address{
+ {
+ Name: `Jörg Doe`,
+ Address: "joerg@example.com",
+ },
+ },
+ },
+ // RFC 2047 "Q"-encoded UTF-8 address with multiple encoded-words.
+ {
+ `=?utf-8?q?J=C3=B6rg?= =?utf-8?q?Doe?= <joerg@example.com>`,
+ []*Address{
+ {
+ Name: `JörgDoe`,
+ Address: "joerg@example.com",
+ },
+ },
+ },
+ // RFC 2047, Section 8.
+ {
+ `=?ISO-8859-1?Q?Andr=E9?= Pirard <PIRARD@vm1.ulg.ac.be>`,
+ []*Address{
+ {
+ Name: `André Pirard`,
+ Address: "PIRARD@vm1.ulg.ac.be",
+ },
+ },
+ },
+ // Custom example of RFC 2047 "B"-encoded ISO-8859-1 address.
+ {
+ `=?ISO-8859-1?B?SvZyZw==?= <joerg@example.com>`,
+ []*Address{
+ {
+ Name: `Jörg`,
+ Address: "joerg@example.com",
+ },
+ },
+ },
+ // Custom example of RFC 2047 "B"-encoded UTF-8 address.
+ {
+ `=?UTF-8?B?SsO2cmc=?= <joerg@example.com>`,
+ []*Address{
+ {
+ Name: `Jörg`,
+ Address: "joerg@example.com",
+ },
+ },
+ },
+ // Custom example with "." in name. For issue 4938
+ {
+ `Asem H. <noreply@example.com>`,
+ []*Address{
+ {
+ Name: `Asem H.`,
+ Address: "noreply@example.com",
+ },
+ },
+ },
+ // RFC 6532 3.2.3, qtext /= UTF8-non-ascii
+ {
+ `"Gø Pher" <gopher@example.com>`,
+ []*Address{
+ {
+ Name: `Gø Pher`,
+ Address: "gopher@example.com",
+ },
+ },
+ },
+ // RFC 6532 3.2, atext /= UTF8-non-ascii
+ {
+ `µ <micro@example.com>`,
+ []*Address{
+ {
+ Name: `µ`,
+ Address: "micro@example.com",
+ },
+ },
+ },
+ // RFC 6532 3.2.2, local address parts allow UTF-8
+ {
+ `Micro <µ@example.com>`,
+ []*Address{
+ {
+ Name: `Micro`,
+ Address: "µ@example.com",
+ },
+ },
+ },
+ // RFC 6532 3.2.4, domains parts allow UTF-8
+ {
+ `Micro <micro@µ.example.com>`,
+ []*Address{
+ {
+ Name: `Micro`,
+ Address: "micro@µ.example.com",
+ },
+ },
+ },
+ // Issue 14866
+ {
+ `"" <emptystring@example.com>`,
+ []*Address{
+ {
+ Name: "",
+ Address: "emptystring@example.com",
+ },
+ },
+ },
+ // CFWS
+ {
+ `<cfws@example.com> (CFWS (cfws)) (another comment)`,
+ []*Address{
+ {
+ Name: "",
+ Address: "cfws@example.com",
+ },
+ },
+ },
+ {
+ `<cfws@example.com> () (another comment), <cfws2@example.com> (another)`,
+ []*Address{
+ {
+ Name: "",
+ Address: "cfws@example.com",
+ },
+ {
+ Name: "",
+ Address: "cfws2@example.com",
+ },
+ },
+ },
+ // Comment as display name
+ {
+ `john@example.com (John Doe)`,
+ []*Address{
+ {
+ Name: "John Doe",
+ Address: "john@example.com",
+ },
+ },
+ },
+ // Comment and display name
+ {
+ `John Doe <john@example.com> (Joey)`,
+ []*Address{
+ {
+ Name: "John Doe",
+ Address: "john@example.com",
+ },
+ },
+ },
+ // Comment as display name, no space
+ {
+ `john@example.com(John Doe)`,
+ []*Address{
+ {
+ Name: "John Doe",
+ Address: "john@example.com",
+ },
+ },
+ },
+ // Comment as display name, Q-encoded
+ {
+ `asjo@example.com (Adam =?utf-8?Q?Sj=C3=B8gren?=)`,
+ []*Address{
+ {
+ Name: "Adam Sjøgren",
+ Address: "asjo@example.com",
+ },
+ },
+ },
+ // Comment as display name, Q-encoded and tab-separated
+ {
+ `asjo@example.com (Adam =?utf-8?Q?Sj=C3=B8gren?=)`,
+ []*Address{
+ {
+ Name: "Adam Sjøgren",
+ Address: "asjo@example.com",
+ },
+ },
+ },
+ // Nested comment as display name, Q-encoded
+ {
+ `asjo@example.com (Adam =?utf-8?Q?Sj=C3=B8gren?= (Debian))`,
+ []*Address{
+ {
+ Name: "Adam Sjøgren (Debian)",
+ Address: "asjo@example.com",
+ },
+ },
+ },
+ // Comment in group display name
+ {
+ `group (comment:): a@example.com, b@example.com;`,
+ []*Address{
+ {
+ Address: "a@example.com",
+ },
+ {
+ Address: "b@example.com",
+ },
+ },
+ },
+ {
+ `x(:"):"@a.example;("@b.example;`,
+ []*Address{
+ {
+ Address: `@a.example;(@b.example`,
+ },
+ },
+ },
+ }
+ for _, test := range tests {
+ if len(test.exp) == 1 {
+ addr, err := ParseAddress(test.addrsStr)
+ if err != nil {
+ t.Errorf("Failed parsing (single) %q: %v", test.addrsStr, err)
+ continue
+ }
+ if !reflect.DeepEqual([]*Address{addr}, test.exp) {
+ t.Errorf("Parse (single) of %q: got %+v, want %+v", test.addrsStr, addr, test.exp)
+ }
+ }
+
+ addrs, err := ParseAddressList(test.addrsStr)
+ if err != nil {
+ t.Errorf("Failed parsing (list) %q: %v", test.addrsStr, err)
+ continue
+ }
+ if !reflect.DeepEqual(addrs, test.exp) {
+ t.Errorf("Parse (list) of %q: got %+v, want %+v", test.addrsStr, addrs, test.exp)
+ }
+ }
+}
+
+func TestAddressParser(t *testing.T) {
+ tests := []struct {
+ addrsStr string
+ exp []*Address
+ }{
+ // Bare address
+ {
+ `jdoe@machine.example`,
+ []*Address{{
+ Address: "jdoe@machine.example",
+ }},
+ },
+ // RFC 5322, Appendix A.1.1
+ {
+ `John Doe <jdoe@machine.example>`,
+ []*Address{{
+ Name: "John Doe",
+ Address: "jdoe@machine.example",
+ }},
+ },
+ // RFC 5322, Appendix A.1.2
+ {
+ `"Joe Q. Public" <john.q.public@example.com>`,
+ []*Address{{
+ Name: "Joe Q. Public",
+ Address: "john.q.public@example.com",
+ }},
+ },
+ {
+ `Mary Smith <mary@x.test>, jdoe@example.org, Who? <one@y.test>`,
+ []*Address{
+ {
+ Name: "Mary Smith",
+ Address: "mary@x.test",
+ },
+ {
+ Address: "jdoe@example.org",
+ },
+ {
+ Name: "Who?",
+ Address: "one@y.test",
+ },
+ },
+ },
+ {
+ `<boss@nil.test>, "Giant; \"Big\" Box" <sysservices@example.net>`,
+ []*Address{
+ {
+ Address: "boss@nil.test",
+ },
+ {
+ Name: `Giant; "Big" Box`,
+ Address: "sysservices@example.net",
+ },
+ },
+ },
+ // RFC 2047 "Q"-encoded ISO-8859-1 address.
+ {
+ `=?iso-8859-1?q?J=F6rg_Doe?= <joerg@example.com>`,
+ []*Address{
+ {
+ Name: `Jörg Doe`,
+ Address: "joerg@example.com",
+ },
+ },
+ },
+ // RFC 2047 "Q"-encoded US-ASCII address. Dumb but legal.
+ {
+ `=?us-ascii?q?J=6Frg_Doe?= <joerg@example.com>`,
+ []*Address{
+ {
+ Name: `Jorg Doe`,
+ Address: "joerg@example.com",
+ },
+ },
+ },
+ // RFC 2047 "Q"-encoded ISO-8859-15 address.
+ {
+ `=?ISO-8859-15?Q?J=F6rg_Doe?= <joerg@example.com>`,
+ []*Address{
+ {
+ Name: `Jörg Doe`,
+ Address: "joerg@example.com",
+ },
+ },
+ },
+ // RFC 2047 "B"-encoded windows-1252 address.
+ {
+ `=?windows-1252?q?Andr=E9?= Pirard <PIRARD@vm1.ulg.ac.be>`,
+ []*Address{
+ {
+ Name: `André Pirard`,
+ Address: "PIRARD@vm1.ulg.ac.be",
+ },
+ },
+ },
+ // Custom example of RFC 2047 "B"-encoded ISO-8859-15 address.
+ {
+ `=?ISO-8859-15?B?SvZyZw==?= <joerg@example.com>`,
+ []*Address{
+ {
+ Name: `Jörg`,
+ Address: "joerg@example.com",
+ },
+ },
+ },
+ // Custom example of RFC 2047 "B"-encoded UTF-8 address.
+ {
+ `=?UTF-8?B?SsO2cmc=?= <joerg@example.com>`,
+ []*Address{
+ {
+ Name: `Jörg`,
+ Address: "joerg@example.com",
+ },
+ },
+ },
+ // Custom example with "." in name. For issue 4938
+ {
+ `Asem H. <noreply@example.com>`,
+ []*Address{
+ {
+ Name: `Asem H.`,
+ Address: "noreply@example.com",
+ },
+ },
+ },
+ }
+
+ ap := AddressParser{WordDecoder: &mime.WordDecoder{
+ CharsetReader: func(charset string, input io.Reader) (io.Reader, error) {
+ in, err := io.ReadAll(input)
+ if err != nil {
+ return nil, err
+ }
+
+ switch charset {
+ case "iso-8859-15":
+ in = bytes.ReplaceAll(in, []byte("\xf6"), []byte("ö"))
+ case "windows-1252":
+ in = bytes.ReplaceAll(in, []byte("\xe9"), []byte("é"))
+ }
+
+ return bytes.NewReader(in), nil
+ },
+ }}
+
+ for _, test := range tests {
+ if len(test.exp) == 1 {
+ addr, err := ap.Parse(test.addrsStr)
+ if err != nil {
+ t.Errorf("Failed parsing (single) %q: %v", test.addrsStr, err)
+ continue
+ }
+ if !reflect.DeepEqual([]*Address{addr}, test.exp) {
+ t.Errorf("Parse (single) of %q: got %+v, want %+v", test.addrsStr, addr, test.exp)
+ }
+ }
+
+ addrs, err := ap.ParseList(test.addrsStr)
+ if err != nil {
+ t.Errorf("Failed parsing (list) %q: %v", test.addrsStr, err)
+ continue
+ }
+ if !reflect.DeepEqual(addrs, test.exp) {
+ t.Errorf("Parse (list) of %q: got %+v, want %+v", test.addrsStr, addrs, test.exp)
+ }
+ }
+}
+
+func TestAddressString(t *testing.T) {
+ tests := []struct {
+ addr *Address
+ exp string
+ }{
+ {
+ &Address{Address: "bob@example.com"},
+ "<bob@example.com>",
+ },
+ { // quoted local parts: RFC 5322, 3.4.1. and 3.2.4.
+ &Address{Address: `my@idiot@address@example.com`},
+ `<"my@idiot@address"@example.com>`,
+ },
+ { // quoted local parts
+ &Address{Address: ` @example.com`},
+ `<" "@example.com>`,
+ },
+ {
+ &Address{Name: "Bob", Address: "bob@example.com"},
+ `"Bob" <bob@example.com>`,
+ },
+ {
+ // note the ö (o with an umlaut)
+ &Address{Name: "Böb", Address: "bob@example.com"},
+ `=?utf-8?q?B=C3=B6b?= <bob@example.com>`,
+ },
+ {
+ &Address{Name: "Bob Jane", Address: "bob@example.com"},
+ `"Bob Jane" <bob@example.com>`,
+ },
+ {
+ &Address{Name: "Böb Jacöb", Address: "bob@example.com"},
+ `=?utf-8?q?B=C3=B6b_Jac=C3=B6b?= <bob@example.com>`,
+ },
+ { // https://golang.org/issue/12098
+ &Address{Name: "Rob", Address: ""},
+ `"Rob" <@>`,
+ },
+ { // https://golang.org/issue/12098
+ &Address{Name: "Rob", Address: "@"},
+ `"Rob" <@>`,
+ },
+ {
+ &Address{Name: "Böb, Jacöb", Address: "bob@example.com"},
+ `=?utf-8?b?QsO2YiwgSmFjw7Zi?= <bob@example.com>`,
+ },
+ {
+ &Address{Name: "=??Q?x?=", Address: "hello@world.com"},
+ `"=??Q?x?=" <hello@world.com>`,
+ },
+ {
+ &Address{Name: "=?hello", Address: "hello@world.com"},
+ `"=?hello" <hello@world.com>`,
+ },
+ {
+ &Address{Name: "world?=", Address: "hello@world.com"},
+ `"world?=" <hello@world.com>`,
+ },
+ {
+ // should q-encode even for invalid utf-8.
+ &Address{Name: string([]byte{0xed, 0xa0, 0x80}), Address: "invalid-utf8@example.net"},
+ "=?utf-8?q?=ED=A0=80?= <invalid-utf8@example.net>",
+ },
+ }
+ for _, test := range tests {
+ s := test.addr.String()
+ if s != test.exp {
+ t.Errorf("Address%+v.String() = %v, want %v", *test.addr, s, test.exp)
+ continue
+ }
+
+ // Check round-trip.
+ if test.addr.Address != "" && test.addr.Address != "@" {
+ a, err := ParseAddress(test.exp)
+ if err != nil {
+ t.Errorf("ParseAddress(%#q): %v", test.exp, err)
+ continue
+ }
+ if a.Name != test.addr.Name || a.Address != test.addr.Address {
+ t.Errorf("ParseAddress(%#q) = %#v, want %#v", test.exp, a, test.addr)
+ }
+ }
+ }
+}
+
+// Check if all valid addresses can be parsed, formatted and parsed again
+func TestAddressParsingAndFormatting(t *testing.T) {
+
+ // Should pass
+ tests := []string{
+ `<Bob@example.com>`,
+ `<bob.bob@example.com>`,
+ `<".bob"@example.com>`,
+ `<" "@example.com>`,
+ `<some.mail-with-dash@example.com>`,
+ `<"dot.and space"@example.com>`,
+ `<"very.unusual.@.unusual.com"@example.com>`,
+ `<admin@mailserver1>`,
+ `<postmaster@localhost>`,
+ "<#!$%&'*+-/=?^_`{}|~@example.org>",
+ `<"very.(),:;<>[]\".VERY.\"very@\\ \"very\".unusual"@strange.example.com>`, // escaped quotes
+ `<"()<>[]:,;@\\\"!#$%&'*+-/=?^_{}| ~.a"@example.org>`, // escaped backslashes
+ `<"Abc\\@def"@example.com>`,
+ `<"Joe\\Blow"@example.com>`,
+ `<test1/test2=test3@example.com>`,
+ `<def!xyz%abc@example.com>`,
+ `<_somename@example.com>`,
+ `<joe@uk>`,
+ `<~@example.com>`,
+ `<"..."@test.com>`,
+ `<"john..doe"@example.com>`,
+ `<"john.doe."@example.com>`,
+ `<".john.doe"@example.com>`,
+ `<"."@example.com>`,
+ `<".."@example.com>`,
+ `<"0:"@0>`,
+ }
+
+ for _, test := range tests {
+ addr, err := ParseAddress(test)
+ if err != nil {
+ t.Errorf("Couldn't parse address %s: %s", test, err.Error())
+ continue
+ }
+ str := addr.String()
+ addr, err = ParseAddress(str)
+ if err != nil {
+ t.Errorf("ParseAddr(%q) error: %v", test, err)
+ continue
+ }
+
+ if addr.String() != test {
+ t.Errorf("String() round-trip = %q; want %q", addr, test)
+ continue
+ }
+
+ }
+
+ // Should fail
+ badTests := []string{
+ `<Abc.example.com>`,
+ `<A@b@c@example.com>`,
+ `<a"b(c)d,e:f;g<h>i[j\k]l@example.com>`,
+ `<just"not"right@example.com>`,
+ `<this is"not\allowed@example.com>`,
+ `<this\ still\"not\\allowed@example.com>`,
+ `<john..doe@example.com>`,
+ `<john.doe@example..com>`,
+ `<john.doe@example..com>`,
+ `<john.doe.@example.com>`,
+ `<john.doe.@.example.com>`,
+ `<.john.doe@example.com>`,
+ `<@example.com>`,
+ `<.@example.com>`,
+ `<test@.>`,
+ `< @example.com>`,
+ `<""test""blah""@example.com>`,
+ `<""@0>`,
+ }
+
+ for _, test := range badTests {
+ _, err := ParseAddress(test)
+ if err == nil {
+ t.Errorf("Should have failed to parse address: %s", test)
+ continue
+ }
+
+ }
+
+}
+
+func TestAddressFormattingAndParsing(t *testing.T) {
+ tests := []*Address{
+ {Name: "@lïce", Address: "alice@example.com"},
+ {Name: "Böb O'Connor", Address: "bob@example.com"},
+ {Name: "???", Address: "bob@example.com"},
+ {Name: "Böb ???", Address: "bob@example.com"},
+ {Name: "Böb (Jacöb)", Address: "bob@example.com"},
+ {Name: "à#$%&'(),.:;<>@[]^`{|}~'", Address: "bob@example.com"},
+ // https://golang.org/issue/11292
+ {Name: "\"\\\x1f,\"", Address: "0@0"},
+ // https://golang.org/issue/12782
+ {Name: "naé, mée", Address: "test.mail@gmail.com"},
+ }
+
+ for i, test := range tests {
+ parsed, err := ParseAddress(test.String())
+ if err != nil {
+ t.Errorf("test #%d: ParseAddr(%q) error: %v", i, test.String(), err)
+ continue
+ }
+ if parsed.Name != test.Name {
+ t.Errorf("test #%d: Parsed name = %q; want %q", i, parsed.Name, test.Name)
+ }
+ if parsed.Address != test.Address {
+ t.Errorf("test #%d: Parsed address = %q; want %q", i, parsed.Address, test.Address)
+ }
+ }
+}
+
+func TestEmptyAddress(t *testing.T) {
+ parsed, err := ParseAddress("")
+ if parsed != nil || err == nil {
+ t.Errorf(`ParseAddress("") = %v, %v, want nil, error`, parsed, err)
+ }
+ list, err := ParseAddressList("")
+ if len(list) > 0 || err == nil {
+ t.Errorf(`ParseAddressList("") = %v, %v, want nil, error`, list, err)
+ }
+ list, err = ParseAddressList(",")
+ if len(list) > 0 || err == nil {
+ t.Errorf(`ParseAddressList("") = %v, %v, want nil, error`, list, err)
+ }
+ list, err = ParseAddressList("a@b c@d")
+ if len(list) > 0 || err == nil {
+ t.Errorf(`ParseAddressList("") = %v, %v, want nil, error`, list, err)
+ }
+}
diff --git a/src/net/main_cloexec_test.go b/src/net/main_cloexec_test.go
new file mode 100644
index 0000000..6ea99ad
--- /dev/null
+++ b/src/net/main_cloexec_test.go
@@ -0,0 +1,27 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build dragonfly || freebsd || linux || netbsd || openbsd || solaris
+
+package net
+
+import "internal/poll"
+
+func init() {
+ extraTestHookInstallers = append(extraTestHookInstallers, installAccept4TestHook)
+ extraTestHookUninstallers = append(extraTestHookUninstallers, uninstallAccept4TestHook)
+}
+
+var (
+ // Placeholders for saving original socket system calls.
+ origAccept4 = poll.Accept4Func
+)
+
+func installAccept4TestHook() {
+ poll.Accept4Func = sw.Accept4
+}
+
+func uninstallAccept4TestHook() {
+ poll.Accept4Func = origAccept4
+}
diff --git a/src/net/main_conf_test.go b/src/net/main_conf_test.go
new file mode 100644
index 0000000..28a1cb8
--- /dev/null
+++ b/src/net/main_conf_test.go
@@ -0,0 +1,59 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !js && !plan9 && !wasip1
+
+package net
+
+import "testing"
+
+// forceGoDNS forces the resolver configuration to use the pure Go resolver
+// and returns a fixup function to restore the old settings.
+func forceGoDNS() func() {
+ c := systemConf()
+ oldGo := c.netGo
+ oldCgo := c.netCgo
+ fixup := func() {
+ c.netGo = oldGo
+ c.netCgo = oldCgo
+ }
+ c.netGo = true
+ c.netCgo = false
+ return fixup
+}
+
+// forceCgoDNS forces the resolver configuration to use the cgo resolver
+// and returns a fixup function to restore the old settings.
+// (On non-Unix systems forceCgoDNS returns nil.)
+func forceCgoDNS() func() {
+ c := systemConf()
+ oldGo := c.netGo
+ oldCgo := c.netCgo
+ fixup := func() {
+ c.netGo = oldGo
+ c.netCgo = oldCgo
+ }
+ c.netGo = false
+ c.netCgo = true
+ return fixup
+}
+
+func TestForceCgoDNS(t *testing.T) {
+ if !cgoAvailable {
+ t.Skip("cgo resolver not available")
+ }
+ defer forceCgoDNS()()
+ order, _ := systemConf().hostLookupOrder(nil, "go.dev")
+ if order != hostLookupCgo {
+ t.Fatalf("hostLookupOrder returned: %v, want cgo", order)
+ }
+}
+
+func TestForceGoDNS(t *testing.T) {
+ defer forceGoDNS()()
+ order, _ := systemConf().hostLookupOrder(nil, "go.dev")
+ if order == hostLookupCgo {
+ t.Fatalf("hostLookupOrder returned: %v, want go resolver order", order)
+ }
+}
diff --git a/src/net/main_noconf_test.go b/src/net/main_noconf_test.go
new file mode 100644
index 0000000..077a36e
--- /dev/null
+++ b/src/net/main_noconf_test.go
@@ -0,0 +1,22 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build (js && wasm) || plan9 || wasip1
+
+package net
+
+import "runtime"
+
+// See main_conf_test.go for what these (don't) do.
+func forceGoDNS() func() {
+ switch runtime.GOOS {
+ case "plan9":
+ return func() {}
+ default:
+ return nil
+ }
+}
+
+// See main_conf_test.go for what these (don't) do.
+func forceCgoDNS() func() { return nil }
diff --git a/src/net/main_plan9_test.go b/src/net/main_plan9_test.go
new file mode 100644
index 0000000..2bc5be8
--- /dev/null
+++ b/src/net/main_plan9_test.go
@@ -0,0 +1,16 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+func installTestHooks() {}
+
+func uninstallTestHooks() {}
+
+// forceCloseSockets must be called only from TestMain.
+func forceCloseSockets() {}
+
+func enableSocketConnect() {}
+
+func disableSocketConnect(network string) {}
diff --git a/src/net/main_posix_test.go b/src/net/main_posix_test.go
new file mode 100644
index 0000000..a7942ee
--- /dev/null
+++ b/src/net/main_posix_test.go
@@ -0,0 +1,50 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !js && !plan9 && !wasip1
+
+package net
+
+import (
+ "net/internal/socktest"
+ "strings"
+ "syscall"
+)
+
+func enableSocketConnect() {
+ sw.Set(socktest.FilterConnect, nil)
+}
+
+func disableSocketConnect(network string) {
+ net, _, _ := strings.Cut(network, ":")
+ sw.Set(socktest.FilterConnect, func(so *socktest.Status) (socktest.AfterFilter, error) {
+ switch net {
+ case "tcp4":
+ if so.Cookie.Family() == syscall.AF_INET && so.Cookie.Type() == syscall.SOCK_STREAM {
+ return nil, syscall.EHOSTUNREACH
+ }
+ case "udp4":
+ if so.Cookie.Family() == syscall.AF_INET && so.Cookie.Type() == syscall.SOCK_DGRAM {
+ return nil, syscall.EHOSTUNREACH
+ }
+ case "ip4":
+ if so.Cookie.Family() == syscall.AF_INET && so.Cookie.Type() == syscall.SOCK_RAW {
+ return nil, syscall.EHOSTUNREACH
+ }
+ case "tcp6":
+ if so.Cookie.Family() == syscall.AF_INET6 && so.Cookie.Type() == syscall.SOCK_STREAM {
+ return nil, syscall.EHOSTUNREACH
+ }
+ case "udp6":
+ if so.Cookie.Family() == syscall.AF_INET6 && so.Cookie.Type() == syscall.SOCK_DGRAM {
+ return nil, syscall.EHOSTUNREACH
+ }
+ case "ip6":
+ if so.Cookie.Family() == syscall.AF_INET6 && so.Cookie.Type() == syscall.SOCK_RAW {
+ return nil, syscall.EHOSTUNREACH
+ }
+ }
+ return nil, nil
+ })
+}
diff --git a/src/net/main_test.go b/src/net/main_test.go
new file mode 100644
index 0000000..9fd5c88
--- /dev/null
+++ b/src/net/main_test.go
@@ -0,0 +1,209 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !js && !wasip1
+
+package net
+
+import (
+ "flag"
+ "fmt"
+ "net/internal/socktest"
+ "os"
+ "runtime"
+ "sort"
+ "strings"
+ "sync"
+ "testing"
+)
+
+var (
+ sw socktest.Switch
+
+ // uninstallTestHooks runs just before a run of benchmarks.
+ testHookUninstaller sync.Once
+)
+
+var (
+ testTCPBig = flag.Bool("tcpbig", false, "whether to test massive size of data per read or write call on TCP connection")
+
+ testDNSFlood = flag.Bool("dnsflood", false, "whether to test DNS query flooding")
+
+ // If external IPv4 connectivity exists, we can try dialing
+ // non-node/interface local scope IPv4 addresses.
+ // On Windows, Lookup APIs may not return IPv4-related
+ // resource records when a node has no external IPv4
+ // connectivity.
+ testIPv4 = flag.Bool("ipv4", true, "assume external IPv4 connectivity exists")
+
+ // If external IPv6 connectivity exists, we can try dialing
+ // non-node/interface local scope IPv6 addresses.
+ // On Windows, Lookup APIs may not return IPv6-related
+ // resource records when a node has no external IPv6
+ // connectivity.
+ testIPv6 = flag.Bool("ipv6", false, "assume external IPv6 connectivity exists")
+)
+
+func TestMain(m *testing.M) {
+ setupTestData()
+ installTestHooks()
+
+ st := m.Run()
+
+ testHookUninstaller.Do(uninstallTestHooks)
+ if testing.Verbose() {
+ printRunningGoroutines()
+ printInflightSockets()
+ printSocketStats()
+ }
+ forceCloseSockets()
+ os.Exit(st)
+}
+
+type ipv6LinkLocalUnicastTest struct {
+ network, address string
+ nameLookup bool
+}
+
+var (
+ ipv6LinkLocalUnicastTCPTests []ipv6LinkLocalUnicastTest
+ ipv6LinkLocalUnicastUDPTests []ipv6LinkLocalUnicastTest
+)
+
+func setupTestData() {
+ if supportsIPv4() {
+ resolveTCPAddrTests = append(resolveTCPAddrTests, []resolveTCPAddrTest{
+ {"tcp", "localhost:1", &TCPAddr{IP: IPv4(127, 0, 0, 1), Port: 1}, nil},
+ {"tcp4", "localhost:2", &TCPAddr{IP: IPv4(127, 0, 0, 1), Port: 2}, nil},
+ }...)
+ resolveUDPAddrTests = append(resolveUDPAddrTests, []resolveUDPAddrTest{
+ {"udp", "localhost:1", &UDPAddr{IP: IPv4(127, 0, 0, 1), Port: 1}, nil},
+ {"udp4", "localhost:2", &UDPAddr{IP: IPv4(127, 0, 0, 1), Port: 2}, nil},
+ }...)
+ resolveIPAddrTests = append(resolveIPAddrTests, []resolveIPAddrTest{
+ {"ip", "localhost", &IPAddr{IP: IPv4(127, 0, 0, 1)}, nil},
+ {"ip4", "localhost", &IPAddr{IP: IPv4(127, 0, 0, 1)}, nil},
+ }...)
+ }
+
+ if supportsIPv6() {
+ resolveTCPAddrTests = append(resolveTCPAddrTests, resolveTCPAddrTest{"tcp6", "localhost:3", &TCPAddr{IP: IPv6loopback, Port: 3}, nil})
+ resolveUDPAddrTests = append(resolveUDPAddrTests, resolveUDPAddrTest{"udp6", "localhost:3", &UDPAddr{IP: IPv6loopback, Port: 3}, nil})
+ resolveIPAddrTests = append(resolveIPAddrTests, resolveIPAddrTest{"ip6", "localhost", &IPAddr{IP: IPv6loopback}, nil})
+
+ // Issue 20911: don't return IPv4 addresses for
+ // Resolve*Addr calls of the IPv6 unspecified address.
+ resolveTCPAddrTests = append(resolveTCPAddrTests, resolveTCPAddrTest{"tcp", "[::]:4", &TCPAddr{IP: IPv6unspecified, Port: 4}, nil})
+ resolveUDPAddrTests = append(resolveUDPAddrTests, resolveUDPAddrTest{"udp", "[::]:4", &UDPAddr{IP: IPv6unspecified, Port: 4}, nil})
+ resolveIPAddrTests = append(resolveIPAddrTests, resolveIPAddrTest{"ip", "::", &IPAddr{IP: IPv6unspecified}, nil})
+ }
+
+ ifi := loopbackInterface()
+ if ifi != nil {
+ index := fmt.Sprintf("%v", ifi.Index)
+ resolveTCPAddrTests = append(resolveTCPAddrTests, []resolveTCPAddrTest{
+ {"tcp6", "[fe80::1%" + ifi.Name + "]:1", &TCPAddr{IP: ParseIP("fe80::1"), Port: 1, Zone: zoneCache.name(ifi.Index)}, nil},
+ {"tcp6", "[fe80::1%" + index + "]:2", &TCPAddr{IP: ParseIP("fe80::1"), Port: 2, Zone: index}, nil},
+ }...)
+ resolveUDPAddrTests = append(resolveUDPAddrTests, []resolveUDPAddrTest{
+ {"udp6", "[fe80::1%" + ifi.Name + "]:1", &UDPAddr{IP: ParseIP("fe80::1"), Port: 1, Zone: zoneCache.name(ifi.Index)}, nil},
+ {"udp6", "[fe80::1%" + index + "]:2", &UDPAddr{IP: ParseIP("fe80::1"), Port: 2, Zone: index}, nil},
+ }...)
+ resolveIPAddrTests = append(resolveIPAddrTests, []resolveIPAddrTest{
+ {"ip6", "fe80::1%" + ifi.Name, &IPAddr{IP: ParseIP("fe80::1"), Zone: zoneCache.name(ifi.Index)}, nil},
+ {"ip6", "fe80::1%" + index, &IPAddr{IP: ParseIP("fe80::1"), Zone: index}, nil},
+ }...)
+ }
+
+ addr := ipv6LinkLocalUnicastAddr(ifi)
+ if addr != "" {
+ if runtime.GOOS != "dragonfly" {
+ ipv6LinkLocalUnicastTCPTests = append(ipv6LinkLocalUnicastTCPTests, []ipv6LinkLocalUnicastTest{
+ {"tcp", "[" + addr + "%" + ifi.Name + "]:0", false},
+ }...)
+ ipv6LinkLocalUnicastUDPTests = append(ipv6LinkLocalUnicastUDPTests, []ipv6LinkLocalUnicastTest{
+ {"udp", "[" + addr + "%" + ifi.Name + "]:0", false},
+ }...)
+ }
+ ipv6LinkLocalUnicastTCPTests = append(ipv6LinkLocalUnicastTCPTests, []ipv6LinkLocalUnicastTest{
+ {"tcp6", "[" + addr + "%" + ifi.Name + "]:0", false},
+ }...)
+ ipv6LinkLocalUnicastUDPTests = append(ipv6LinkLocalUnicastUDPTests, []ipv6LinkLocalUnicastTest{
+ {"udp6", "[" + addr + "%" + ifi.Name + "]:0", false},
+ }...)
+ switch runtime.GOOS {
+ case "darwin", "ios", "dragonfly", "freebsd", "openbsd", "netbsd":
+ ipv6LinkLocalUnicastTCPTests = append(ipv6LinkLocalUnicastTCPTests, []ipv6LinkLocalUnicastTest{
+ {"tcp", "[localhost%" + ifi.Name + "]:0", true},
+ {"tcp6", "[localhost%" + ifi.Name + "]:0", true},
+ }...)
+ ipv6LinkLocalUnicastUDPTests = append(ipv6LinkLocalUnicastUDPTests, []ipv6LinkLocalUnicastTest{
+ {"udp", "[localhost%" + ifi.Name + "]:0", true},
+ {"udp6", "[localhost%" + ifi.Name + "]:0", true},
+ }...)
+ case "linux":
+ ipv6LinkLocalUnicastTCPTests = append(ipv6LinkLocalUnicastTCPTests, []ipv6LinkLocalUnicastTest{
+ {"tcp", "[ip6-localhost%" + ifi.Name + "]:0", true},
+ {"tcp6", "[ip6-localhost%" + ifi.Name + "]:0", true},
+ }...)
+ ipv6LinkLocalUnicastUDPTests = append(ipv6LinkLocalUnicastUDPTests, []ipv6LinkLocalUnicastTest{
+ {"udp", "[ip6-localhost%" + ifi.Name + "]:0", true},
+ {"udp6", "[ip6-localhost%" + ifi.Name + "]:0", true},
+ }...)
+ }
+ }
+}
+
+func printRunningGoroutines() {
+ gss := runningGoroutines()
+ if len(gss) == 0 {
+ return
+ }
+ fmt.Fprintf(os.Stderr, "Running goroutines:\n")
+ for _, gs := range gss {
+ fmt.Fprintf(os.Stderr, "%v\n", gs)
+ }
+ fmt.Fprintf(os.Stderr, "\n")
+}
+
+// runningGoroutines returns a list of remaining goroutines.
+func runningGoroutines() []string {
+ var gss []string
+ b := make([]byte, 2<<20)
+ b = b[:runtime.Stack(b, true)]
+ for _, s := range strings.Split(string(b), "\n\n") {
+ _, stack, _ := strings.Cut(s, "\n")
+ stack = strings.TrimSpace(stack)
+ if !strings.Contains(stack, "created by net") {
+ continue
+ }
+ gss = append(gss, stack)
+ }
+ sort.Strings(gss)
+ return gss
+}
+
+func printInflightSockets() {
+ sos := sw.Sockets()
+ if len(sos) == 0 {
+ return
+ }
+ fmt.Fprintf(os.Stderr, "Inflight sockets:\n")
+ for s, so := range sos {
+ fmt.Fprintf(os.Stderr, "%v: %v\n", s, so)
+ }
+ fmt.Fprintf(os.Stderr, "\n")
+}
+
+func printSocketStats() {
+ sts := sw.Stats()
+ if len(sts) == 0 {
+ return
+ }
+ fmt.Fprintf(os.Stderr, "Socket statistical information:\n")
+ for _, st := range sts {
+ fmt.Fprintf(os.Stderr, "%v\n", st)
+ }
+ fmt.Fprintf(os.Stderr, "\n")
+}
diff --git a/src/net/main_unix_test.go b/src/net/main_unix_test.go
new file mode 100644
index 0000000..e7a5b4f
--- /dev/null
+++ b/src/net/main_unix_test.go
@@ -0,0 +1,55 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build unix
+
+package net
+
+import "internal/poll"
+
+var (
+ // Placeholders for saving original socket system calls.
+ origSocket = socketFunc
+ origClose = poll.CloseFunc
+ origConnect = connectFunc
+ origListen = listenFunc
+ origAccept = poll.AcceptFunc
+ origGetsockoptInt = getsockoptIntFunc
+
+ extraTestHookInstallers []func()
+ extraTestHookUninstallers []func()
+)
+
+func installTestHooks() {
+ socketFunc = sw.Socket
+ poll.CloseFunc = sw.Close
+ connectFunc = sw.Connect
+ listenFunc = sw.Listen
+ poll.AcceptFunc = sw.Accept
+ getsockoptIntFunc = sw.GetsockoptInt
+
+ for _, fn := range extraTestHookInstallers {
+ fn()
+ }
+}
+
+func uninstallTestHooks() {
+ socketFunc = origSocket
+ poll.CloseFunc = origClose
+ connectFunc = origConnect
+ listenFunc = origListen
+ poll.AcceptFunc = origAccept
+ getsockoptIntFunc = origGetsockoptInt
+
+ for _, fn := range extraTestHookUninstallers {
+ fn()
+ }
+}
+
+// forceCloseSockets must be called only from TestMain.
+func forceCloseSockets() {
+ for s := range sw.Sockets() {
+ poll.CloseFunc(s)
+ }
+}
diff --git a/src/net/main_windows_test.go b/src/net/main_windows_test.go
new file mode 100644
index 0000000..07f21b7
--- /dev/null
+++ b/src/net/main_windows_test.go
@@ -0,0 +1,45 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import "internal/poll"
+
+var (
+ // Placeholders for saving original socket system calls.
+ origSocket = socketFunc
+ origWSASocket = wsaSocketFunc
+ origClosesocket = poll.CloseFunc
+ origConnect = connectFunc
+ origConnectEx = poll.ConnectExFunc
+ origListen = listenFunc
+ origAccept = poll.AcceptFunc
+)
+
+func installTestHooks() {
+ socketFunc = sw.Socket
+ wsaSocketFunc = sw.WSASocket
+ poll.CloseFunc = sw.Closesocket
+ connectFunc = sw.Connect
+ poll.ConnectExFunc = sw.ConnectEx
+ listenFunc = sw.Listen
+ poll.AcceptFunc = sw.AcceptEx
+}
+
+func uninstallTestHooks() {
+ socketFunc = origSocket
+ wsaSocketFunc = origWSASocket
+ poll.CloseFunc = origClosesocket
+ connectFunc = origConnect
+ poll.ConnectExFunc = origConnectEx
+ listenFunc = origListen
+ poll.AcceptFunc = origAccept
+}
+
+// forceCloseSockets must be called only from TestMain.
+func forceCloseSockets() {
+ for s := range sw.Sockets() {
+ poll.CloseFunc(s)
+ }
+}
diff --git a/src/net/mockserver_test.go b/src/net/mockserver_test.go
new file mode 100644
index 0000000..f86dd66
--- /dev/null
+++ b/src/net/mockserver_test.go
@@ -0,0 +1,510 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !js && !wasip1
+
+package net
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "os"
+ "path/filepath"
+ "sync"
+ "testing"
+ "time"
+)
+
+// testUnixAddr uses os.MkdirTemp to get a name that is unique.
+func testUnixAddr(t testing.TB) string {
+ // Pass an empty pattern to get a directory name that is as short as possible.
+ // If we end up with a name longer than the sun_path field in the sockaddr_un
+ // struct, we won't be able to make the syscall to open the socket.
+ d, err := os.MkdirTemp("", "")
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Cleanup(func() {
+ if err := os.RemoveAll(d); err != nil {
+ t.Error(err)
+ }
+ })
+ return filepath.Join(d, "sock")
+}
+
+func newLocalListener(t testing.TB, network string, lcOpt ...*ListenConfig) Listener {
+ var lc *ListenConfig
+ switch len(lcOpt) {
+ case 0:
+ lc = new(ListenConfig)
+ case 1:
+ lc = lcOpt[0]
+ default:
+ t.Helper()
+ t.Fatal("too many ListenConfigs passed to newLocalListener: want 0 or 1")
+ }
+
+ listen := func(net, addr string) Listener {
+ ln, err := lc.Listen(context.Background(), net, addr)
+ if err != nil {
+ t.Helper()
+ t.Fatal(err)
+ }
+ return ln
+ }
+
+ switch network {
+ case "tcp":
+ if supportsIPv4() {
+ if !supportsIPv6() {
+ return listen("tcp4", "127.0.0.1:0")
+ }
+ if ln, err := Listen("tcp4", "127.0.0.1:0"); err == nil {
+ return ln
+ }
+ }
+ if supportsIPv6() {
+ return listen("tcp6", "[::1]:0")
+ }
+ case "tcp4":
+ if supportsIPv4() {
+ return listen("tcp4", "127.0.0.1:0")
+ }
+ case "tcp6":
+ if supportsIPv6() {
+ return listen("tcp6", "[::1]:0")
+ }
+ case "unix", "unixpacket":
+ return listen(network, testUnixAddr(t))
+ }
+
+ t.Helper()
+ t.Fatalf("%s is not supported", network)
+ return nil
+}
+
+func newDualStackListener() (lns []*TCPListener, err error) {
+ var args = []struct {
+ network string
+ TCPAddr
+ }{
+ {"tcp4", TCPAddr{IP: IPv4(127, 0, 0, 1)}},
+ {"tcp6", TCPAddr{IP: IPv6loopback}},
+ }
+ for i := 0; i < 64; i++ {
+ var port int
+ var lns []*TCPListener
+ for _, arg := range args {
+ arg.TCPAddr.Port = port
+ ln, err := ListenTCP(arg.network, &arg.TCPAddr)
+ if err != nil {
+ continue
+ }
+ port = ln.Addr().(*TCPAddr).Port
+ lns = append(lns, ln)
+ }
+ if len(lns) != len(args) {
+ for _, ln := range lns {
+ ln.Close()
+ }
+ continue
+ }
+ return lns, nil
+ }
+ return nil, errors.New("no dualstack port available")
+}
+
+type localServer struct {
+ lnmu sync.RWMutex
+ Listener
+ done chan bool // signal that indicates server stopped
+ cl []Conn // accepted connection list
+}
+
+func (ls *localServer) buildup(handler func(*localServer, Listener)) error {
+ go func() {
+ handler(ls, ls.Listener)
+ close(ls.done)
+ }()
+ return nil
+}
+
+func (ls *localServer) teardown() error {
+ ls.lnmu.Lock()
+ defer ls.lnmu.Unlock()
+ if ls.Listener != nil {
+ network := ls.Listener.Addr().Network()
+ address := ls.Listener.Addr().String()
+ ls.Listener.Close()
+ for _, c := range ls.cl {
+ if err := c.Close(); err != nil {
+ return err
+ }
+ }
+ <-ls.done
+ ls.Listener = nil
+ switch network {
+ case "unix", "unixpacket":
+ os.Remove(address)
+ }
+ }
+ return nil
+}
+
+func newLocalServer(t testing.TB, network string) *localServer {
+ t.Helper()
+ ln := newLocalListener(t, network)
+ return &localServer{Listener: ln, done: make(chan bool)}
+}
+
+type streamListener struct {
+ network, address string
+ Listener
+ done chan bool // signal that indicates server stopped
+}
+
+func (sl *streamListener) newLocalServer() *localServer {
+ return &localServer{Listener: sl.Listener, done: make(chan bool)}
+}
+
+type dualStackServer struct {
+ lnmu sync.RWMutex
+ lns []streamListener
+ port string
+
+ cmu sync.RWMutex
+ cs []Conn // established connections at the passive open side
+}
+
+func (dss *dualStackServer) buildup(handler func(*dualStackServer, Listener)) error {
+ for i := range dss.lns {
+ go func(i int) {
+ handler(dss, dss.lns[i].Listener)
+ close(dss.lns[i].done)
+ }(i)
+ }
+ return nil
+}
+
+func (dss *dualStackServer) teardownNetwork(network string) error {
+ dss.lnmu.Lock()
+ for i := range dss.lns {
+ if network == dss.lns[i].network && dss.lns[i].Listener != nil {
+ dss.lns[i].Listener.Close()
+ <-dss.lns[i].done
+ dss.lns[i].Listener = nil
+ }
+ }
+ dss.lnmu.Unlock()
+ return nil
+}
+
+func (dss *dualStackServer) teardown() error {
+ dss.lnmu.Lock()
+ for i := range dss.lns {
+ if dss.lns[i].Listener != nil {
+ dss.lns[i].Listener.Close()
+ <-dss.lns[i].done
+ }
+ }
+ dss.lns = dss.lns[:0]
+ dss.lnmu.Unlock()
+ dss.cmu.Lock()
+ for _, c := range dss.cs {
+ c.Close()
+ }
+ dss.cs = dss.cs[:0]
+ dss.cmu.Unlock()
+ return nil
+}
+
+func newDualStackServer() (*dualStackServer, error) {
+ lns, err := newDualStackListener()
+ if err != nil {
+ return nil, err
+ }
+ _, port, err := SplitHostPort(lns[0].Addr().String())
+ if err != nil {
+ lns[0].Close()
+ lns[1].Close()
+ return nil, err
+ }
+ return &dualStackServer{
+ lns: []streamListener{
+ {network: "tcp4", address: lns[0].Addr().String(), Listener: lns[0], done: make(chan bool)},
+ {network: "tcp6", address: lns[1].Addr().String(), Listener: lns[1], done: make(chan bool)},
+ },
+ port: port,
+ }, nil
+}
+
+func (ls *localServer) transponder(ln Listener, ch chan<- error) {
+ defer close(ch)
+
+ switch ln := ln.(type) {
+ case *TCPListener:
+ ln.SetDeadline(time.Now().Add(someTimeout))
+ case *UnixListener:
+ ln.SetDeadline(time.Now().Add(someTimeout))
+ }
+ c, err := ln.Accept()
+ if err != nil {
+ if perr := parseAcceptError(err); perr != nil {
+ ch <- perr
+ }
+ ch <- err
+ return
+ }
+ ls.cl = append(ls.cl, c)
+
+ network := ln.Addr().Network()
+ if c.LocalAddr().Network() != network || c.RemoteAddr().Network() != network {
+ ch <- fmt.Errorf("got %v->%v; expected %v->%v", c.LocalAddr().Network(), c.RemoteAddr().Network(), network, network)
+ return
+ }
+ c.SetDeadline(time.Now().Add(someTimeout))
+ c.SetReadDeadline(time.Now().Add(someTimeout))
+ c.SetWriteDeadline(time.Now().Add(someTimeout))
+
+ b := make([]byte, 256)
+ n, err := c.Read(b)
+ if err != nil {
+ if perr := parseReadError(err); perr != nil {
+ ch <- perr
+ }
+ ch <- err
+ return
+ }
+ if _, err := c.Write(b[:n]); err != nil {
+ if perr := parseWriteError(err); perr != nil {
+ ch <- perr
+ }
+ ch <- err
+ return
+ }
+}
+
+func transceiver(c Conn, wb []byte, ch chan<- error) {
+ defer close(ch)
+
+ c.SetDeadline(time.Now().Add(someTimeout))
+ c.SetReadDeadline(time.Now().Add(someTimeout))
+ c.SetWriteDeadline(time.Now().Add(someTimeout))
+
+ n, err := c.Write(wb)
+ if err != nil {
+ if perr := parseWriteError(err); perr != nil {
+ ch <- perr
+ }
+ ch <- err
+ return
+ }
+ if n != len(wb) {
+ ch <- fmt.Errorf("wrote %d; want %d", n, len(wb))
+ }
+ rb := make([]byte, len(wb))
+ n, err = c.Read(rb)
+ if err != nil {
+ if perr := parseReadError(err); perr != nil {
+ ch <- perr
+ }
+ ch <- err
+ return
+ }
+ if n != len(wb) {
+ ch <- fmt.Errorf("read %d; want %d", n, len(wb))
+ }
+}
+
+func newLocalPacketListener(t testing.TB, network string, lcOpt ...*ListenConfig) PacketConn {
+ var lc *ListenConfig
+ switch len(lcOpt) {
+ case 0:
+ lc = new(ListenConfig)
+ case 1:
+ lc = lcOpt[0]
+ default:
+ t.Helper()
+ t.Fatal("too many ListenConfigs passed to newLocalListener: want 0 or 1")
+ }
+
+ listenPacket := func(net, addr string) PacketConn {
+ c, err := lc.ListenPacket(context.Background(), net, addr)
+ if err != nil {
+ t.Helper()
+ t.Fatal(err)
+ }
+ return c
+ }
+
+ switch network {
+ case "udp":
+ if supportsIPv4() {
+ return listenPacket("udp4", "127.0.0.1:0")
+ }
+ if supportsIPv6() {
+ return listenPacket("udp6", "[::1]:0")
+ }
+ case "udp4":
+ if supportsIPv4() {
+ return listenPacket("udp4", "127.0.0.1:0")
+ }
+ case "udp6":
+ if supportsIPv6() {
+ return listenPacket("udp6", "[::1]:0")
+ }
+ case "unixgram":
+ return listenPacket(network, testUnixAddr(t))
+ }
+
+ t.Helper()
+ t.Fatalf("%s is not supported", network)
+ return nil
+}
+
+func newDualStackPacketListener() (cs []*UDPConn, err error) {
+ var args = []struct {
+ network string
+ UDPAddr
+ }{
+ {"udp4", UDPAddr{IP: IPv4(127, 0, 0, 1)}},
+ {"udp6", UDPAddr{IP: IPv6loopback}},
+ }
+ for i := 0; i < 64; i++ {
+ var port int
+ var cs []*UDPConn
+ for _, arg := range args {
+ arg.UDPAddr.Port = port
+ c, err := ListenUDP(arg.network, &arg.UDPAddr)
+ if err != nil {
+ continue
+ }
+ port = c.LocalAddr().(*UDPAddr).Port
+ cs = append(cs, c)
+ }
+ if len(cs) != len(args) {
+ for _, c := range cs {
+ c.Close()
+ }
+ continue
+ }
+ return cs, nil
+ }
+ return nil, errors.New("no dualstack port available")
+}
+
+type localPacketServer struct {
+ pcmu sync.RWMutex
+ PacketConn
+ done chan bool // signal that indicates server stopped
+}
+
+func (ls *localPacketServer) buildup(handler func(*localPacketServer, PacketConn)) error {
+ go func() {
+ handler(ls, ls.PacketConn)
+ close(ls.done)
+ }()
+ return nil
+}
+
+func (ls *localPacketServer) teardown() error {
+ ls.pcmu.Lock()
+ if ls.PacketConn != nil {
+ network := ls.PacketConn.LocalAddr().Network()
+ address := ls.PacketConn.LocalAddr().String()
+ ls.PacketConn.Close()
+ <-ls.done
+ ls.PacketConn = nil
+ switch network {
+ case "unixgram":
+ os.Remove(address)
+ }
+ }
+ ls.pcmu.Unlock()
+ return nil
+}
+
+func newLocalPacketServer(t testing.TB, network string) *localPacketServer {
+ t.Helper()
+ c := newLocalPacketListener(t, network)
+ return &localPacketServer{PacketConn: c, done: make(chan bool)}
+}
+
+type packetListener struct {
+ PacketConn
+}
+
+func (pl *packetListener) newLocalServer() *localPacketServer {
+ return &localPacketServer{PacketConn: pl.PacketConn, done: make(chan bool)}
+}
+
+func packetTransponder(c PacketConn, ch chan<- error) {
+ defer close(ch)
+
+ c.SetDeadline(time.Now().Add(someTimeout))
+ c.SetReadDeadline(time.Now().Add(someTimeout))
+ c.SetWriteDeadline(time.Now().Add(someTimeout))
+
+ b := make([]byte, 256)
+ n, peer, err := c.ReadFrom(b)
+ if err != nil {
+ if perr := parseReadError(err); perr != nil {
+ ch <- perr
+ }
+ ch <- err
+ return
+ }
+ if peer == nil { // for connected-mode sockets
+ switch c.LocalAddr().Network() {
+ case "udp":
+ peer, err = ResolveUDPAddr("udp", string(b[:n]))
+ case "unixgram":
+ peer, err = ResolveUnixAddr("unixgram", string(b[:n]))
+ }
+ if err != nil {
+ ch <- err
+ return
+ }
+ }
+ if _, err := c.WriteTo(b[:n], peer); err != nil {
+ if perr := parseWriteError(err); perr != nil {
+ ch <- perr
+ }
+ ch <- err
+ return
+ }
+}
+
+func packetTransceiver(c PacketConn, wb []byte, dst Addr, ch chan<- error) {
+ defer close(ch)
+
+ c.SetDeadline(time.Now().Add(someTimeout))
+ c.SetReadDeadline(time.Now().Add(someTimeout))
+ c.SetWriteDeadline(time.Now().Add(someTimeout))
+
+ n, err := c.WriteTo(wb, dst)
+ if err != nil {
+ if perr := parseWriteError(err); perr != nil {
+ ch <- perr
+ }
+ ch <- err
+ return
+ }
+ if n != len(wb) {
+ ch <- fmt.Errorf("wrote %d; want %d", n, len(wb))
+ }
+ rb := make([]byte, len(wb))
+ n, _, err = c.ReadFrom(rb)
+ if err != nil {
+ if perr := parseReadError(err); perr != nil {
+ ch <- perr
+ }
+ ch <- err
+ return
+ }
+ if n != len(wb) {
+ ch <- fmt.Errorf("read %d; want %d", n, len(wb))
+ }
+}
diff --git a/src/net/mptcpsock_linux.go b/src/net/mptcpsock_linux.go
new file mode 100644
index 0000000..b2ac3ee
--- /dev/null
+++ b/src/net/mptcpsock_linux.go
@@ -0,0 +1,127 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "context"
+ "errors"
+ "internal/poll"
+ "internal/syscall/unix"
+ "sync"
+ "syscall"
+)
+
+var (
+ mptcpOnce sync.Once
+ mptcpAvailable bool
+ hasSOLMPTCP bool
+)
+
+// These constants aren't in the syscall package, which is frozen
+const (
+ _IPPROTO_MPTCP = 0x106
+ _SOL_MPTCP = 0x11c
+ _MPTCP_INFO = 0x1
+)
+
+func supportsMultipathTCP() bool {
+ mptcpOnce.Do(initMPTCPavailable)
+ return mptcpAvailable
+}
+
+// Check that MPTCP is supported by attempting to create an MPTCP socket and by
+// looking at the returned error if any.
+func initMPTCPavailable() {
+ s, err := sysSocket(syscall.AF_INET, syscall.SOCK_STREAM, _IPPROTO_MPTCP)
+ switch {
+ case errors.Is(err, syscall.EPROTONOSUPPORT): // Not supported: >= v5.6
+ case errors.Is(err, syscall.EINVAL): // Not supported: < v5.6
+ case err == nil: // Supported and no error
+ poll.CloseFunc(s)
+ fallthrough
+ default:
+ // another error: MPTCP was not available but it might be later
+ mptcpAvailable = true
+ }
+
+ major, minor := unix.KernelVersion()
+ // SOL_MPTCP only supported from kernel 5.16
+ hasSOLMPTCP = major > 5 || (major == 5 && minor >= 16)
+}
+
+func (sd *sysDialer) dialMPTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConn, error) {
+ if supportsMultipathTCP() {
+ if conn, err := sd.doDialTCPProto(ctx, laddr, raddr, _IPPROTO_MPTCP); err == nil {
+ return conn, nil
+ }
+ }
+
+ // Fallback to dialTCP if Multipath TCP isn't supported on this operating
+ // system. But also fallback in case of any error with MPTCP.
+ //
+ // Possible MPTCP specific error: ENOPROTOOPT (sysctl net.mptcp.enabled=0)
+ // But just in case MPTCP is blocked differently (SELinux, etc.), just
+ // retry with "plain" TCP.
+ return sd.dialTCP(ctx, laddr, raddr)
+}
+
+func (sl *sysListener) listenMPTCP(ctx context.Context, laddr *TCPAddr) (*TCPListener, error) {
+ if supportsMultipathTCP() {
+ if dial, err := sl.listenTCPProto(ctx, laddr, _IPPROTO_MPTCP); err == nil {
+ return dial, nil
+ }
+ }
+
+ // Fallback to listenTCP if Multipath TCP isn't supported on this operating
+ // system. But also fallback in case of any error with MPTCP.
+ //
+ // Possible MPTCP specific error: ENOPROTOOPT (sysctl net.mptcp.enabled=0)
+ // But just in case MPTCP is blocked differently (SELinux, etc.), just
+ // retry with "plain" TCP.
+ return sl.listenTCP(ctx, laddr)
+}
+
+// hasFallenBack reports whether the MPTCP connection has fallen back to "plain"
+// TCP.
+//
+// A connection can fallback to TCP for different reasons, e.g. the other peer
+// doesn't support it, a middle box "accidentally" drops the option, etc.
+//
+// If the MPTCP protocol has not been requested when creating the socket, this
+// method will return true: MPTCP is not being used.
+//
+// Kernel >= 5.16 returns EOPNOTSUPP/ENOPROTOOPT in case of fallback.
+// Older kernels will always return them even if MPTCP is used: not usable.
+func hasFallenBack(fd *netFD) bool {
+ _, err := fd.pfd.GetsockoptInt(_SOL_MPTCP, _MPTCP_INFO)
+
+ // 2 expected errors in case of fallback depending on the address family
+ // - AF_INET: EOPNOTSUPP
+ // - AF_INET6: ENOPROTOOPT
+ return err == syscall.EOPNOTSUPP || err == syscall.ENOPROTOOPT
+}
+
+// isUsingMPTCPProto reports whether the socket protocol is MPTCP.
+//
+// Compared to hasFallenBack method, here only the socket protocol being used is
+// checked: it can be MPTCP but it doesn't mean MPTCP is used on the wire, maybe
+// a fallback to TCP has been done.
+func isUsingMPTCPProto(fd *netFD) bool {
+ proto, _ := fd.pfd.GetsockoptInt(syscall.SOL_SOCKET, syscall.SO_PROTOCOL)
+
+ return proto == _IPPROTO_MPTCP
+}
+
+// isUsingMultipathTCP reports whether MPTCP is still being used.
+//
+// Please look at the description of hasFallenBack (kernel >=5.16) and
+// isUsingMPTCPProto methods for more details about what is being checked here.
+func isUsingMultipathTCP(fd *netFD) bool {
+ if hasSOLMPTCP {
+ return !hasFallenBack(fd)
+ }
+
+ return isUsingMPTCPProto(fd)
+}
diff --git a/src/net/mptcpsock_linux_test.go b/src/net/mptcpsock_linux_test.go
new file mode 100644
index 0000000..5134aba
--- /dev/null
+++ b/src/net/mptcpsock_linux_test.go
@@ -0,0 +1,192 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "bytes"
+ "context"
+ "errors"
+ "syscall"
+ "testing"
+)
+
+func newLocalListenerMPTCP(t *testing.T, envVar bool) Listener {
+ lc := &ListenConfig{}
+
+ if envVar {
+ if !lc.MultipathTCP() {
+ t.Fatal("MultipathTCP Listen is not on despite GODEBUG=multipathtcp=1")
+ }
+ } else {
+ if lc.MultipathTCP() {
+ t.Error("MultipathTCP should be off by default")
+ }
+
+ lc.SetMultipathTCP(true)
+ if !lc.MultipathTCP() {
+ t.Fatal("MultipathTCP is not on after having been forced to on")
+ }
+ }
+
+ ln, err := lc.Listen(context.Background(), "tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ return ln
+}
+
+func postAcceptMPTCP(ls *localServer, ch chan<- error) {
+ defer close(ch)
+
+ if len(ls.cl) == 0 {
+ ch <- errors.New("no accepted stream")
+ return
+ }
+
+ c := ls.cl[0]
+
+ tcp, ok := c.(*TCPConn)
+ if !ok {
+ ch <- errors.New("struct is not a TCPConn")
+ return
+ }
+
+ mptcp, err := tcp.MultipathTCP()
+ if err != nil {
+ ch <- err
+ return
+ }
+
+ if !mptcp {
+ ch <- errors.New("incoming connection is not with MPTCP")
+ return
+ }
+
+ // Also check the method for the older kernels if not tested before
+ if hasSOLMPTCP && !isUsingMPTCPProto(tcp.fd) {
+ ch <- errors.New("incoming connection is not an MPTCP proto")
+ return
+ }
+}
+
+func dialerMPTCP(t *testing.T, addr string, envVar bool) {
+ d := &Dialer{}
+
+ if envVar {
+ if !d.MultipathTCP() {
+ t.Fatal("MultipathTCP Dialer is not on despite GODEBUG=multipathtcp=1")
+ }
+ } else {
+ if d.MultipathTCP() {
+ t.Error("MultipathTCP should be off by default")
+ }
+
+ d.SetMultipathTCP(true)
+ if !d.MultipathTCP() {
+ t.Fatal("MultipathTCP is not on after having been forced to on")
+ }
+ }
+
+ c, err := d.Dial("tcp", addr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+
+ tcp, ok := c.(*TCPConn)
+ if !ok {
+ t.Fatal("struct is not a TCPConn")
+ }
+
+ // Transfer a bit of data to make sure everything is still OK
+ snt := []byte("MPTCP TEST")
+ if _, err := c.Write(snt); err != nil {
+ t.Fatal(err)
+ }
+ b := make([]byte, len(snt))
+ if _, err := c.Read(b); err != nil {
+ t.Fatal(err)
+ }
+ if !bytes.Equal(snt, b) {
+ t.Errorf("sent bytes (%s) are different from received ones (%s)", snt, b)
+ }
+
+ mptcp, err := tcp.MultipathTCP()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ t.Logf("outgoing connection from %s with mptcp: %t", addr, mptcp)
+
+ if !mptcp {
+ t.Error("outgoing connection is not with MPTCP")
+ }
+
+ // Also check the method for the older kernels if not tested before
+ if hasSOLMPTCP && !isUsingMPTCPProto(tcp.fd) {
+ t.Error("outgoing connection is not an MPTCP proto")
+ }
+}
+
+func canCreateMPTCPSocket() bool {
+ // We want to know if we can create an MPTCP socket, not just if it is
+ // available (mptcpAvailable()): it could be blocked by the admin
+ fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_STREAM, _IPPROTO_MPTCP)
+ if err != nil {
+ return false
+ }
+
+ syscall.Close(fd)
+ return true
+}
+
+func testMultiPathTCP(t *testing.T, envVar bool) {
+ if envVar {
+ t.Log("Test with GODEBUG=multipathtcp=1")
+ t.Setenv("GODEBUG", "multipathtcp=1")
+ } else {
+ t.Log("Test with GODEBUG=multipathtcp=0")
+ t.Setenv("GODEBUG", "multipathtcp=0")
+ }
+
+ ln := newLocalListenerMPTCP(t, envVar)
+
+ // similar to tcpsock_test:TestIPv6LinkLocalUnicastTCP
+ ls := (&streamListener{Listener: ln}).newLocalServer()
+ defer ls.teardown()
+
+ if g, w := ls.Listener.Addr().Network(), "tcp"; g != w {
+ t.Fatalf("Network type mismatch: got %q, want %q", g, w)
+ }
+
+ genericCh := make(chan error)
+ mptcpCh := make(chan error)
+ handler := func(ls *localServer, ln Listener) {
+ ls.transponder(ln, genericCh)
+ postAcceptMPTCP(ls, mptcpCh)
+ }
+ if err := ls.buildup(handler); err != nil {
+ t.Fatal(err)
+ }
+
+ dialerMPTCP(t, ln.Addr().String(), envVar)
+
+ if err := <-genericCh; err != nil {
+ t.Error(err)
+ }
+ if err := <-mptcpCh; err != nil {
+ t.Error(err)
+ }
+}
+
+func TestMultiPathTCP(t *testing.T) {
+ if !canCreateMPTCPSocket() {
+ t.Skip("Cannot create MPTCP sockets")
+ }
+
+ for _, envVar := range []bool{false, true} {
+ testMultiPathTCP(t, envVar)
+ }
+}
diff --git a/src/net/mptcpsock_stub.go b/src/net/mptcpsock_stub.go
new file mode 100644
index 0000000..458c153
--- /dev/null
+++ b/src/net/mptcpsock_stub.go
@@ -0,0 +1,23 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !linux
+
+package net
+
+import (
+ "context"
+)
+
+func (sd *sysDialer) dialMPTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConn, error) {
+ return sd.dialTCP(ctx, laddr, raddr)
+}
+
+func (sl *sysListener) listenMPTCP(ctx context.Context, laddr *TCPAddr) (*TCPListener, error) {
+ return sl.listenTCP(ctx, laddr)
+}
+
+func isUsingMultipathTCP(fd *netFD) bool {
+ return false
+}
diff --git a/src/net/net.go b/src/net/net.go
new file mode 100644
index 0000000..5cfc25f
--- /dev/null
+++ b/src/net/net.go
@@ -0,0 +1,767 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+/*
+Package net provides a portable interface for network I/O, including
+TCP/IP, UDP, domain name resolution, and Unix domain sockets.
+
+Although the package provides access to low-level networking
+primitives, most clients will need only the basic interface provided
+by the Dial, Listen, and Accept functions and the associated
+Conn and Listener interfaces. The crypto/tls package uses
+the same interfaces and similar Dial and Listen functions.
+
+The Dial function connects to a server:
+
+ conn, err := net.Dial("tcp", "golang.org:80")
+ if err != nil {
+ // handle error
+ }
+ fmt.Fprintf(conn, "GET / HTTP/1.0\r\n\r\n")
+ status, err := bufio.NewReader(conn).ReadString('\n')
+ // ...
+
+The Listen function creates servers:
+
+ ln, err := net.Listen("tcp", ":8080")
+ if err != nil {
+ // handle error
+ }
+ for {
+ conn, err := ln.Accept()
+ if err != nil {
+ // handle error
+ }
+ go handleConnection(conn)
+ }
+
+# Name Resolution
+
+The method for resolving domain names, whether indirectly with functions like Dial
+or directly with functions like LookupHost and LookupAddr, varies by operating system.
+
+On Unix systems, the resolver has two options for resolving names.
+It can use a pure Go resolver that sends DNS requests directly to the servers
+listed in /etc/resolv.conf, or it can use a cgo-based resolver that calls C
+library routines such as getaddrinfo and getnameinfo.
+
+By default the pure Go resolver is used, because a blocked DNS request consumes
+only a goroutine, while a blocked C call consumes an operating system thread.
+When cgo is available, the cgo-based resolver is used instead under a variety of
+conditions: on systems that do not let programs make direct DNS requests (OS X),
+when the LOCALDOMAIN environment variable is present (even if empty),
+when the RES_OPTIONS or HOSTALIASES environment variable is non-empty,
+when the ASR_CONFIG environment variable is non-empty (OpenBSD only),
+when /etc/resolv.conf or /etc/nsswitch.conf specify the use of features that the
+Go resolver does not implement, and when the name being looked up ends in .local
+or is an mDNS name.
+
+The resolver decision can be overridden by setting the netdns value of the
+GODEBUG environment variable (see package runtime) to go or cgo, as in:
+
+ export GODEBUG=netdns=go # force pure Go resolver
+ export GODEBUG=netdns=cgo # force native resolver (cgo, win32)
+
+The decision can also be forced while building the Go source tree
+by setting the netgo or netcgo build tag.
+
+A numeric netdns setting, as in GODEBUG=netdns=1, causes the resolver
+to print debugging information about its decisions.
+To force a particular resolver while also printing debugging information,
+join the two settings by a plus sign, as in GODEBUG=netdns=go+1.
+
+On macOS, if Go code that uses the net package is built with
+-buildmode=c-archive, linking the resulting archive into a C program
+requires passing -lresolv when linking the C code.
+
+On Plan 9, the resolver always accesses /net/cs and /net/dns.
+
+On Windows, in Go 1.18.x and earlier, the resolver always used C
+library functions, such as GetAddrInfo and DnsQuery.
+*/
+package net
+
+import (
+ "context"
+ "errors"
+ "internal/poll"
+ "io"
+ "os"
+ "sync"
+ "syscall"
+ "time"
+)
+
+// Addr represents a network end point address.
+//
+// The two methods Network and String conventionally return strings
+// that can be passed as the arguments to Dial, but the exact form
+// and meaning of the strings is up to the implementation.
+type Addr interface {
+ Network() string // name of the network (for example, "tcp", "udp")
+ String() string // string form of address (for example, "192.0.2.1:25", "[2001:db8::1]:80")
+}
+
+// Conn is a generic stream-oriented network connection.
+//
+// Multiple goroutines may invoke methods on a Conn simultaneously.
+type Conn interface {
+ // Read reads data from the connection.
+ // Read can be made to time out and return an error after a fixed
+ // time limit; see SetDeadline and SetReadDeadline.
+ Read(b []byte) (n int, err error)
+
+ // Write writes data to the connection.
+ // Write can be made to time out and return an error after a fixed
+ // time limit; see SetDeadline and SetWriteDeadline.
+ Write(b []byte) (n int, err error)
+
+ // Close closes the connection.
+ // Any blocked Read or Write operations will be unblocked and return errors.
+ Close() error
+
+ // LocalAddr returns the local network address, if known.
+ LocalAddr() Addr
+
+ // RemoteAddr returns the remote network address, if known.
+ RemoteAddr() Addr
+
+ // SetDeadline sets the read and write deadlines associated
+ // with the connection. It is equivalent to calling both
+ // SetReadDeadline and SetWriteDeadline.
+ //
+ // A deadline is an absolute time after which I/O operations
+ // fail instead of blocking. The deadline applies to all future
+ // and pending I/O, not just the immediately following call to
+ // Read or Write. After a deadline has been exceeded, the
+ // connection can be refreshed by setting a deadline in the future.
+ //
+ // If the deadline is exceeded a call to Read or Write or to other
+ // I/O methods will return an error that wraps os.ErrDeadlineExceeded.
+ // This can be tested using errors.Is(err, os.ErrDeadlineExceeded).
+ // The error's Timeout method will return true, but note that there
+ // are other possible errors for which the Timeout method will
+ // return true even if the deadline has not been exceeded.
+ //
+ // An idle timeout can be implemented by repeatedly extending
+ // the deadline after successful Read or Write calls.
+ //
+ // A zero value for t means I/O operations will not time out.
+ SetDeadline(t time.Time) error
+
+ // SetReadDeadline sets the deadline for future Read calls
+ // and any currently-blocked Read call.
+ // A zero value for t means Read will not time out.
+ SetReadDeadline(t time.Time) error
+
+ // SetWriteDeadline sets the deadline for future Write calls
+ // and any currently-blocked Write call.
+ // Even if write times out, it may return n > 0, indicating that
+ // some of the data was successfully written.
+ // A zero value for t means Write will not time out.
+ SetWriteDeadline(t time.Time) error
+}
+
+type conn struct {
+ fd *netFD
+}
+
+func (c *conn) ok() bool { return c != nil && c.fd != nil }
+
+// Implementation of the Conn interface.
+
+// Read implements the Conn Read method.
+func (c *conn) Read(b []byte) (int, error) {
+ if !c.ok() {
+ return 0, syscall.EINVAL
+ }
+ n, err := c.fd.Read(b)
+ if err != nil && err != io.EOF {
+ err = &OpError{Op: "read", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
+ }
+ return n, err
+}
+
+// Write implements the Conn Write method.
+func (c *conn) Write(b []byte) (int, error) {
+ if !c.ok() {
+ return 0, syscall.EINVAL
+ }
+ n, err := c.fd.Write(b)
+ if err != nil {
+ err = &OpError{Op: "write", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
+ }
+ return n, err
+}
+
+// Close closes the connection.
+func (c *conn) Close() error {
+ if !c.ok() {
+ return syscall.EINVAL
+ }
+ err := c.fd.Close()
+ if err != nil {
+ err = &OpError{Op: "close", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
+ }
+ return err
+}
+
+// LocalAddr returns the local network address.
+// The Addr returned is shared by all invocations of LocalAddr, so
+// do not modify it.
+func (c *conn) LocalAddr() Addr {
+ if !c.ok() {
+ return nil
+ }
+ return c.fd.laddr
+}
+
+// RemoteAddr returns the remote network address.
+// The Addr returned is shared by all invocations of RemoteAddr, so
+// do not modify it.
+func (c *conn) RemoteAddr() Addr {
+ if !c.ok() {
+ return nil
+ }
+ return c.fd.raddr
+}
+
+// SetDeadline implements the Conn SetDeadline method.
+func (c *conn) SetDeadline(t time.Time) error {
+ if !c.ok() {
+ return syscall.EINVAL
+ }
+ if err := c.fd.SetDeadline(t); err != nil {
+ return &OpError{Op: "set", Net: c.fd.net, Source: nil, Addr: c.fd.laddr, Err: err}
+ }
+ return nil
+}
+
+// SetReadDeadline implements the Conn SetReadDeadline method.
+func (c *conn) SetReadDeadline(t time.Time) error {
+ if !c.ok() {
+ return syscall.EINVAL
+ }
+ if err := c.fd.SetReadDeadline(t); err != nil {
+ return &OpError{Op: "set", Net: c.fd.net, Source: nil, Addr: c.fd.laddr, Err: err}
+ }
+ return nil
+}
+
+// SetWriteDeadline implements the Conn SetWriteDeadline method.
+func (c *conn) SetWriteDeadline(t time.Time) error {
+ if !c.ok() {
+ return syscall.EINVAL
+ }
+ if err := c.fd.SetWriteDeadline(t); err != nil {
+ return &OpError{Op: "set", Net: c.fd.net, Source: nil, Addr: c.fd.laddr, Err: err}
+ }
+ return nil
+}
+
+// SetReadBuffer sets the size of the operating system's
+// receive buffer associated with the connection.
+func (c *conn) SetReadBuffer(bytes int) error {
+ if !c.ok() {
+ return syscall.EINVAL
+ }
+ if err := setReadBuffer(c.fd, bytes); err != nil {
+ return &OpError{Op: "set", Net: c.fd.net, Source: nil, Addr: c.fd.laddr, Err: err}
+ }
+ return nil
+}
+
+// SetWriteBuffer sets the size of the operating system's
+// transmit buffer associated with the connection.
+func (c *conn) SetWriteBuffer(bytes int) error {
+ if !c.ok() {
+ return syscall.EINVAL
+ }
+ if err := setWriteBuffer(c.fd, bytes); err != nil {
+ return &OpError{Op: "set", Net: c.fd.net, Source: nil, Addr: c.fd.laddr, Err: err}
+ }
+ return nil
+}
+
+// File returns a copy of the underlying os.File.
+// It is the caller's responsibility to close f when finished.
+// Closing c does not affect f, and closing f does not affect c.
+//
+// The returned os.File's file descriptor is different from the connection's.
+// Attempting to change properties of the original using this duplicate
+// may or may not have the desired effect.
+func (c *conn) File() (f *os.File, err error) {
+ f, err = c.fd.dup()
+ if err != nil {
+ err = &OpError{Op: "file", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
+ }
+ return
+}
+
+// PacketConn is a generic packet-oriented network connection.
+//
+// Multiple goroutines may invoke methods on a PacketConn simultaneously.
+type PacketConn interface {
+ // ReadFrom reads a packet from the connection,
+ // copying the payload into p. It returns the number of
+ // bytes copied into p and the return address that
+ // was on the packet.
+ // It returns the number of bytes read (0 <= n <= len(p))
+ // and any error encountered. Callers should always process
+ // the n > 0 bytes returned before considering the error err.
+ // ReadFrom can be made to time out and return an error after a
+ // fixed time limit; see SetDeadline and SetReadDeadline.
+ ReadFrom(p []byte) (n int, addr Addr, err error)
+
+ // WriteTo writes a packet with payload p to addr.
+ // WriteTo can be made to time out and return an Error after a
+ // fixed time limit; see SetDeadline and SetWriteDeadline.
+ // On packet-oriented connections, write timeouts are rare.
+ WriteTo(p []byte, addr Addr) (n int, err error)
+
+ // Close closes the connection.
+ // Any blocked ReadFrom or WriteTo operations will be unblocked and return errors.
+ Close() error
+
+ // LocalAddr returns the local network address, if known.
+ LocalAddr() Addr
+
+ // SetDeadline sets the read and write deadlines associated
+ // with the connection. It is equivalent to calling both
+ // SetReadDeadline and SetWriteDeadline.
+ //
+ // A deadline is an absolute time after which I/O operations
+ // fail instead of blocking. The deadline applies to all future
+ // and pending I/O, not just the immediately following call to
+ // Read or Write. After a deadline has been exceeded, the
+ // connection can be refreshed by setting a deadline in the future.
+ //
+ // If the deadline is exceeded a call to Read or Write or to other
+ // I/O methods will return an error that wraps os.ErrDeadlineExceeded.
+ // This can be tested using errors.Is(err, os.ErrDeadlineExceeded).
+ // The error's Timeout method will return true, but note that there
+ // are other possible errors for which the Timeout method will
+ // return true even if the deadline has not been exceeded.
+ //
+ // An idle timeout can be implemented by repeatedly extending
+ // the deadline after successful ReadFrom or WriteTo calls.
+ //
+ // A zero value for t means I/O operations will not time out.
+ SetDeadline(t time.Time) error
+
+ // SetReadDeadline sets the deadline for future ReadFrom calls
+ // and any currently-blocked ReadFrom call.
+ // A zero value for t means ReadFrom will not time out.
+ SetReadDeadline(t time.Time) error
+
+ // SetWriteDeadline sets the deadline for future WriteTo calls
+ // and any currently-blocked WriteTo call.
+ // Even if write times out, it may return n > 0, indicating that
+ // some of the data was successfully written.
+ // A zero value for t means WriteTo will not time out.
+ SetWriteDeadline(t time.Time) error
+}
+
+var listenerBacklogCache struct {
+ sync.Once
+ val int
+}
+
+// listenerBacklog is a caching wrapper around maxListenerBacklog.
+func listenerBacklog() int {
+ listenerBacklogCache.Do(func() { listenerBacklogCache.val = maxListenerBacklog() })
+ return listenerBacklogCache.val
+}
+
+// A Listener is a generic network listener for stream-oriented protocols.
+//
+// Multiple goroutines may invoke methods on a Listener simultaneously.
+type Listener interface {
+ // Accept waits for and returns the next connection to the listener.
+ Accept() (Conn, error)
+
+ // Close closes the listener.
+ // Any blocked Accept operations will be unblocked and return errors.
+ Close() error
+
+ // Addr returns the listener's network address.
+ Addr() Addr
+}
+
+// An Error represents a network error.
+type Error interface {
+ error
+ Timeout() bool // Is the error a timeout?
+
+ // Deprecated: Temporary errors are not well-defined.
+ // Most "temporary" errors are timeouts, and the few exceptions are surprising.
+ // Do not use this method.
+ Temporary() bool
+}
+
+// Various errors contained in OpError.
+var (
+ // For connection setup operations.
+ errNoSuitableAddress = errors.New("no suitable address found")
+
+ // For connection setup and write operations.
+ errMissingAddress = errors.New("missing address")
+
+ // For both read and write operations.
+ errCanceled = canceledError{}
+ ErrWriteToConnected = errors.New("use of WriteTo with pre-connected connection")
+)
+
+// canceledError lets us return the same error string we have always
+// returned, while still being Is context.Canceled.
+type canceledError struct{}
+
+func (canceledError) Error() string { return "operation was canceled" }
+
+func (canceledError) Is(err error) bool { return err == context.Canceled }
+
+// mapErr maps from the context errors to the historical internal net
+// error values.
+func mapErr(err error) error {
+ switch err {
+ case context.Canceled:
+ return errCanceled
+ case context.DeadlineExceeded:
+ return errTimeout
+ default:
+ return err
+ }
+}
+
+// OpError is the error type usually returned by functions in the net
+// package. It describes the operation, network type, and address of
+// an error.
+type OpError struct {
+ // Op is the operation which caused the error, such as
+ // "read" or "write".
+ Op string
+
+ // Net is the network type on which this error occurred,
+ // such as "tcp" or "udp6".
+ Net string
+
+ // For operations involving a remote network connection, like
+ // Dial, Read, or Write, Source is the corresponding local
+ // network address.
+ Source Addr
+
+ // Addr is the network address for which this error occurred.
+ // For local operations, like Listen or SetDeadline, Addr is
+ // the address of the local endpoint being manipulated.
+ // For operations involving a remote network connection, like
+ // Dial, Read, or Write, Addr is the remote address of that
+ // connection.
+ Addr Addr
+
+ // Err is the error that occurred during the operation.
+ // The Error method panics if the error is nil.
+ Err error
+}
+
+func (e *OpError) Unwrap() error { return e.Err }
+
+func (e *OpError) Error() string {
+ if e == nil {
+ return "<nil>"
+ }
+ s := e.Op
+ if e.Net != "" {
+ s += " " + e.Net
+ }
+ if e.Source != nil {
+ s += " " + e.Source.String()
+ }
+ if e.Addr != nil {
+ if e.Source != nil {
+ s += "->"
+ } else {
+ s += " "
+ }
+ s += e.Addr.String()
+ }
+ s += ": " + e.Err.Error()
+ return s
+}
+
+var (
+ // aLongTimeAgo is a non-zero time, far in the past, used for
+ // immediate cancellation of dials.
+ aLongTimeAgo = time.Unix(1, 0)
+
+ // noDeadline and noCancel are just zero values for
+ // readability with functions taking too many parameters.
+ noDeadline = time.Time{}
+ noCancel = (chan struct{})(nil)
+)
+
+type timeout interface {
+ Timeout() bool
+}
+
+func (e *OpError) Timeout() bool {
+ if ne, ok := e.Err.(*os.SyscallError); ok {
+ t, ok := ne.Err.(timeout)
+ return ok && t.Timeout()
+ }
+ t, ok := e.Err.(timeout)
+ return ok && t.Timeout()
+}
+
+type temporary interface {
+ Temporary() bool
+}
+
+func (e *OpError) Temporary() bool {
+ // Treat ECONNRESET and ECONNABORTED as temporary errors when
+ // they come from calling accept. See issue 6163.
+ if e.Op == "accept" && isConnError(e.Err) {
+ return true
+ }
+
+ if ne, ok := e.Err.(*os.SyscallError); ok {
+ t, ok := ne.Err.(temporary)
+ return ok && t.Temporary()
+ }
+ t, ok := e.Err.(temporary)
+ return ok && t.Temporary()
+}
+
+// A ParseError is the error type of literal network address parsers.
+type ParseError struct {
+ // Type is the type of string that was expected, such as
+ // "IP address", "CIDR address".
+ Type string
+
+ // Text is the malformed text string.
+ Text string
+}
+
+func (e *ParseError) Error() string { return "invalid " + e.Type + ": " + e.Text }
+
+func (e *ParseError) Timeout() bool { return false }
+func (e *ParseError) Temporary() bool { return false }
+
+type AddrError struct {
+ Err string
+ Addr string
+}
+
+func (e *AddrError) Error() string {
+ if e == nil {
+ return "<nil>"
+ }
+ s := e.Err
+ if e.Addr != "" {
+ s = "address " + e.Addr + ": " + s
+ }
+ return s
+}
+
+func (e *AddrError) Timeout() bool { return false }
+func (e *AddrError) Temporary() bool { return false }
+
+type UnknownNetworkError string
+
+func (e UnknownNetworkError) Error() string { return "unknown network " + string(e) }
+func (e UnknownNetworkError) Timeout() bool { return false }
+func (e UnknownNetworkError) Temporary() bool { return false }
+
+type InvalidAddrError string
+
+func (e InvalidAddrError) Error() string { return string(e) }
+func (e InvalidAddrError) Timeout() bool { return false }
+func (e InvalidAddrError) Temporary() bool { return false }
+
+// errTimeout exists to return the historical "i/o timeout" string
+// for context.DeadlineExceeded. See mapErr.
+// It is also used when Dialer.Deadline is exceeded.
+// error.Is(errTimeout, context.DeadlineExceeded) returns true.
+//
+// TODO(iant): We could consider changing this to os.ErrDeadlineExceeded
+// in the future, if we make
+//
+// errors.Is(os.ErrDeadlineExceeded, context.DeadlineExceeded)
+//
+// return true.
+var errTimeout error = &timeoutError{}
+
+type timeoutError struct{}
+
+func (e *timeoutError) Error() string { return "i/o timeout" }
+func (e *timeoutError) Timeout() bool { return true }
+func (e *timeoutError) Temporary() bool { return true }
+
+func (e *timeoutError) Is(err error) bool {
+ return err == context.DeadlineExceeded
+}
+
+// DNSConfigError represents an error reading the machine's DNS configuration.
+// (No longer used; kept for compatibility.)
+type DNSConfigError struct {
+ Err error
+}
+
+func (e *DNSConfigError) Unwrap() error { return e.Err }
+func (e *DNSConfigError) Error() string { return "error reading DNS config: " + e.Err.Error() }
+func (e *DNSConfigError) Timeout() bool { return false }
+func (e *DNSConfigError) Temporary() bool { return false }
+
+// Various errors contained in DNSError.
+var (
+ errNoSuchHost = errors.New("no such host")
+)
+
+// DNSError represents a DNS lookup error.
+type DNSError struct {
+ Err string // description of the error
+ Name string // name looked for
+ Server string // server used
+ IsTimeout bool // if true, timed out; not all timeouts set this
+ IsTemporary bool // if true, error is temporary; not all errors set this
+ IsNotFound bool // if true, host could not be found
+}
+
+func (e *DNSError) Error() string {
+ if e == nil {
+ return "<nil>"
+ }
+ s := "lookup " + e.Name
+ if e.Server != "" {
+ s += " on " + e.Server
+ }
+ s += ": " + e.Err
+ return s
+}
+
+// Timeout reports whether the DNS lookup is known to have timed out.
+// This is not always known; a DNS lookup may fail due to a timeout
+// and return a DNSError for which Timeout returns false.
+func (e *DNSError) Timeout() bool { return e.IsTimeout }
+
+// Temporary reports whether the DNS error is known to be temporary.
+// This is not always known; a DNS lookup may fail due to a temporary
+// error and return a DNSError for which Temporary returns false.
+func (e *DNSError) Temporary() bool { return e.IsTimeout || e.IsTemporary }
+
+// errClosed exists just so that the docs for ErrClosed don't mention
+// the internal package poll.
+var errClosed = poll.ErrNetClosing
+
+// ErrClosed is the error returned by an I/O call on a network
+// connection that has already been closed, or that is closed by
+// another goroutine before the I/O is completed. This may be wrapped
+// in another error, and should normally be tested using
+// errors.Is(err, net.ErrClosed).
+var ErrClosed error = errClosed
+
+type writerOnly struct {
+ io.Writer
+}
+
+// Fallback implementation of io.ReaderFrom's ReadFrom, when sendfile isn't
+// applicable.
+func genericReadFrom(w io.Writer, r io.Reader) (n int64, err error) {
+ // Use wrapper to hide existing r.ReadFrom from io.Copy.
+ return io.Copy(writerOnly{w}, r)
+}
+
+// Limit the number of concurrent cgo-using goroutines, because
+// each will block an entire operating system thread. The usual culprit
+// is resolving many DNS names in separate goroutines but the DNS
+// server is not responding. Then the many lookups each use a different
+// thread, and the system or the program runs out of threads.
+
+var threadLimit chan struct{}
+
+var threadOnce sync.Once
+
+func acquireThread() {
+ threadOnce.Do(func() {
+ threadLimit = make(chan struct{}, concurrentThreadsLimit())
+ })
+ threadLimit <- struct{}{}
+}
+
+func releaseThread() {
+ <-threadLimit
+}
+
+// buffersWriter is the interface implemented by Conns that support a
+// "writev"-like batch write optimization.
+// writeBuffers should fully consume and write all chunks from the
+// provided Buffers, else it should report a non-nil error.
+type buffersWriter interface {
+ writeBuffers(*Buffers) (int64, error)
+}
+
+// Buffers contains zero or more runs of bytes to write.
+//
+// On certain machines, for certain types of connections, this is
+// optimized into an OS-specific batch write operation (such as
+// "writev").
+type Buffers [][]byte
+
+var (
+ _ io.WriterTo = (*Buffers)(nil)
+ _ io.Reader = (*Buffers)(nil)
+)
+
+// WriteTo writes contents of the buffers to w.
+//
+// WriteTo implements io.WriterTo for Buffers.
+//
+// WriteTo modifies the slice v as well as v[i] for 0 <= i < len(v),
+// but does not modify v[i][j] for any i, j.
+func (v *Buffers) WriteTo(w io.Writer) (n int64, err error) {
+ if wv, ok := w.(buffersWriter); ok {
+ return wv.writeBuffers(v)
+ }
+ for _, b := range *v {
+ nb, err := w.Write(b)
+ n += int64(nb)
+ if err != nil {
+ v.consume(n)
+ return n, err
+ }
+ }
+ v.consume(n)
+ return n, nil
+}
+
+// Read from the buffers.
+//
+// Read implements io.Reader for Buffers.
+//
+// Read modifies the slice v as well as v[i] for 0 <= i < len(v),
+// but does not modify v[i][j] for any i, j.
+func (v *Buffers) Read(p []byte) (n int, err error) {
+ for len(p) > 0 && len(*v) > 0 {
+ n0 := copy(p, (*v)[0])
+ v.consume(int64(n0))
+ p = p[n0:]
+ n += n0
+ }
+ if len(*v) == 0 {
+ err = io.EOF
+ }
+ return
+}
+
+func (v *Buffers) consume(n int64) {
+ for len(*v) > 0 {
+ ln0 := int64(len((*v)[0]))
+ if ln0 > n {
+ (*v)[0] = (*v)[0][n:]
+ return
+ }
+ n -= ln0
+ (*v)[0] = nil
+ *v = (*v)[1:]
+ }
+}
diff --git a/src/net/net_fake.go b/src/net/net_fake.go
new file mode 100644
index 0000000..908767a
--- /dev/null
+++ b/src/net/net_fake.go
@@ -0,0 +1,406 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Fake networking for js/wasm and wasip1/wasm. It is intended to allow tests of other package to pass.
+
+//go:build (js && wasm) || wasip1
+
+package net
+
+import (
+ "context"
+ "io"
+ "os"
+ "sync"
+ "syscall"
+ "time"
+)
+
+var listenersMu sync.Mutex
+var listeners = make(map[fakeNetAddr]*netFD)
+
+var portCounterMu sync.Mutex
+var portCounter = 0
+
+func nextPort() int {
+ portCounterMu.Lock()
+ defer portCounterMu.Unlock()
+ portCounter++
+ return portCounter
+}
+
+type fakeNetAddr struct {
+ network string
+ address string
+}
+
+type fakeNetFD struct {
+ listener fakeNetAddr
+ r *bufferedPipe
+ w *bufferedPipe
+ incoming chan *netFD
+ closedMu sync.Mutex
+ closed bool
+}
+
+// socket returns a network file descriptor that is ready for
+// asynchronous I/O using the network poller.
+func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only bool, laddr, raddr sockaddr, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) (*netFD, error) {
+ fd := &netFD{family: family, sotype: sotype, net: net}
+ if laddr != nil && raddr == nil {
+ return fakelistener(fd, laddr)
+ }
+ fd2 := &netFD{family: family, sotype: sotype, net: net}
+ return fakeconn(fd, fd2, laddr, raddr)
+}
+
+func fakeIPAndPort(ip IP, port int) (IP, int) {
+ if ip == nil {
+ ip = IPv4(127, 0, 0, 1)
+ }
+ if port == 0 {
+ port = nextPort()
+ }
+ return ip, port
+}
+
+func fakeTCPAddr(addr *TCPAddr) *TCPAddr {
+ var ip IP
+ var port int
+ var zone string
+ if addr != nil {
+ ip, port, zone = addr.IP, addr.Port, addr.Zone
+ }
+ ip, port = fakeIPAndPort(ip, port)
+ return &TCPAddr{IP: ip, Port: port, Zone: zone}
+}
+
+func fakeUDPAddr(addr *UDPAddr) *UDPAddr {
+ var ip IP
+ var port int
+ var zone string
+ if addr != nil {
+ ip, port, zone = addr.IP, addr.Port, addr.Zone
+ }
+ ip, port = fakeIPAndPort(ip, port)
+ return &UDPAddr{IP: ip, Port: port, Zone: zone}
+}
+
+func fakeUnixAddr(sotype int, addr *UnixAddr) *UnixAddr {
+ var net, name string
+ if addr != nil {
+ name = addr.Name
+ }
+ switch sotype {
+ case syscall.SOCK_DGRAM:
+ net = "unixgram"
+ case syscall.SOCK_SEQPACKET:
+ net = "unixpacket"
+ default:
+ net = "unix"
+ }
+ return &UnixAddr{Net: net, Name: name}
+}
+
+func fakelistener(fd *netFD, laddr sockaddr) (*netFD, error) {
+ switch l := laddr.(type) {
+ case *TCPAddr:
+ laddr = fakeTCPAddr(l)
+ case *UDPAddr:
+ laddr = fakeUDPAddr(l)
+ case *UnixAddr:
+ if l.Name == "" {
+ return nil, syscall.ENOENT
+ }
+ laddr = fakeUnixAddr(fd.sotype, l)
+ default:
+ return nil, syscall.EOPNOTSUPP
+ }
+
+ listener := fakeNetAddr{
+ network: laddr.Network(),
+ address: laddr.String(),
+ }
+
+ fd.fakeNetFD = &fakeNetFD{
+ listener: listener,
+ incoming: make(chan *netFD, 1024),
+ }
+
+ fd.laddr = laddr
+ listenersMu.Lock()
+ defer listenersMu.Unlock()
+ if _, exists := listeners[listener]; exists {
+ return nil, syscall.EADDRINUSE
+ }
+ listeners[listener] = fd
+ return fd, nil
+}
+
+func fakeconn(fd *netFD, fd2 *netFD, laddr, raddr sockaddr) (*netFD, error) {
+ switch r := raddr.(type) {
+ case *TCPAddr:
+ r = fakeTCPAddr(r)
+ raddr = r
+ laddr = fakeTCPAddr(laddr.(*TCPAddr))
+ case *UDPAddr:
+ r = fakeUDPAddr(r)
+ raddr = r
+ laddr = fakeUDPAddr(laddr.(*UDPAddr))
+ case *UnixAddr:
+ r = fakeUnixAddr(fd.sotype, r)
+ raddr = r
+ laddr = &UnixAddr{Net: r.Net, Name: r.Name}
+ default:
+ return nil, syscall.EAFNOSUPPORT
+ }
+ fd.laddr = laddr
+ fd.raddr = raddr
+
+ fd.fakeNetFD = &fakeNetFD{
+ r: newBufferedPipe(65536),
+ w: newBufferedPipe(65536),
+ }
+ fd2.fakeNetFD = &fakeNetFD{
+ r: fd.fakeNetFD.w,
+ w: fd.fakeNetFD.r,
+ }
+
+ fd2.laddr = fd.raddr
+ fd2.raddr = fd.laddr
+
+ listener := fakeNetAddr{
+ network: fd.raddr.Network(),
+ address: fd.raddr.String(),
+ }
+ listenersMu.Lock()
+ defer listenersMu.Unlock()
+ l, ok := listeners[listener]
+ if !ok {
+ return nil, syscall.ECONNREFUSED
+ }
+ l.incoming <- fd2
+ return fd, nil
+}
+
+func (fd *fakeNetFD) Read(p []byte) (n int, err error) {
+ return fd.r.Read(p)
+}
+
+func (fd *fakeNetFD) Write(p []byte) (nn int, err error) {
+ return fd.w.Write(p)
+}
+
+func (fd *fakeNetFD) Close() error {
+ fd.closedMu.Lock()
+ if fd.closed {
+ fd.closedMu.Unlock()
+ return nil
+ }
+ fd.closed = true
+ fd.closedMu.Unlock()
+
+ if fd.listener != (fakeNetAddr{}) {
+ listenersMu.Lock()
+ delete(listeners, fd.listener)
+ close(fd.incoming)
+ fd.listener = fakeNetAddr{}
+ listenersMu.Unlock()
+ return nil
+ }
+
+ fd.r.Close()
+ fd.w.Close()
+ return nil
+}
+
+func (fd *fakeNetFD) closeRead() error {
+ fd.r.Close()
+ return nil
+}
+
+func (fd *fakeNetFD) closeWrite() error {
+ fd.w.Close()
+ return nil
+}
+
+func (fd *fakeNetFD) accept() (*netFD, error) {
+ c, ok := <-fd.incoming
+ if !ok {
+ return nil, syscall.EINVAL
+ }
+ return c, nil
+}
+
+func (fd *fakeNetFD) SetDeadline(t time.Time) error {
+ fd.r.SetReadDeadline(t)
+ fd.w.SetWriteDeadline(t)
+ return nil
+}
+
+func (fd *fakeNetFD) SetReadDeadline(t time.Time) error {
+ fd.r.SetReadDeadline(t)
+ return nil
+}
+
+func (fd *fakeNetFD) SetWriteDeadline(t time.Time) error {
+ fd.w.SetWriteDeadline(t)
+ return nil
+}
+
+func newBufferedPipe(softLimit int) *bufferedPipe {
+ p := &bufferedPipe{softLimit: softLimit}
+ p.rCond.L = &p.mu
+ p.wCond.L = &p.mu
+ return p
+}
+
+type bufferedPipe struct {
+ softLimit int
+ mu sync.Mutex
+ buf []byte
+ closed bool
+ rCond sync.Cond
+ wCond sync.Cond
+ rDeadline time.Time
+ wDeadline time.Time
+}
+
+func (p *bufferedPipe) Read(b []byte) (int, error) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ for {
+ if p.closed && len(p.buf) == 0 {
+ return 0, io.EOF
+ }
+ if !p.rDeadline.IsZero() {
+ d := time.Until(p.rDeadline)
+ if d <= 0 {
+ return 0, os.ErrDeadlineExceeded
+ }
+ time.AfterFunc(d, p.rCond.Broadcast)
+ }
+ if len(p.buf) > 0 {
+ break
+ }
+ p.rCond.Wait()
+ }
+
+ n := copy(b, p.buf)
+ p.buf = p.buf[n:]
+ p.wCond.Broadcast()
+ return n, nil
+}
+
+func (p *bufferedPipe) Write(b []byte) (int, error) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ for {
+ if p.closed {
+ return 0, syscall.ENOTCONN
+ }
+ if !p.wDeadline.IsZero() {
+ d := time.Until(p.wDeadline)
+ if d <= 0 {
+ return 0, os.ErrDeadlineExceeded
+ }
+ time.AfterFunc(d, p.wCond.Broadcast)
+ }
+ if len(p.buf) <= p.softLimit {
+ break
+ }
+ p.wCond.Wait()
+ }
+
+ p.buf = append(p.buf, b...)
+ p.rCond.Broadcast()
+ return len(b), nil
+}
+
+func (p *bufferedPipe) Close() {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ p.closed = true
+ p.rCond.Broadcast()
+ p.wCond.Broadcast()
+}
+
+func (p *bufferedPipe) SetReadDeadline(t time.Time) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ p.rDeadline = t
+ p.rCond.Broadcast()
+}
+
+func (p *bufferedPipe) SetWriteDeadline(t time.Time) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ p.wDeadline = t
+ p.wCond.Broadcast()
+}
+
+func sysSocket(family, sotype, proto int) (int, error) {
+ return 0, syscall.ENOSYS
+}
+
+func (fd *fakeNetFD) connect(ctx context.Context, la, ra syscall.Sockaddr) (syscall.Sockaddr, error) {
+ return nil, syscall.ENOSYS
+}
+
+func (fd *fakeNetFD) readFrom(p []byte) (n int, sa syscall.Sockaddr, err error) {
+ return 0, nil, syscall.ENOSYS
+
+}
+func (fd *fakeNetFD) readFromInet4(p []byte, sa *syscall.SockaddrInet4) (n int, err error) {
+ return 0, syscall.ENOSYS
+}
+
+func (fd *fakeNetFD) readFromInet6(p []byte, sa *syscall.SockaddrInet6) (n int, err error) {
+ return 0, syscall.ENOSYS
+}
+
+func (fd *fakeNetFD) readMsg(p []byte, oob []byte, flags int) (n, oobn, retflags int, sa syscall.Sockaddr, err error) {
+ return 0, 0, 0, nil, syscall.ENOSYS
+}
+
+func (fd *fakeNetFD) readMsgInet4(p []byte, oob []byte, flags int, sa *syscall.SockaddrInet4) (n, oobn, retflags int, err error) {
+ return 0, 0, 0, syscall.ENOSYS
+}
+
+func (fd *fakeNetFD) readMsgInet6(p []byte, oob []byte, flags int, sa *syscall.SockaddrInet6) (n, oobn, retflags int, err error) {
+ return 0, 0, 0, syscall.ENOSYS
+}
+
+func (fd *fakeNetFD) writeMsgInet4(p []byte, oob []byte, sa *syscall.SockaddrInet4) (n int, oobn int, err error) {
+ return 0, 0, syscall.ENOSYS
+}
+
+func (fd *fakeNetFD) writeMsgInet6(p []byte, oob []byte, sa *syscall.SockaddrInet6) (n int, oobn int, err error) {
+ return 0, 0, syscall.ENOSYS
+}
+
+func (fd *fakeNetFD) writeTo(p []byte, sa syscall.Sockaddr) (n int, err error) {
+ return 0, syscall.ENOSYS
+}
+
+func (fd *fakeNetFD) writeToInet4(p []byte, sa *syscall.SockaddrInet4) (n int, err error) {
+ return 0, syscall.ENOSYS
+}
+
+func (fd *fakeNetFD) writeToInet6(p []byte, sa *syscall.SockaddrInet6) (n int, err error) {
+ return 0, syscall.ENOSYS
+}
+
+func (fd *fakeNetFD) writeMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err error) {
+ return 0, 0, syscall.ENOSYS
+}
+
+func (fd *fakeNetFD) dup() (f *os.File, err error) {
+ return nil, syscall.ENOSYS
+}
diff --git a/src/net/net_fake_js.go b/src/net/net_fake_js.go
new file mode 100644
index 0000000..7ba108b
--- /dev/null
+++ b/src/net/net_fake_js.go
@@ -0,0 +1,36 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Fake networking for js/wasm. It is intended to allow tests of other package to pass.
+
+//go:build js && wasm
+
+package net
+
+import (
+ "context"
+ "internal/poll"
+
+ "golang.org/x/net/dns/dnsmessage"
+)
+
+// Network file descriptor.
+type netFD struct {
+ *fakeNetFD
+
+ // immutable until Close
+ family int
+ sotype int
+ net string
+ laddr Addr
+ raddr Addr
+
+ // unused
+ pfd poll.FD
+ isConnected bool // handshake completed or use of association with peer
+}
+
+func (r *Resolver) lookup(ctx context.Context, name string, qtype dnsmessage.Type, conf *dnsConfig) (dnsmessage.Parser, string, error) {
+ panic("unreachable")
+}
diff --git a/src/net/net_fake_test.go b/src/net/net_fake_test.go
new file mode 100644
index 0000000..783304d
--- /dev/null
+++ b/src/net/net_fake_test.go
@@ -0,0 +1,203 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build js || wasip1
+
+package net
+
+// GOOS=js and GOOS=wasip1 do not have typical socket networking capabilities
+// found on other platforms. To help run test suites of the stdlib packages,
+// an in-memory "fake network" facility is implemented.
+//
+// The tests in this files are intended to validate the behavior of the fake
+// network stack on these platforms.
+
+import "testing"
+
+func TestFakeConn(t *testing.T) {
+ tests := []struct {
+ name string
+ listen func() (Listener, error)
+ dial func(Addr) (Conn, error)
+ addr func(*testing.T, Addr)
+ }{
+ {
+ name: "Listener:tcp",
+ listen: func() (Listener, error) {
+ return Listen("tcp", ":0")
+ },
+ dial: func(addr Addr) (Conn, error) {
+ return Dial(addr.Network(), addr.String())
+ },
+ addr: testFakeTCPAddr,
+ },
+
+ {
+ name: "ListenTCP:tcp",
+ listen: func() (Listener, error) {
+ // Creating a listening TCP connection with a nil address must
+ // select an IP address on localhost with a random port.
+ // This test verifies that the fake network facility does that.
+ return ListenTCP("tcp", nil)
+ },
+ dial: func(addr Addr) (Conn, error) {
+ // Connecting a listening TCP connection will select a local
+ // address on the local network and connects to the destination
+ // address.
+ return DialTCP("tcp", nil, addr.(*TCPAddr))
+ },
+ addr: testFakeTCPAddr,
+ },
+
+ {
+ name: "ListenUnix:unix",
+ listen: func() (Listener, error) {
+ return ListenUnix("unix", &UnixAddr{Name: "test"})
+ },
+ dial: func(addr Addr) (Conn, error) {
+ return DialUnix("unix", nil, addr.(*UnixAddr))
+ },
+ addr: testFakeUnixAddr("unix", "test"),
+ },
+
+ {
+ name: "ListenUnix:unixpacket",
+ listen: func() (Listener, error) {
+ return ListenUnix("unixpacket", &UnixAddr{Name: "test"})
+ },
+ dial: func(addr Addr) (Conn, error) {
+ return DialUnix("unixpacket", nil, addr.(*UnixAddr))
+ },
+ addr: testFakeUnixAddr("unixpacket", "test"),
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ l, err := test.listen()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer l.Close()
+ test.addr(t, l.Addr())
+
+ c, err := test.dial(l.Addr())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+ test.addr(t, c.LocalAddr())
+ test.addr(t, c.RemoteAddr())
+ })
+ }
+}
+
+func TestFakePacketConn(t *testing.T) {
+ tests := []struct {
+ name string
+ listen func() (PacketConn, error)
+ dial func(Addr) (Conn, error)
+ addr func(*testing.T, Addr)
+ }{
+ {
+ name: "ListenPacket:udp",
+ listen: func() (PacketConn, error) {
+ return ListenPacket("udp", ":0")
+ },
+ dial: func(addr Addr) (Conn, error) {
+ return Dial(addr.Network(), addr.String())
+ },
+ addr: testFakeUDPAddr,
+ },
+
+ {
+ name: "ListenUDP:udp",
+ listen: func() (PacketConn, error) {
+ // Creating a listening UDP connection with a nil address must
+ // select an IP address on localhost with a random port.
+ // This test verifies that the fake network facility does that.
+ return ListenUDP("udp", nil)
+ },
+ dial: func(addr Addr) (Conn, error) {
+ // Connecting a listening UDP connection will select a local
+ // address on the local network and connects to the destination
+ // address.
+ return DialUDP("udp", nil, addr.(*UDPAddr))
+ },
+ addr: testFakeUDPAddr,
+ },
+
+ {
+ name: "ListenUnixgram:unixgram",
+ listen: func() (PacketConn, error) {
+ return ListenUnixgram("unixgram", &UnixAddr{Name: "test"})
+ },
+ dial: func(addr Addr) (Conn, error) {
+ return DialUnix("unixgram", nil, addr.(*UnixAddr))
+ },
+ addr: testFakeUnixAddr("unixgram", "test"),
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ l, err := test.listen()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer l.Close()
+ test.addr(t, l.LocalAddr())
+
+ c, err := test.dial(l.LocalAddr())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+ test.addr(t, c.LocalAddr())
+ test.addr(t, c.RemoteAddr())
+ })
+ }
+}
+
+func testFakeTCPAddr(t *testing.T, addr Addr) {
+ t.Helper()
+ if a, ok := addr.(*TCPAddr); !ok {
+ t.Errorf("Addr is not *TCPAddr: %T", addr)
+ } else {
+ testFakeNetAddr(t, a.IP, a.Port)
+ }
+}
+
+func testFakeUDPAddr(t *testing.T, addr Addr) {
+ t.Helper()
+ if a, ok := addr.(*UDPAddr); !ok {
+ t.Errorf("Addr is not *UDPAddr: %T", addr)
+ } else {
+ testFakeNetAddr(t, a.IP, a.Port)
+ }
+}
+
+func testFakeNetAddr(t *testing.T, ip IP, port int) {
+ t.Helper()
+ if port == 0 {
+ t.Error("network address is missing port")
+ } else if len(ip) == 0 {
+ t.Error("network address is missing IP")
+ } else if !ip.Equal(IPv4(127, 0, 0, 1)) {
+ t.Errorf("network address has wrong IP: %s", ip)
+ }
+}
+
+func testFakeUnixAddr(net, name string) func(*testing.T, Addr) {
+ return func(t *testing.T, addr Addr) {
+ t.Helper()
+ if a, ok := addr.(*UnixAddr); !ok {
+ t.Errorf("Addr is not *UnixAddr: %T", addr)
+ } else if a.Net != net {
+ t.Errorf("unix address has wrong net: want=%q got=%q", net, a.Net)
+ } else if a.Name != name {
+ t.Errorf("unix address has wrong name: want=%q got=%q", name, a.Name)
+ }
+ }
+}
diff --git a/src/net/net_test.go b/src/net/net_test.go
new file mode 100644
index 0000000..a0ac85f
--- /dev/null
+++ b/src/net/net_test.go
@@ -0,0 +1,593 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !js && !wasip1
+
+package net
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "net/internal/socktest"
+ "os"
+ "runtime"
+ "testing"
+ "time"
+)
+
+func TestCloseRead(t *testing.T) {
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("not supported on %s", runtime.GOOS)
+ }
+ t.Parallel()
+
+ for _, network := range []string{"tcp", "unix", "unixpacket"} {
+ network := network
+ t.Run(network, func(t *testing.T) {
+ if !testableNetwork(network) {
+ t.Skipf("network %s is not testable on the current platform", network)
+ }
+ t.Parallel()
+
+ ln := newLocalListener(t, network)
+ switch network {
+ case "unix", "unixpacket":
+ defer os.Remove(ln.Addr().String())
+ }
+ defer ln.Close()
+
+ c, err := Dial(ln.Addr().Network(), ln.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ switch network {
+ case "unix", "unixpacket":
+ defer os.Remove(c.LocalAddr().String())
+ }
+ defer c.Close()
+
+ switch c := c.(type) {
+ case *TCPConn:
+ err = c.CloseRead()
+ case *UnixConn:
+ err = c.CloseRead()
+ }
+ if err != nil {
+ if perr := parseCloseError(err, true); perr != nil {
+ t.Error(perr)
+ }
+ t.Fatal(err)
+ }
+ var b [1]byte
+ n, err := c.Read(b[:])
+ if n != 0 || err == nil {
+ t.Fatalf("got (%d, %v); want (0, error)", n, err)
+ }
+ })
+ }
+}
+
+func TestCloseWrite(t *testing.T) {
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("not supported on %s", runtime.GOOS)
+ }
+
+ t.Parallel()
+ deadline, _ := t.Deadline()
+ if !deadline.IsZero() {
+ // Leave 10% headroom on the deadline to report errors and clean up.
+ deadline = deadline.Add(-time.Until(deadline) / 10)
+ }
+
+ for _, network := range []string{"tcp", "unix", "unixpacket"} {
+ network := network
+ t.Run(network, func(t *testing.T) {
+ if !testableNetwork(network) {
+ t.Skipf("network %s is not testable on the current platform", network)
+ }
+ t.Parallel()
+
+ handler := func(ls *localServer, ln Listener) {
+ c, err := ln.Accept()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+
+ // Workaround for https://go.dev/issue/49352.
+ // On arm64 macOS (current as of macOS 12.4),
+ // reading from a socket at the same time as the client
+ // is closing it occasionally hangs for 60 seconds before
+ // returning ECONNRESET. Sleep for a bit to give the
+ // socket time to close before trying to read from it.
+ if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" {
+ time.Sleep(10 * time.Millisecond)
+ }
+
+ if !deadline.IsZero() {
+ c.SetDeadline(deadline)
+ }
+ defer c.Close()
+
+ var b [1]byte
+ n, err := c.Read(b[:])
+ if n != 0 || err != io.EOF {
+ t.Errorf("got (%d, %v); want (0, io.EOF)", n, err)
+ return
+ }
+ switch c := c.(type) {
+ case *TCPConn:
+ err = c.CloseWrite()
+ case *UnixConn:
+ err = c.CloseWrite()
+ }
+ if err != nil {
+ if perr := parseCloseError(err, true); perr != nil {
+ t.Error(perr)
+ }
+ t.Error(err)
+ return
+ }
+ n, err = c.Write(b[:])
+ if err == nil {
+ t.Errorf("got (%d, %v); want (any, error)", n, err)
+ return
+ }
+ }
+
+ ls := newLocalServer(t, network)
+ defer ls.teardown()
+ if err := ls.buildup(handler); err != nil {
+ t.Fatal(err)
+ }
+
+ c, err := Dial(ls.Listener.Addr().Network(), ls.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !deadline.IsZero() {
+ c.SetDeadline(deadline)
+ }
+ switch network {
+ case "unix", "unixpacket":
+ defer os.Remove(c.LocalAddr().String())
+ }
+ defer c.Close()
+
+ switch c := c.(type) {
+ case *TCPConn:
+ err = c.CloseWrite()
+ case *UnixConn:
+ err = c.CloseWrite()
+ }
+ if err != nil {
+ if perr := parseCloseError(err, true); perr != nil {
+ t.Error(perr)
+ }
+ t.Fatal(err)
+ }
+ var b [1]byte
+ n, err := c.Read(b[:])
+ if n != 0 || err != io.EOF {
+ t.Fatalf("got (%d, %v); want (0, io.EOF)", n, err)
+ }
+ n, err = c.Write(b[:])
+ if err == nil {
+ t.Fatalf("got (%d, %v); want (any, error)", n, err)
+ }
+ })
+ }
+}
+
+func TestConnClose(t *testing.T) {
+ t.Parallel()
+ for _, network := range []string{"tcp", "unix", "unixpacket"} {
+ network := network
+ t.Run(network, func(t *testing.T) {
+ if !testableNetwork(network) {
+ t.Skipf("network %s is not testable on the current platform", network)
+ }
+ t.Parallel()
+
+ ln := newLocalListener(t, network)
+ switch network {
+ case "unix", "unixpacket":
+ defer os.Remove(ln.Addr().String())
+ }
+ defer ln.Close()
+
+ c, err := Dial(ln.Addr().Network(), ln.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ switch network {
+ case "unix", "unixpacket":
+ defer os.Remove(c.LocalAddr().String())
+ }
+ defer c.Close()
+
+ if err := c.Close(); err != nil {
+ if perr := parseCloseError(err, false); perr != nil {
+ t.Error(perr)
+ }
+ t.Fatal(err)
+ }
+ var b [1]byte
+ n, err := c.Read(b[:])
+ if n != 0 || err == nil {
+ t.Fatalf("got (%d, %v); want (0, error)", n, err)
+ }
+ })
+ }
+}
+
+func TestListenerClose(t *testing.T) {
+ t.Parallel()
+ for _, network := range []string{"tcp", "unix", "unixpacket"} {
+ network := network
+ t.Run(network, func(t *testing.T) {
+ if !testableNetwork(network) {
+ t.Skipf("network %s is not testable on the current platform", network)
+ }
+ t.Parallel()
+
+ ln := newLocalListener(t, network)
+ switch network {
+ case "unix", "unixpacket":
+ defer os.Remove(ln.Addr().String())
+ }
+
+ if err := ln.Close(); err != nil {
+ if perr := parseCloseError(err, false); perr != nil {
+ t.Error(perr)
+ }
+ t.Fatal(err)
+ }
+ c, err := ln.Accept()
+ if err == nil {
+ c.Close()
+ t.Fatal("should fail")
+ }
+
+ // Note: we cannot ensure that a subsequent Dial does not succeed, because
+ // we do not in general have any guarantee that ln.Addr is not immediately
+ // reused. (TCP sockets enter a TIME_WAIT state when closed, but that only
+ // applies to existing connections for the port — it does not prevent the
+ // port itself from being used for entirely new connections in the
+ // meantime.)
+ })
+ }
+}
+
+func TestPacketConnClose(t *testing.T) {
+ t.Parallel()
+ for _, network := range []string{"udp", "unixgram"} {
+ network := network
+ t.Run(network, func(t *testing.T) {
+ if !testableNetwork(network) {
+ t.Skipf("network %s is not testable on the current platform", network)
+ }
+ t.Parallel()
+
+ c := newLocalPacketListener(t, network)
+ switch network {
+ case "unixgram":
+ defer os.Remove(c.LocalAddr().String())
+ }
+ defer c.Close()
+
+ if err := c.Close(); err != nil {
+ if perr := parseCloseError(err, false); perr != nil {
+ t.Error(perr)
+ }
+ t.Fatal(err)
+ }
+ var b [1]byte
+ n, _, err := c.ReadFrom(b[:])
+ if n != 0 || err == nil {
+ t.Fatalf("got (%d, %v); want (0, error)", n, err)
+ }
+ })
+ }
+}
+
+func TestListenCloseListen(t *testing.T) {
+ const maxTries = 10
+ for tries := 0; tries < maxTries; tries++ {
+ ln := newLocalListener(t, "tcp")
+ addr := ln.Addr().String()
+ // TODO: This is racy. The selected address could be reused in between this
+ // Close and the subsequent Listen.
+ if err := ln.Close(); err != nil {
+ if perr := parseCloseError(err, false); perr != nil {
+ t.Error(perr)
+ }
+ t.Fatal(err)
+ }
+ ln, err := Listen("tcp", addr)
+ if err == nil {
+ // Success. (This test didn't always make it here earlier.)
+ ln.Close()
+ return
+ }
+ t.Errorf("failed on try %d/%d: %v", tries+1, maxTries, err)
+ }
+ t.Fatalf("failed to listen/close/listen on same address after %d tries", maxTries)
+}
+
+// See golang.org/issue/6163, golang.org/issue/6987.
+func TestAcceptIgnoreAbortedConnRequest(t *testing.T) {
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("%s does not have full support of socktest", runtime.GOOS)
+ }
+
+ syserr := make(chan error)
+ go func() {
+ defer close(syserr)
+ for _, err := range abortedConnRequestErrors {
+ syserr <- err
+ }
+ }()
+ sw.Set(socktest.FilterAccept, func(so *socktest.Status) (socktest.AfterFilter, error) {
+ if err, ok := <-syserr; ok {
+ return nil, err
+ }
+ return nil, nil
+ })
+ defer sw.Set(socktest.FilterAccept, nil)
+
+ operr := make(chan error, 1)
+ handler := func(ls *localServer, ln Listener) {
+ defer close(operr)
+ c, err := ln.Accept()
+ if err != nil {
+ if perr := parseAcceptError(err); perr != nil {
+ operr <- perr
+ }
+ operr <- err
+ return
+ }
+ c.Close()
+ }
+ ls := newLocalServer(t, "tcp")
+ defer ls.teardown()
+ if err := ls.buildup(handler); err != nil {
+ t.Fatal(err)
+ }
+
+ c, err := Dial(ls.Listener.Addr().Network(), ls.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ c.Close()
+
+ for err := range operr {
+ t.Error(err)
+ }
+}
+
+func TestZeroByteRead(t *testing.T) {
+ t.Parallel()
+ for _, network := range []string{"tcp", "unix", "unixpacket"} {
+ network := network
+ t.Run(network, func(t *testing.T) {
+ if !testableNetwork(network) {
+ t.Skipf("network %s is not testable on the current platform", network)
+ }
+ t.Parallel()
+
+ ln := newLocalListener(t, network)
+ connc := make(chan Conn, 1)
+ go func() {
+ defer ln.Close()
+ c, err := ln.Accept()
+ if err != nil {
+ t.Error(err)
+ }
+ connc <- c // might be nil
+ }()
+ c, err := Dial(network, ln.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+ sc := <-connc
+ if sc == nil {
+ return
+ }
+ defer sc.Close()
+
+ if runtime.GOOS == "windows" {
+ // A zero byte read on Windows caused a wait for readability first.
+ // Rather than change that behavior, satisfy it in this test.
+ // See Issue 15735.
+ go io.WriteString(sc, "a")
+ }
+
+ n, err := c.Read(nil)
+ if n != 0 || err != nil {
+ t.Errorf("%s: zero byte client read = %v, %v; want 0, nil", network, n, err)
+ }
+
+ if runtime.GOOS == "windows" {
+ // Same as comment above.
+ go io.WriteString(c, "a")
+ }
+ n, err = sc.Read(nil)
+ if n != 0 || err != nil {
+ t.Errorf("%s: zero byte server read = %v, %v; want 0, nil", network, n, err)
+ }
+ })
+ }
+}
+
+// withTCPConnPair sets up a TCP connection between two peers, then
+// runs peer1 and peer2 concurrently. withTCPConnPair returns when
+// both have completed.
+func withTCPConnPair(t *testing.T, peer1, peer2 func(c *TCPConn) error) {
+ t.Helper()
+ ln := newLocalListener(t, "tcp")
+ defer ln.Close()
+ errc := make(chan error, 2)
+ go func() {
+ c1, err := ln.Accept()
+ if err != nil {
+ errc <- err
+ return
+ }
+ defer c1.Close()
+ errc <- peer1(c1.(*TCPConn))
+ }()
+ go func() {
+ c2, err := Dial("tcp", ln.Addr().String())
+ if err != nil {
+ errc <- err
+ return
+ }
+ defer c2.Close()
+ errc <- peer2(c2.(*TCPConn))
+ }()
+ for i := 0; i < 2; i++ {
+ if err := <-errc; err != nil {
+ t.Fatal(err)
+ }
+ }
+}
+
+// Tests that a blocked Read is interrupted by a concurrent SetReadDeadline
+// modifying that Conn's read deadline to the past.
+// See golang.org/cl/30164 which documented this. The net/http package
+// depends on this.
+func TestReadTimeoutUnblocksRead(t *testing.T) {
+ serverDone := make(chan struct{})
+ server := func(cs *TCPConn) error {
+ defer close(serverDone)
+ errc := make(chan error, 1)
+ go func() {
+ defer close(errc)
+ go func() {
+ // TODO: find a better way to wait
+ // until we're blocked in the cs.Read
+ // call below. Sleep is lame.
+ time.Sleep(100 * time.Millisecond)
+
+ // Interrupt the upcoming Read, unblocking it:
+ cs.SetReadDeadline(time.Unix(123, 0)) // time in the past
+ }()
+ var buf [1]byte
+ n, err := cs.Read(buf[:1])
+ if n != 0 || err == nil {
+ errc <- fmt.Errorf("Read = %v, %v; want 0, non-nil", n, err)
+ }
+ }()
+ select {
+ case err := <-errc:
+ return err
+ case <-time.After(5 * time.Second):
+ buf := make([]byte, 2<<20)
+ buf = buf[:runtime.Stack(buf, true)]
+ println("Stacks at timeout:\n", string(buf))
+ return errors.New("timeout waiting for Read to finish")
+ }
+
+ }
+ // Do nothing in the client. Never write. Just wait for the
+ // server's half to be done.
+ client := func(*TCPConn) error {
+ <-serverDone
+ return nil
+ }
+ withTCPConnPair(t, client, server)
+}
+
+// Issue 17695: verify that a blocked Read is woken up by a Close.
+func TestCloseUnblocksRead(t *testing.T) {
+ t.Parallel()
+ server := func(cs *TCPConn) error {
+ // Give the client time to get stuck in a Read:
+ time.Sleep(20 * time.Millisecond)
+ cs.Close()
+ return nil
+ }
+ client := func(ss *TCPConn) error {
+ n, err := ss.Read([]byte{0})
+ if n != 0 || err != io.EOF {
+ return fmt.Errorf("Read = %v, %v; want 0, EOF", n, err)
+ }
+ return nil
+ }
+ withTCPConnPair(t, client, server)
+}
+
+// Issue 24808: verify that ECONNRESET is not temporary for read.
+func TestNotTemporaryRead(t *testing.T) {
+ t.Parallel()
+
+ ln := newLocalListener(t, "tcp")
+ serverDone := make(chan struct{})
+ dialed := make(chan struct{})
+ go func() {
+ defer close(serverDone)
+
+ cs, err := ln.Accept()
+ if err != nil {
+ return
+ }
+ <-dialed
+ cs.(*TCPConn).SetLinger(0)
+ cs.Close()
+ }()
+ defer func() {
+ ln.Close()
+ <-serverDone
+ }()
+
+ ss, err := Dial("tcp", ln.Addr().String())
+ close(dialed)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ss.Close()
+
+ _, err = ss.Read([]byte{0})
+ if err == nil {
+ t.Fatal("Read succeeded unexpectedly")
+ } else if err == io.EOF {
+ // This happens on Plan 9, but for some reason (prior to CL 385314) it was
+ // accepted everywhere else too.
+ if runtime.GOOS == "plan9" {
+ return
+ }
+ t.Fatal("Read unexpectedly returned io.EOF after socket was abruptly closed")
+ }
+ if ne, ok := err.(Error); !ok {
+ t.Errorf("Read error does not implement net.Error: %v", err)
+ } else if ne.Temporary() {
+ t.Errorf("Read error is unexpectedly temporary: %v", err)
+ }
+}
+
+// The various errors should implement the Error interface.
+func TestErrors(t *testing.T) {
+ var (
+ _ Error = &OpError{}
+ _ Error = &ParseError{}
+ _ Error = &AddrError{}
+ _ Error = UnknownNetworkError("")
+ _ Error = InvalidAddrError("")
+ _ Error = &timeoutError{}
+ _ Error = &DNSConfigError{}
+ _ Error = &DNSError{}
+ )
+
+ // ErrClosed was introduced as type error, so we can't check
+ // it using a declaration.
+ if _, ok := ErrClosed.(Error); !ok {
+ t.Fatal("ErrClosed does not implement Error")
+ }
+}
diff --git a/src/net/net_windows_test.go b/src/net/net_windows_test.go
new file mode 100644
index 0000000..947dda5
--- /dev/null
+++ b/src/net/net_windows_test.go
@@ -0,0 +1,631 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "bufio"
+ "bytes"
+ "fmt"
+ "internal/testenv"
+ "io"
+ "os"
+ "os/exec"
+ "regexp"
+ "sort"
+ "strings"
+ "syscall"
+ "testing"
+ "time"
+)
+
+func toErrno(err error) (syscall.Errno, bool) {
+ operr, ok := err.(*OpError)
+ if !ok {
+ return 0, false
+ }
+ syserr, ok := operr.Err.(*os.SyscallError)
+ if !ok {
+ return 0, false
+ }
+ errno, ok := syserr.Err.(syscall.Errno)
+ if !ok {
+ return 0, false
+ }
+ return errno, true
+}
+
+// TestAcceptIgnoreSomeErrors tests that windows TCPListener.AcceptTCP
+// handles broken connections. It verifies that broken connections do
+// not affect future connections.
+func TestAcceptIgnoreSomeErrors(t *testing.T) {
+ recv := func(ln Listener, ignoreSomeReadErrors bool) (string, error) {
+ c, err := ln.Accept()
+ if err != nil {
+ // Display windows errno in error message.
+ errno, ok := toErrno(err)
+ if !ok {
+ return "", err
+ }
+ return "", fmt.Errorf("%v (windows errno=%d)", err, errno)
+ }
+ defer c.Close()
+
+ b := make([]byte, 100)
+ n, err := c.Read(b)
+ if err == nil || err == io.EOF {
+ return string(b[:n]), nil
+ }
+ errno, ok := toErrno(err)
+ if ok && ignoreSomeReadErrors && (errno == syscall.ERROR_NETNAME_DELETED || errno == syscall.WSAECONNRESET) {
+ return "", nil
+ }
+ return "", err
+ }
+
+ send := func(addr string, data string) error {
+ c, err := Dial("tcp", addr)
+ if err != nil {
+ return err
+ }
+ defer c.Close()
+
+ b := []byte(data)
+ n, err := c.Write(b)
+ if err != nil {
+ return err
+ }
+ if n != len(b) {
+ return fmt.Errorf(`Only %d chars of string "%s" sent`, n, data)
+ }
+ return nil
+ }
+
+ if envaddr := os.Getenv("GOTEST_DIAL_ADDR"); envaddr != "" {
+ // In child process.
+ c, err := Dial("tcp", envaddr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ fmt.Printf("sleeping\n")
+ time.Sleep(time.Minute) // process will be killed here
+ c.Close()
+ }
+
+ ln, err := Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ln.Close()
+
+ // Start child process that connects to our listener.
+ cmd := exec.Command(os.Args[0], "-test.run=TestAcceptIgnoreSomeErrors")
+ cmd.Env = append(os.Environ(), "GOTEST_DIAL_ADDR="+ln.Addr().String())
+ stdout, err := cmd.StdoutPipe()
+ if err != nil {
+ t.Fatalf("cmd.StdoutPipe failed: %v", err)
+ }
+ err = cmd.Start()
+ if err != nil {
+ t.Fatalf("cmd.Start failed: %v\n", err)
+ }
+ outReader := bufio.NewReader(stdout)
+ for {
+ s, err := outReader.ReadString('\n')
+ if err != nil {
+ t.Fatalf("reading stdout failed: %v", err)
+ }
+ if s == "sleeping\n" {
+ break
+ }
+ }
+ defer cmd.Wait() // ignore error - we know it is getting killed
+
+ const alittle = 100 * time.Millisecond
+ time.Sleep(alittle)
+ cmd.Process.Kill() // the only way to trigger the errors
+ time.Sleep(alittle)
+
+ // Send second connection data (with delay in a separate goroutine).
+ result := make(chan error)
+ go func() {
+ time.Sleep(alittle)
+ err := send(ln.Addr().String(), "abc")
+ if err != nil {
+ result <- err
+ }
+ result <- nil
+ }()
+ defer func() {
+ err := <-result
+ if err != nil {
+ t.Fatalf("send failed: %v", err)
+ }
+ }()
+
+ // Receive first or second connection.
+ s, err := recv(ln, true)
+ if err != nil {
+ t.Fatalf("recv failed: %v", err)
+ }
+ switch s {
+ case "":
+ // First connection data is received, let's get second connection data.
+ case "abc":
+ // First connection is lost forever, but that is ok.
+ return
+ default:
+ t.Fatalf(`"%s" received from recv, but "" or "abc" expected`, s)
+ }
+
+ // Get second connection data.
+ s, err = recv(ln, false)
+ if err != nil {
+ t.Fatalf("recv failed: %v", err)
+ }
+ if s != "abc" {
+ t.Fatalf(`"%s" received from recv, but "abc" expected`, s)
+ }
+}
+
+func runCmd(args ...string) ([]byte, error) {
+ removeUTF8BOM := func(b []byte) []byte {
+ if len(b) >= 3 && b[0] == 0xEF && b[1] == 0xBB && b[2] == 0xBF {
+ return b[3:]
+ }
+ return b
+ }
+ f, err := os.CreateTemp("", "netcmd")
+ if err != nil {
+ return nil, err
+ }
+ f.Close()
+ defer os.Remove(f.Name())
+ cmd := fmt.Sprintf(`%s | Out-File "%s" -encoding UTF8`, strings.Join(args, " "), f.Name())
+ out, err := exec.Command("powershell", "-Command", cmd).CombinedOutput()
+ if err != nil {
+ if len(out) != 0 {
+ return nil, fmt.Errorf("%s failed: %v: %q", args[0], err, string(removeUTF8BOM(out)))
+ }
+ var err2 error
+ out, err2 = os.ReadFile(f.Name())
+ if err2 != nil {
+ return nil, err2
+ }
+ if len(out) != 0 {
+ return nil, fmt.Errorf("%s failed: %v: %q", args[0], err, string(removeUTF8BOM(out)))
+ }
+ return nil, fmt.Errorf("%s failed: %v", args[0], err)
+ }
+ out, err = os.ReadFile(f.Name())
+ if err != nil {
+ return nil, err
+ }
+ return removeUTF8BOM(out), nil
+}
+
+func checkNetsh(t *testing.T) {
+ if testenv.Builder() == "windows-arm64-10" {
+ // netsh was observed to sometimes hang on this builder.
+ // We have not observed failures on windows-arm64-11, so for the
+ // moment we are leaving the test enabled elsewhere on the theory
+ // that it may have been a platform bug fixed in Windows 11.
+ testenv.SkipFlaky(t, 52082)
+ }
+ out, err := runCmd("netsh", "help")
+ if err != nil {
+ t.Fatal(err)
+ }
+ if bytes.Contains(out, []byte("The following helper DLL cannot be loaded")) {
+ t.Skipf("powershell failure:\n%s", err)
+ }
+ if !bytes.Contains(out, []byte("The following commands are available:")) {
+ t.Skipf("powershell does not speak English:\n%s", out)
+ }
+}
+
+func netshInterfaceIPShowInterface(ipver string, ifaces map[string]bool) error {
+ out, err := runCmd("netsh", "interface", ipver, "show", "interface", "level=verbose")
+ if err != nil {
+ return err
+ }
+ // interface information is listed like:
+ //
+ //Interface Local Area Connection Parameters
+ //----------------------------------------------
+ //IfLuid : ethernet_6
+ //IfIndex : 11
+ //State : connected
+ //Metric : 10
+ //...
+ var name string
+ lines := bytes.Split(out, []byte{'\r', '\n'})
+ for _, line := range lines {
+ if bytes.HasPrefix(line, []byte("Interface ")) && bytes.HasSuffix(line, []byte(" Parameters")) {
+ f := line[len("Interface "):]
+ f = f[:len(f)-len(" Parameters")]
+ name = string(f)
+ continue
+ }
+ var isup bool
+ switch string(line) {
+ case "State : connected":
+ isup = true
+ case "State : disconnected":
+ isup = false
+ default:
+ continue
+ }
+ if name != "" {
+ if v, ok := ifaces[name]; ok && v != isup {
+ return fmt.Errorf("%s:%s isup=%v: ipv4 and ipv6 report different interface state", ipver, name, isup)
+ }
+ ifaces[name] = isup
+ name = ""
+ }
+ }
+ return nil
+}
+
+func TestInterfacesWithNetsh(t *testing.T) {
+ checkNetsh(t)
+
+ toString := func(name string, isup bool) string {
+ if isup {
+ return name + ":up"
+ }
+ return name + ":down"
+ }
+
+ ift, err := Interfaces()
+ if err != nil {
+ t.Fatal(err)
+ }
+ have := make([]string, 0)
+ for _, ifi := range ift {
+ have = append(have, toString(ifi.Name, ifi.Flags&FlagUp != 0))
+ }
+ sort.Strings(have)
+
+ ifaces := make(map[string]bool)
+ err = netshInterfaceIPShowInterface("ipv6", ifaces)
+ if err != nil {
+ t.Fatal(err)
+ }
+ err = netshInterfaceIPShowInterface("ipv4", ifaces)
+ if err != nil {
+ t.Fatal(err)
+ }
+ want := make([]string, 0)
+ for name, isup := range ifaces {
+ want = append(want, toString(name, isup))
+ }
+ sort.Strings(want)
+
+ if strings.Join(want, "/") != strings.Join(have, "/") {
+ t.Fatalf("unexpected interface list %q, want %q", have, want)
+ }
+}
+
+func netshInterfaceIPv4ShowAddress(name string, netshOutput []byte) []string {
+ // Address information is listed like:
+ //
+ //Configuration for interface "Local Area Connection"
+ // DHCP enabled: Yes
+ // IP Address: 10.0.0.2
+ // Subnet Prefix: 10.0.0.0/24 (mask 255.255.255.0)
+ // IP Address: 10.0.0.3
+ // Subnet Prefix: 10.0.0.0/24 (mask 255.255.255.0)
+ // Default Gateway: 10.0.0.254
+ // Gateway Metric: 0
+ // InterfaceMetric: 10
+ //
+ //Configuration for interface "Loopback Pseudo-Interface 1"
+ // DHCP enabled: No
+ // IP Address: 127.0.0.1
+ // Subnet Prefix: 127.0.0.0/8 (mask 255.0.0.0)
+ // InterfaceMetric: 50
+ //
+ addrs := make([]string, 0)
+ var addr, subnetprefix string
+ var processingOurInterface bool
+ lines := bytes.Split(netshOutput, []byte{'\r', '\n'})
+ for _, line := range lines {
+ if !processingOurInterface {
+ if !bytes.HasPrefix(line, []byte("Configuration for interface")) {
+ continue
+ }
+ if !bytes.Contains(line, []byte(`"`+name+`"`)) {
+ continue
+ }
+ processingOurInterface = true
+ continue
+ }
+ if len(line) == 0 {
+ break
+ }
+ if bytes.Contains(line, []byte("Subnet Prefix:")) {
+ f := bytes.Split(line, []byte{':'})
+ if len(f) == 2 {
+ f = bytes.Split(f[1], []byte{'('})
+ if len(f) == 2 {
+ f = bytes.Split(f[0], []byte{'/'})
+ if len(f) == 2 {
+ subnetprefix = string(bytes.TrimSpace(f[1]))
+ if addr != "" && subnetprefix != "" {
+ addrs = append(addrs, addr+"/"+subnetprefix)
+ }
+ }
+ }
+ }
+ }
+ addr = ""
+ if bytes.Contains(line, []byte("IP Address:")) {
+ f := bytes.Split(line, []byte{':'})
+ if len(f) == 2 {
+ addr = string(bytes.TrimSpace(f[1]))
+ }
+ }
+ }
+ return addrs
+}
+
+func netshInterfaceIPv6ShowAddress(name string, netshOutput []byte) []string {
+ // Address information is listed like:
+ //
+ //Address ::1 Parameters
+ //---------------------------------------------------------
+ //Interface Luid : Loopback Pseudo-Interface 1
+ //Scope Id : 0.0
+ //Valid Lifetime : infinite
+ //Preferred Lifetime : infinite
+ //DAD State : Preferred
+ //Address Type : Other
+ //Skip as Source : false
+ //
+ //Address XXXX::XXXX:XXXX:XXXX:XXXX%11 Parameters
+ //---------------------------------------------------------
+ //Interface Luid : Local Area Connection
+ //Scope Id : 0.11
+ //Valid Lifetime : infinite
+ //Preferred Lifetime : infinite
+ //DAD State : Preferred
+ //Address Type : Other
+ //Skip as Source : false
+ //
+
+ // TODO: need to test ipv6 netmask too, but netsh does not outputs it
+ var addr string
+ addrs := make([]string, 0)
+ lines := bytes.Split(netshOutput, []byte{'\r', '\n'})
+ for _, line := range lines {
+ if addr != "" {
+ if len(line) == 0 {
+ addr = ""
+ continue
+ }
+ if string(line) != "Interface Luid : "+name {
+ continue
+ }
+ addrs = append(addrs, addr)
+ addr = ""
+ continue
+ }
+ if !bytes.HasPrefix(line, []byte("Address")) {
+ continue
+ }
+ if !bytes.HasSuffix(line, []byte("Parameters")) {
+ continue
+ }
+ f := bytes.Split(line, []byte{' '})
+ if len(f) != 3 {
+ continue
+ }
+ // remove scope ID if present
+ f = bytes.Split(f[1], []byte{'%'})
+
+ // netsh can create IPv4-embedded IPv6 addresses, like fe80::5efe:192.168.140.1.
+ // Convert these to all hexadecimal fe80::5efe:c0a8:8c01 for later string comparisons.
+ ipv4Tail := regexp.MustCompile(`:\d+\.\d+\.\d+\.\d+$`)
+ if ipv4Tail.Match(f[0]) {
+ f[0] = []byte(ParseIP(string(f[0])).String())
+ }
+
+ addr = string(bytes.ToLower(bytes.TrimSpace(f[0])))
+ }
+ return addrs
+}
+
+func TestInterfaceAddrsWithNetsh(t *testing.T) {
+ checkNetsh(t)
+
+ outIPV4, err := runCmd("netsh", "interface", "ipv4", "show", "address")
+ if err != nil {
+ t.Fatal(err)
+ }
+ outIPV6, err := runCmd("netsh", "interface", "ipv6", "show", "address", "level=verbose")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ ift, err := Interfaces()
+ if err != nil {
+ t.Fatal(err)
+ }
+ for _, ifi := range ift {
+ // Skip the interface if it's down.
+ if (ifi.Flags & FlagUp) == 0 {
+ continue
+ }
+ have := make([]string, 0)
+ addrs, err := ifi.Addrs()
+ if err != nil {
+ t.Fatal(err)
+ }
+ for _, addr := range addrs {
+ switch addr := addr.(type) {
+ case *IPNet:
+ if addr.IP.To4() != nil {
+ have = append(have, addr.String())
+ }
+ if addr.IP.To16() != nil && addr.IP.To4() == nil {
+ // netsh does not output netmask for ipv6, so ignore ipv6 mask
+ have = append(have, addr.IP.String())
+ }
+ case *IPAddr:
+ if addr.IP.To4() != nil {
+ have = append(have, addr.String())
+ }
+ if addr.IP.To16() != nil && addr.IP.To4() == nil {
+ // netsh does not output netmask for ipv6, so ignore ipv6 mask
+ have = append(have, addr.IP.String())
+ }
+ }
+ }
+ sort.Strings(have)
+
+ want := netshInterfaceIPv4ShowAddress(ifi.Name, outIPV4)
+ wantIPv6 := netshInterfaceIPv6ShowAddress(ifi.Name, outIPV6)
+ want = append(want, wantIPv6...)
+ sort.Strings(want)
+
+ if strings.Join(want, "/") != strings.Join(have, "/") {
+ t.Errorf("%s: unexpected addresses list %q, want %q", ifi.Name, have, want)
+ }
+ }
+}
+
+// check that getmac exists as a powershell command, and that it
+// speaks English.
+func checkGetmac(t *testing.T) {
+ out, err := runCmd("getmac", "/?")
+ if err != nil {
+ if strings.Contains(err.Error(), "term 'getmac' is not recognized as the name of a cmdlet") {
+ t.Skipf("getmac not available")
+ }
+ t.Fatal(err)
+ }
+ if !bytes.Contains(out, []byte("network adapters on a system")) {
+ t.Skipf("skipping test on non-English system")
+ }
+}
+
+func TestInterfaceHardwareAddrWithGetmac(t *testing.T) {
+ checkGetmac(t)
+
+ ift, err := Interfaces()
+ if err != nil {
+ t.Fatal(err)
+ }
+ have := make(map[string]string)
+ for _, ifi := range ift {
+ if ifi.Flags&FlagLoopback != 0 {
+ // no MAC address for loopback interfaces
+ continue
+ }
+ have[ifi.Name] = ifi.HardwareAddr.String()
+ }
+
+ out, err := runCmd("getmac", "/fo", "list", "/v")
+ if err != nil {
+ t.Fatal(err)
+ }
+ // getmac output looks like:
+ //
+ //Connection Name: Local Area Connection
+ //Network Adapter: Intel Gigabit Network Connection
+ //Physical Address: XX-XX-XX-XX-XX-XX
+ //Transport Name: \Device\Tcpip_{XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX}
+ //
+ //Connection Name: Wireless Network Connection
+ //Network Adapter: Wireles WLAN Card
+ //Physical Address: XX-XX-XX-XX-XX-XX
+ //Transport Name: Media disconnected
+ //
+ //Connection Name: Bluetooth Network Connection
+ //Network Adapter: Bluetooth Device (Personal Area Network)
+ //Physical Address: N/A
+ //Transport Name: Hardware not present
+ //
+ //Connection Name: VMware Network Adapter VMnet8
+ //Network Adapter: VMware Virtual Ethernet Adapter for VMnet8
+ //Physical Address: Disabled
+ //Transport Name: Disconnected
+ //
+ want := make(map[string]string)
+ group := make(map[string]string) // name / values for single adapter
+ getValue := func(name string) string {
+ value, found := group[name]
+ if !found {
+ t.Fatalf("%q has no %q line in it", group, name)
+ }
+ if value == "" {
+ t.Fatalf("%q has empty %q value", group, name)
+ }
+ return value
+ }
+ processGroup := func() {
+ if len(group) == 0 {
+ return
+ }
+ tname := strings.ToLower(getValue("Transport Name"))
+ if tname == "n/a" {
+ // skip these
+ return
+ }
+ addr := strings.ToLower(getValue("Physical Address"))
+ if addr == "disabled" || addr == "n/a" {
+ // skip these
+ return
+ }
+ addr = strings.ReplaceAll(addr, "-", ":")
+ cname := getValue("Connection Name")
+ want[cname] = addr
+ group = make(map[string]string)
+ }
+ lines := bytes.Split(out, []byte{'\r', '\n'})
+ for _, line := range lines {
+ if len(line) == 0 {
+ processGroup()
+ continue
+ }
+ i := bytes.IndexByte(line, ':')
+ if i == -1 {
+ t.Fatalf("line %q has no : in it", line)
+ }
+ group[string(line[:i])] = string(bytes.TrimSpace(line[i+1:]))
+ }
+ processGroup()
+
+ dups := make(map[string][]string)
+ for name, addr := range want {
+ if _, ok := dups[addr]; !ok {
+ dups[addr] = make([]string, 0)
+ }
+ dups[addr] = append(dups[addr], name)
+ }
+
+nextWant:
+ for name, wantAddr := range want {
+ if haveAddr, ok := have[name]; ok {
+ if haveAddr != wantAddr {
+ t.Errorf("unexpected MAC address for %q - %v, want %v", name, haveAddr, wantAddr)
+ }
+ continue
+ }
+ // We could not find the interface in getmac output by name.
+ // But sometimes getmac lists many interface names
+ // for the same MAC address. If that is the case here,
+ // and we can match at least one of those names,
+ // let's ignore the other names.
+ if dupNames, ok := dups[wantAddr]; ok && len(dupNames) > 1 {
+ for _, dupName := range dupNames {
+ if haveAddr, ok := have[dupName]; ok && haveAddr == wantAddr {
+ continue nextWant
+ }
+ }
+ }
+ t.Errorf("getmac lists %q, but it could not be found among Go interfaces %v", name, have)
+ }
+}
diff --git a/src/net/netcgo_off.go b/src/net/netcgo_off.go
new file mode 100644
index 0000000..54677dc
--- /dev/null
+++ b/src/net/netcgo_off.go
@@ -0,0 +1,9 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !netcgo
+
+package net
+
+const netCgoBuildTag = false
diff --git a/src/net/netcgo_on.go b/src/net/netcgo_on.go
new file mode 100644
index 0000000..25d4bdc
--- /dev/null
+++ b/src/net/netcgo_on.go
@@ -0,0 +1,9 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build netcgo
+
+package net
+
+const netCgoBuildTag = true
diff --git a/src/net/netgo_netcgo.go b/src/net/netgo_netcgo.go
new file mode 100644
index 0000000..7f3a5fd
--- /dev/null
+++ b/src/net/netgo_netcgo.go
@@ -0,0 +1,14 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build netgo && netcgo
+
+package net
+
+func init() {
+ // This will give a compile time error about the unused constant.
+ // The advantage of this approach is that the gc compiler
+ // actually prints the constant, making the problem obvious.
+ "Do not use both netgo and netcgo build tags."
+}
diff --git a/src/net/netgo_off.go b/src/net/netgo_off.go
new file mode 100644
index 0000000..e6bc2d7
--- /dev/null
+++ b/src/net/netgo_off.go
@@ -0,0 +1,9 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !netgo
+
+package net
+
+const netGoBuildTag = false
diff --git a/src/net/netgo_on.go b/src/net/netgo_on.go
new file mode 100644
index 0000000..4f088de
--- /dev/null
+++ b/src/net/netgo_on.go
@@ -0,0 +1,9 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build netgo
+
+package net
+
+const netGoBuildTag = true
diff --git a/src/net/netip/export_test.go b/src/net/netip/export_test.go
new file mode 100644
index 0000000..59971fa
--- /dev/null
+++ b/src/net/netip/export_test.go
@@ -0,0 +1,30 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package netip
+
+import "internal/intern"
+
+var (
+ Z0 = z0
+ Z4 = z4
+ Z6noz = z6noz
+)
+
+type Uint128 = uint128
+
+func Mk128(hi, lo uint64) Uint128 {
+ return uint128{hi, lo}
+}
+
+func MkAddr(u Uint128, z *intern.Value) Addr {
+ return Addr{u, z}
+}
+
+func IPv4(a, b, c, d uint8) Addr { return AddrFrom4([4]byte{a, b, c, d}) }
+
+var TestAppendToMarshal = testAppendToMarshal
+
+func (a Addr) IsZero() bool { return a.isZero() }
+func (p Prefix) IsZero() bool { return p.isZero() }
diff --git a/src/net/netip/fuzz_test.go b/src/net/netip/fuzz_test.go
new file mode 100644
index 0000000..c9fc6c9
--- /dev/null
+++ b/src/net/netip/fuzz_test.go
@@ -0,0 +1,351 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package netip_test
+
+import (
+ "bytes"
+ "encoding"
+ "fmt"
+ "net"
+ . "net/netip"
+ "reflect"
+ "strings"
+ "testing"
+)
+
+var corpus = []string{
+ // Basic zero IPv4 address.
+ "0.0.0.0",
+ // Basic non-zero IPv4 address.
+ "192.168.140.255",
+ // IPv4 address in windows-style "print all the digits" form.
+ "010.000.015.001",
+ // IPv4 address with a silly amount of leading zeros.
+ "000001.00000002.00000003.000000004",
+ // 4-in-6 with octet with leading zero
+ "::ffff:1.2.03.4",
+ // Basic zero IPv6 address.
+ "::",
+ // Localhost IPv6.
+ "::1",
+ // Fully expanded IPv6 address.
+ "fd7a:115c:a1e0:ab12:4843:cd96:626b:430b",
+ // IPv6 with elided fields in the middle.
+ "fd7a:115c::626b:430b",
+ // IPv6 with elided fields at the end.
+ "fd7a:115c:a1e0:ab12:4843:cd96::",
+ // IPv6 with single elided field at the end.
+ "fd7a:115c:a1e0:ab12:4843:cd96:626b::",
+ "fd7a:115c:a1e0:ab12:4843:cd96:626b:0",
+ // IPv6 with single elided field in the middle.
+ "fd7a:115c:a1e0::4843:cd96:626b:430b",
+ "fd7a:115c:a1e0:0:4843:cd96:626b:430b",
+ // IPv6 with the trailing 32 bits written as IPv4 dotted decimal. (4in6)
+ "::ffff:192.168.140.255",
+ "::ffff:192.168.140.255",
+ // IPv6 with a zone specifier.
+ "fd7a:115c:a1e0:ab12:4843:cd96:626b:430b%eth0",
+ // IPv6 with dotted decimal and zone specifier.
+ "1:2::ffff:192.168.140.255%eth1",
+ "1:2::ffff:c0a8:8cff%eth1",
+ // IPv6 with capital letters.
+ "FD9E:1A04:F01D::1",
+ "fd9e:1a04:f01d::1",
+ // Empty string.
+ "",
+ // Garbage non-IP.
+ "bad",
+ // Single number. Some parsers accept this as an IPv4 address in
+ // big-endian uint32 form, but we don't.
+ "1234",
+ // IPv4 with a zone specifier.
+ "1.2.3.4%eth0",
+ // IPv4 field must have at least one digit.
+ ".1.2.3",
+ "1.2.3.",
+ "1..2.3",
+ // IPv4 address too long.
+ "1.2.3.4.5",
+ // IPv4 in dotted octal form.
+ "0300.0250.0214.0377",
+ // IPv4 in dotted hex form.
+ "0xc0.0xa8.0x8c.0xff",
+ // IPv4 in class B form.
+ "192.168.12345",
+ // IPv4 in class B form, with a small enough number to be
+ // parseable as a regular dotted decimal field.
+ "127.0.1",
+ // IPv4 in class A form.
+ "192.1234567",
+ // IPv4 in class A form, with a small enough number to be
+ // parseable as a regular dotted decimal field.
+ "127.1",
+ // IPv4 field has value >255.
+ "192.168.300.1",
+ // IPv4 with too many fields.
+ "192.168.0.1.5.6",
+ // IPv6 with not enough fields.
+ "1:2:3:4:5:6:7",
+ // IPv6 with too many fields.
+ "1:2:3:4:5:6:7:8:9",
+ // IPv6 with 8 fields and a :: expander.
+ "1:2:3:4::5:6:7:8",
+ // IPv6 with a field bigger than 2b.
+ "fe801::1",
+ // IPv6 with non-hex values in field.
+ "fe80:tail:scal:e::",
+ // IPv6 with a zone delimiter but no zone.
+ "fe80::1%",
+ // IPv6 with a zone specifier of zero.
+ "::ffff:0:0%0",
+ // IPv6 (without ellipsis) with too many fields for trailing embedded IPv4.
+ "ffff:ffff:ffff:ffff:ffff:ffff:ffff:192.168.140.255",
+ // IPv6 (with ellipsis) with too many fields for trailing embedded IPv4.
+ "ffff::ffff:ffff:ffff:ffff:ffff:ffff:192.168.140.255",
+ // IPv6 with invalid embedded IPv4.
+ "::ffff:192.168.140.bad",
+ // IPv6 with multiple ellipsis ::.
+ "fe80::1::1",
+ // IPv6 with invalid non hex/colon character.
+ "fe80:1?:1",
+ // IPv6 with truncated bytes after single colon.
+ "fe80:",
+ // AddrPort strings.
+ "1.2.3.4:51820",
+ "[fd7a:115c:a1e0:ab12:4843:cd96:626b:430b]:80",
+ "[::ffff:c000:0280]:65535",
+ "[::ffff:c000:0280%eth0]:1",
+ // Prefix strings.
+ "1.2.3.4/24",
+ "fd7a:115c:a1e0:ab12:4843:cd96:626b:430b/118",
+ "::ffff:c000:0280/96",
+ "::ffff:c000:0280%eth0/37",
+}
+
+func FuzzParse(f *testing.F) {
+ for _, seed := range corpus {
+ f.Add(seed)
+ }
+
+ f.Fuzz(func(t *testing.T, s string) {
+ ip, _ := ParseAddr(s)
+ checkStringParseRoundTrip(t, ip, ParseAddr)
+ checkEncoding(t, ip)
+
+ // Check that we match the net's IP parser, modulo zones.
+ if !strings.Contains(s, "%") {
+ stdip := net.ParseIP(s)
+ if !ip.IsValid() != (stdip == nil) {
+ t.Errorf("ParseAddr zero != net.ParseIP nil: ip=%q stdip=%q", ip, stdip)
+ }
+
+ if ip.IsValid() && !ip.Is4In6() {
+ buf, err := ip.MarshalText()
+ if err != nil {
+ t.Fatal(err)
+ }
+ buf2, err := stdip.MarshalText()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !bytes.Equal(buf, buf2) {
+ t.Errorf("Addr.MarshalText() != net.IP.MarshalText(): ip=%q stdip=%q", ip, stdip)
+ }
+ if ip.String() != stdip.String() {
+ t.Errorf("Addr.String() != net.IP.String(): ip=%q stdip=%q", ip, stdip)
+ }
+ if ip.IsGlobalUnicast() != stdip.IsGlobalUnicast() {
+ t.Errorf("Addr.IsGlobalUnicast() != net.IP.IsGlobalUnicast(): ip=%q stdip=%q", ip, stdip)
+ }
+ if ip.IsInterfaceLocalMulticast() != stdip.IsInterfaceLocalMulticast() {
+ t.Errorf("Addr.IsInterfaceLocalMulticast() != net.IP.IsInterfaceLocalMulticast(): ip=%q stdip=%q", ip, stdip)
+ }
+ if ip.IsLinkLocalMulticast() != stdip.IsLinkLocalMulticast() {
+ t.Errorf("Addr.IsLinkLocalMulticast() != net.IP.IsLinkLocalMulticast(): ip=%q stdip=%q", ip, stdip)
+ }
+ if ip.IsLinkLocalUnicast() != stdip.IsLinkLocalUnicast() {
+ t.Errorf("Addr.IsLinkLocalUnicast() != net.IP.IsLinkLocalUnicast(): ip=%q stdip=%q", ip, stdip)
+ }
+ if ip.IsLoopback() != stdip.IsLoopback() {
+ t.Errorf("Addr.IsLoopback() != net.IP.IsLoopback(): ip=%q stdip=%q", ip, stdip)
+ }
+ if ip.IsMulticast() != stdip.IsMulticast() {
+ t.Errorf("Addr.IsMulticast() != net.IP.IsMulticast(): ip=%q stdip=%q", ip, stdip)
+ }
+ if ip.IsPrivate() != stdip.IsPrivate() {
+ t.Errorf("Addr.IsPrivate() != net.IP.IsPrivate(): ip=%q stdip=%q", ip, stdip)
+ }
+ if ip.IsUnspecified() != stdip.IsUnspecified() {
+ t.Errorf("Addr.IsUnspecified() != net.IP.IsUnspecified(): ip=%q stdip=%q", ip, stdip)
+ }
+ }
+ }
+
+ // Check that .Next().Prev() and .Prev().Next() preserve the IP.
+ if ip.IsValid() && ip.Next().IsValid() && ip.Next().Prev() != ip {
+ t.Errorf(".Next.Prev did not round trip: ip=%q .next=%q .next.prev=%q", ip, ip.Next(), ip.Next().Prev())
+ }
+ if ip.IsValid() && ip.Prev().IsValid() && ip.Prev().Next() != ip {
+ t.Errorf(".Prev.Next did not round trip: ip=%q .prev=%q .prev.next=%q", ip, ip.Prev(), ip.Prev().Next())
+ }
+
+ port, err := ParseAddrPort(s)
+ if err == nil {
+ checkStringParseRoundTrip(t, port, ParseAddrPort)
+ checkEncoding(t, port)
+ }
+ port = AddrPortFrom(ip, 80)
+ checkStringParseRoundTrip(t, port, ParseAddrPort)
+ checkEncoding(t, port)
+
+ ipp, err := ParsePrefix(s)
+ if err == nil {
+ checkStringParseRoundTrip(t, ipp, ParsePrefix)
+ checkEncoding(t, ipp)
+ }
+ ipp = PrefixFrom(ip, 8)
+ checkStringParseRoundTrip(t, ipp, ParsePrefix)
+ checkEncoding(t, ipp)
+ })
+}
+
+// checkTextMarshaler checks that x's MarshalText and UnmarshalText functions round trip correctly.
+func checkTextMarshaler(t *testing.T, x encoding.TextMarshaler) {
+ buf, err := x.MarshalText()
+ if err != nil {
+ t.Fatal(err)
+ }
+ y := reflect.New(reflect.TypeOf(x)).Interface().(encoding.TextUnmarshaler)
+ err = y.UnmarshalText(buf)
+ if err != nil {
+ t.Logf("(%v).MarshalText() = %q", x, buf)
+ t.Fatalf("(%T).UnmarshalText(%q) = %v", y, buf, err)
+ }
+ e := reflect.ValueOf(y).Elem().Interface()
+ if !reflect.DeepEqual(x, e) {
+ t.Logf("(%v).MarshalText() = %q", x, buf)
+ t.Logf("(%T).UnmarshalText(%q) = %v", y, buf, y)
+ t.Fatalf("MarshalText/UnmarshalText failed to round trip: %#v != %#v", x, e)
+ }
+ buf2, err := y.(encoding.TextMarshaler).MarshalText()
+ if err != nil {
+ t.Logf("(%v).MarshalText() = %q", x, buf)
+ t.Logf("(%T).UnmarshalText(%q) = %v", y, buf, y)
+ t.Fatalf("failed to MarshalText a second time: %v", err)
+ }
+ if !bytes.Equal(buf, buf2) {
+ t.Logf("(%v).MarshalText() = %q", x, buf)
+ t.Logf("(%T).UnmarshalText(%q) = %v", y, buf, y)
+ t.Logf("(%v).MarshalText() = %q", y, buf2)
+ t.Fatalf("second MarshalText differs from first: %q != %q", buf, buf2)
+ }
+}
+
+// checkBinaryMarshaler checks that x's MarshalText and UnmarshalText functions round trip correctly.
+func checkBinaryMarshaler(t *testing.T, x encoding.BinaryMarshaler) {
+ buf, err := x.MarshalBinary()
+ if err != nil {
+ t.Fatal(err)
+ }
+ y := reflect.New(reflect.TypeOf(x)).Interface().(encoding.BinaryUnmarshaler)
+ err = y.UnmarshalBinary(buf)
+ if err != nil {
+ t.Logf("(%v).MarshalBinary() = %q", x, buf)
+ t.Fatalf("(%T).UnmarshalBinary(%q) = %v", y, buf, err)
+ }
+ e := reflect.ValueOf(y).Elem().Interface()
+ if !reflect.DeepEqual(x, e) {
+ t.Logf("(%v).MarshalBinary() = %q", x, buf)
+ t.Logf("(%T).UnmarshalBinary(%q) = %v", y, buf, y)
+ t.Fatalf("MarshalBinary/UnmarshalBinary failed to round trip: %#v != %#v", x, e)
+ }
+ buf2, err := y.(encoding.BinaryMarshaler).MarshalBinary()
+ if err != nil {
+ t.Logf("(%v).MarshalBinary() = %q", x, buf)
+ t.Logf("(%T).UnmarshalBinary(%q) = %v", y, buf, y)
+ t.Fatalf("failed to MarshalBinary a second time: %v", err)
+ }
+ if !bytes.Equal(buf, buf2) {
+ t.Logf("(%v).MarshalBinary() = %q", x, buf)
+ t.Logf("(%T).UnmarshalBinary(%q) = %v", y, buf, y)
+ t.Logf("(%v).MarshalBinary() = %q", y, buf2)
+ t.Fatalf("second MarshalBinary differs from first: %q != %q", buf, buf2)
+ }
+}
+
+func checkTextMarshalMatchesString(t *testing.T, x netipType) {
+ buf, err := x.MarshalText()
+ if err != nil {
+ t.Fatal(err)
+ }
+ str := x.String()
+ if string(buf) != str {
+ t.Fatalf("%v: MarshalText = %q, String = %q", x, buf, str)
+ }
+}
+
+type appendMarshaler interface {
+ encoding.TextMarshaler
+ AppendTo([]byte) []byte
+}
+
+// checkTextMarshalMatchesAppendTo checks that x's MarshalText matches x's AppendTo.
+func checkTextMarshalMatchesAppendTo(t *testing.T, x appendMarshaler) {
+ buf, err := x.MarshalText()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ buf2 := make([]byte, 0, len(buf))
+ buf2 = x.AppendTo(buf2)
+ if !bytes.Equal(buf, buf2) {
+ t.Fatalf("%v: MarshalText = %q, AppendTo = %q", x, buf, buf2)
+ }
+}
+
+type netipType interface {
+ encoding.BinaryMarshaler
+ encoding.TextMarshaler
+ fmt.Stringer
+ IsValid() bool
+}
+
+type netipTypeCmp interface {
+ comparable
+ netipType
+}
+
+// checkStringParseRoundTrip checks that x's String method and the provided parse function can round trip correctly.
+func checkStringParseRoundTrip[P netipTypeCmp](t *testing.T, x P, parse func(string) (P, error)) {
+ if !x.IsValid() {
+ // Ignore invalid values.
+ return
+ }
+
+ s := x.String()
+ y, err := parse(s)
+ if err != nil {
+ t.Fatalf("s=%q err=%v", s, err)
+ }
+ if x != y {
+ t.Fatalf("%T round trip identity failure: s=%q x=%#v y=%#v", x, s, x, y)
+ }
+ s2 := y.String()
+ if s != s2 {
+ t.Fatalf("%T String round trip identity failure: s=%#v s2=%#v", x, s, s2)
+ }
+}
+
+func checkEncoding(t *testing.T, x netipType) {
+ if x.IsValid() {
+ checkTextMarshaler(t, x)
+ checkBinaryMarshaler(t, x)
+ checkTextMarshalMatchesString(t, x)
+ }
+
+ if am, ok := x.(appendMarshaler); ok {
+ checkTextMarshalMatchesAppendTo(t, am)
+ }
+}
diff --git a/src/net/netip/inlining_test.go b/src/net/netip/inlining_test.go
new file mode 100644
index 0000000..b521eee
--- /dev/null
+++ b/src/net/netip/inlining_test.go
@@ -0,0 +1,102 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package netip
+
+import (
+ "internal/testenv"
+ "os/exec"
+ "regexp"
+ "runtime"
+ "strings"
+ "testing"
+)
+
+func TestInlining(t *testing.T) {
+ testenv.MustHaveGoBuild(t)
+ t.Parallel()
+ out, err := exec.Command(
+ testenv.GoToolPath(t),
+ "build",
+ "--gcflags=-m",
+ "net/netip").CombinedOutput()
+ if err != nil {
+ t.Fatalf("go build: %v, %s", err, out)
+ }
+ got := map[string]bool{}
+ regexp.MustCompile(` can inline (\S+)`).ReplaceAllFunc(out, func(match []byte) []byte {
+ got[strings.TrimPrefix(string(match), " can inline ")] = true
+ return nil
+ })
+ wantInlinable := []string{
+ "(*uint128).halves",
+ "Addr.BitLen",
+ "Addr.hasZone",
+ "Addr.Is4",
+ "Addr.Is4In6",
+ "Addr.Is6",
+ "Addr.IsLoopback",
+ "Addr.IsMulticast",
+ "Addr.IsInterfaceLocalMulticast",
+ "Addr.IsValid",
+ "Addr.IsUnspecified",
+ "Addr.Less",
+ "Addr.Unmap",
+ "Addr.Zone",
+ "Addr.v4",
+ "Addr.v6",
+ "Addr.v6u16",
+ "Addr.withoutZone",
+ "AddrPortFrom",
+ "AddrPort.Addr",
+ "AddrPort.Port",
+ "AddrPort.IsValid",
+ "Prefix.IsSingleIP",
+ "Prefix.Masked",
+ "Prefix.IsValid",
+ "PrefixFrom",
+ "Prefix.Addr",
+ "Prefix.Bits",
+ "AddrFrom4",
+ "IPv6LinkLocalAllNodes",
+ "IPv6Unspecified",
+ "MustParseAddr",
+ "MustParseAddrPort",
+ "MustParsePrefix",
+ "appendDecimal",
+ "appendHex",
+ "uint128.addOne",
+ "uint128.and",
+ "uint128.bitsClearedFrom",
+ "uint128.bitsSetFrom",
+ "uint128.isZero",
+ "uint128.not",
+ "uint128.or",
+ "uint128.subOne",
+ "uint128.xor",
+ }
+ switch runtime.GOARCH {
+ case "amd64", "arm64":
+ // These don't inline on 32-bit.
+ wantInlinable = append(wantInlinable,
+ "Addr.Next",
+ "Addr.Prev",
+ )
+ }
+
+ for _, want := range wantInlinable {
+ if !got[want] {
+ t.Errorf("%q is no longer inlinable", want)
+ continue
+ }
+ delete(got, want)
+ }
+ for sym := range got {
+ if strings.Contains(sym, ".func") {
+ continue
+ }
+ t.Logf("not in expected set, but also inlinable: %q", sym)
+
+ }
+}
diff --git a/src/net/netip/leaf_alts.go b/src/net/netip/leaf_alts.go
new file mode 100644
index 0000000..70513ab
--- /dev/null
+++ b/src/net/netip/leaf_alts.go
@@ -0,0 +1,54 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Stuff that exists in std, but we can't use due to being a dependency
+// of net, for go/build deps_test policy reasons.
+
+package netip
+
+func stringsLastIndexByte(s string, b byte) int {
+ for i := len(s) - 1; i >= 0; i-- {
+ if s[i] == b {
+ return i
+ }
+ }
+ return -1
+}
+
+func beUint64(b []byte) uint64 {
+ _ = b[7] // bounds check hint to compiler; see golang.org/issue/14808
+ return uint64(b[7]) | uint64(b[6])<<8 | uint64(b[5])<<16 | uint64(b[4])<<24 |
+ uint64(b[3])<<32 | uint64(b[2])<<40 | uint64(b[1])<<48 | uint64(b[0])<<56
+}
+
+func bePutUint64(b []byte, v uint64) {
+ _ = b[7] // early bounds check to guarantee safety of writes below
+ b[0] = byte(v >> 56)
+ b[1] = byte(v >> 48)
+ b[2] = byte(v >> 40)
+ b[3] = byte(v >> 32)
+ b[4] = byte(v >> 24)
+ b[5] = byte(v >> 16)
+ b[6] = byte(v >> 8)
+ b[7] = byte(v)
+}
+
+func bePutUint32(b []byte, v uint32) {
+ _ = b[3] // early bounds check to guarantee safety of writes below
+ b[0] = byte(v >> 24)
+ b[1] = byte(v >> 16)
+ b[2] = byte(v >> 8)
+ b[3] = byte(v)
+}
+
+func leUint16(b []byte) uint16 {
+ _ = b[1] // bounds check hint to compiler; see golang.org/issue/14808
+ return uint16(b[0]) | uint16(b[1])<<8
+}
+
+func lePutUint16(b []byte, v uint16) {
+ _ = b[1] // early bounds check to guarantee safety of writes below
+ b[0] = byte(v)
+ b[1] = byte(v >> 8)
+}
diff --git a/src/net/netip/netip.go b/src/net/netip/netip.go
new file mode 100644
index 0000000..a44b094
--- /dev/null
+++ b/src/net/netip/netip.go
@@ -0,0 +1,1482 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package netip defines an IP address type that's a small value type.
+// Building on that [Addr] type, the package also defines [AddrPort] (an
+// IP address and a port) and [Prefix] (an IP address and a bit length
+// prefix).
+//
+// Compared to the [net.IP] type, [Addr] type takes less memory, is immutable,
+// and is comparable (supports == and being a map key).
+package netip
+
+import (
+ "errors"
+ "math"
+ "strconv"
+
+ "internal/bytealg"
+ "internal/intern"
+ "internal/itoa"
+)
+
+// Sizes: (64-bit)
+// net.IP: 24 byte slice header + {4, 16} = 28 to 40 bytes
+// net.IPAddr: 40 byte slice header + {4, 16} = 44 to 56 bytes + zone length
+// netip.Addr: 24 bytes (zone is per-name singleton, shared across all users)
+
+// Addr represents an IPv4 or IPv6 address (with or without a scoped
+// addressing zone), similar to [net.IP] or [net.IPAddr].
+//
+// Unlike [net.IP] or [net.IPAddr], Addr is a comparable value
+// type (it supports == and can be a map key) and is immutable.
+//
+// The zero Addr is not a valid IP address.
+// Addr{} is distinct from both 0.0.0.0 and ::.
+type Addr struct {
+ // addr is the hi and lo bits of an IPv6 address. If z==z4,
+ // hi and lo contain the IPv4-mapped IPv6 address.
+ //
+ // hi and lo are constructed by interpreting a 16-byte IPv6
+ // address as a big-endian 128-bit number. The most significant
+ // bits of that number go into hi, the rest into lo.
+ //
+ // For example, 0011:2233:4455:6677:8899:aabb:ccdd:eeff is stored as:
+ // addr.hi = 0x0011223344556677
+ // addr.lo = 0x8899aabbccddeeff
+ //
+ // We store IPs like this, rather than as [16]byte, because it
+ // turns most operations on IPs into arithmetic and bit-twiddling
+ // operations on 64-bit registers, which is much faster than
+ // bytewise processing.
+ addr uint128
+
+ // z is a combination of the address family and the IPv6 zone.
+ //
+ // nil means invalid IP address (for a zero Addr).
+ // z4 means an IPv4 address.
+ // z6noz means an IPv6 address without a zone.
+ //
+ // Otherwise it's the interned zone name string.
+ z *intern.Value
+}
+
+// z0, z4, and z6noz are sentinel Addr.z values.
+// See the Addr type's field docs.
+var (
+ z0 = (*intern.Value)(nil)
+ z4 = new(intern.Value)
+ z6noz = new(intern.Value)
+)
+
+// IPv6LinkLocalAllNodes returns the IPv6 link-local all nodes multicast
+// address ff02::1.
+func IPv6LinkLocalAllNodes() Addr { return AddrFrom16([16]byte{0: 0xff, 1: 0x02, 15: 0x01}) }
+
+// IPv6LinkLocalAllRouters returns the IPv6 link-local all routers multicast
+// address ff02::2.
+func IPv6LinkLocalAllRouters() Addr { return AddrFrom16([16]byte{0: 0xff, 1: 0x02, 15: 0x02}) }
+
+// IPv6Loopback returns the IPv6 loopback address ::1.
+func IPv6Loopback() Addr { return AddrFrom16([16]byte{15: 0x01}) }
+
+// IPv6Unspecified returns the IPv6 unspecified address "::".
+func IPv6Unspecified() Addr { return Addr{z: z6noz} }
+
+// IPv4Unspecified returns the IPv4 unspecified address "0.0.0.0".
+func IPv4Unspecified() Addr { return AddrFrom4([4]byte{}) }
+
+// AddrFrom4 returns the address of the IPv4 address given by the bytes in addr.
+func AddrFrom4(addr [4]byte) Addr {
+ return Addr{
+ addr: uint128{0, 0xffff00000000 | uint64(addr[0])<<24 | uint64(addr[1])<<16 | uint64(addr[2])<<8 | uint64(addr[3])},
+ z: z4,
+ }
+}
+
+// AddrFrom16 returns the IPv6 address given by the bytes in addr.
+// An IPv4-mapped IPv6 address is left as an IPv6 address.
+// (Use Unmap to convert them if needed.)
+func AddrFrom16(addr [16]byte) Addr {
+ return Addr{
+ addr: uint128{
+ beUint64(addr[:8]),
+ beUint64(addr[8:]),
+ },
+ z: z6noz,
+ }
+}
+
+// ParseAddr parses s as an IP address, returning the result. The string
+// s can be in dotted decimal ("192.0.2.1"), IPv6 ("2001:db8::68"),
+// or IPv6 with a scoped addressing zone ("fe80::1cc0:3e8c:119f:c2e1%ens18").
+func ParseAddr(s string) (Addr, error) {
+ for i := 0; i < len(s); i++ {
+ switch s[i] {
+ case '.':
+ return parseIPv4(s)
+ case ':':
+ return parseIPv6(s)
+ case '%':
+ // Assume that this was trying to be an IPv6 address with
+ // a zone specifier, but the address is missing.
+ return Addr{}, parseAddrError{in: s, msg: "missing IPv6 address"}
+ }
+ }
+ return Addr{}, parseAddrError{in: s, msg: "unable to parse IP"}
+}
+
+// MustParseAddr calls ParseAddr(s) and panics on error.
+// It is intended for use in tests with hard-coded strings.
+func MustParseAddr(s string) Addr {
+ ip, err := ParseAddr(s)
+ if err != nil {
+ panic(err)
+ }
+ return ip
+}
+
+type parseAddrError struct {
+ in string // the string given to ParseAddr
+ msg string // an explanation of the parse failure
+ at string // optionally, the unparsed portion of in at which the error occurred.
+}
+
+func (err parseAddrError) Error() string {
+ q := strconv.Quote
+ if err.at != "" {
+ return "ParseAddr(" + q(err.in) + "): " + err.msg + " (at " + q(err.at) + ")"
+ }
+ return "ParseAddr(" + q(err.in) + "): " + err.msg
+}
+
+// parseIPv4 parses s as an IPv4 address (in form "192.168.0.1").
+func parseIPv4(s string) (ip Addr, err error) {
+ var fields [4]uint8
+ var val, pos int
+ var digLen int // number of digits in current octet
+ for i := 0; i < len(s); i++ {
+ if s[i] >= '0' && s[i] <= '9' {
+ if digLen == 1 && val == 0 {
+ return Addr{}, parseAddrError{in: s, msg: "IPv4 field has octet with leading zero"}
+ }
+ val = val*10 + int(s[i]) - '0'
+ digLen++
+ if val > 255 {
+ return Addr{}, parseAddrError{in: s, msg: "IPv4 field has value >255"}
+ }
+ } else if s[i] == '.' {
+ // .1.2.3
+ // 1.2.3.
+ // 1..2.3
+ if i == 0 || i == len(s)-1 || s[i-1] == '.' {
+ return Addr{}, parseAddrError{in: s, msg: "IPv4 field must have at least one digit", at: s[i:]}
+ }
+ // 1.2.3.4.5
+ if pos == 3 {
+ return Addr{}, parseAddrError{in: s, msg: "IPv4 address too long"}
+ }
+ fields[pos] = uint8(val)
+ pos++
+ val = 0
+ digLen = 0
+ } else {
+ return Addr{}, parseAddrError{in: s, msg: "unexpected character", at: s[i:]}
+ }
+ }
+ if pos < 3 {
+ return Addr{}, parseAddrError{in: s, msg: "IPv4 address too short"}
+ }
+ fields[3] = uint8(val)
+ return AddrFrom4(fields), nil
+}
+
+// parseIPv6 parses s as an IPv6 address (in form "2001:db8::68").
+func parseIPv6(in string) (Addr, error) {
+ s := in
+
+ // Split off the zone right from the start. Yes it's a second scan
+ // of the string, but trying to handle it inline makes a bunch of
+ // other inner loop conditionals more expensive, and it ends up
+ // being slower.
+ zone := ""
+ i := bytealg.IndexByteString(s, '%')
+ if i != -1 {
+ s, zone = s[:i], s[i+1:]
+ if zone == "" {
+ // Not allowed to have an empty zone if explicitly specified.
+ return Addr{}, parseAddrError{in: in, msg: "zone must be a non-empty string"}
+ }
+ }
+
+ var ip [16]byte
+ ellipsis := -1 // position of ellipsis in ip
+
+ // Might have leading ellipsis
+ if len(s) >= 2 && s[0] == ':' && s[1] == ':' {
+ ellipsis = 0
+ s = s[2:]
+ // Might be only ellipsis
+ if len(s) == 0 {
+ return IPv6Unspecified().WithZone(zone), nil
+ }
+ }
+
+ // Loop, parsing hex numbers followed by colon.
+ i = 0
+ for i < 16 {
+ // Hex number. Similar to parseIPv4, inlining the hex number
+ // parsing yields a significant performance increase.
+ off := 0
+ acc := uint32(0)
+ for ; off < len(s); off++ {
+ c := s[off]
+ if c >= '0' && c <= '9' {
+ acc = (acc << 4) + uint32(c-'0')
+ } else if c >= 'a' && c <= 'f' {
+ acc = (acc << 4) + uint32(c-'a'+10)
+ } else if c >= 'A' && c <= 'F' {
+ acc = (acc << 4) + uint32(c-'A'+10)
+ } else {
+ break
+ }
+ if acc > math.MaxUint16 {
+ // Overflow, fail.
+ return Addr{}, parseAddrError{in: in, msg: "IPv6 field has value >=2^16", at: s}
+ }
+ }
+ if off == 0 {
+ // No digits found, fail.
+ return Addr{}, parseAddrError{in: in, msg: "each colon-separated field must have at least one digit", at: s}
+ }
+
+ // If followed by dot, might be in trailing IPv4.
+ if off < len(s) && s[off] == '.' {
+ if ellipsis < 0 && i != 12 {
+ // Not the right place.
+ return Addr{}, parseAddrError{in: in, msg: "embedded IPv4 address must replace the final 2 fields of the address", at: s}
+ }
+ if i+4 > 16 {
+ // Not enough room.
+ return Addr{}, parseAddrError{in: in, msg: "too many hex fields to fit an embedded IPv4 at the end of the address", at: s}
+ }
+ // TODO: could make this a bit faster by having a helper
+ // that parses to a [4]byte, and have both parseIPv4 and
+ // parseIPv6 use it.
+ ip4, err := parseIPv4(s)
+ if err != nil {
+ return Addr{}, parseAddrError{in: in, msg: err.Error(), at: s}
+ }
+ ip[i] = ip4.v4(0)
+ ip[i+1] = ip4.v4(1)
+ ip[i+2] = ip4.v4(2)
+ ip[i+3] = ip4.v4(3)
+ s = ""
+ i += 4
+ break
+ }
+
+ // Save this 16-bit chunk.
+ ip[i] = byte(acc >> 8)
+ ip[i+1] = byte(acc)
+ i += 2
+
+ // Stop at end of string.
+ s = s[off:]
+ if len(s) == 0 {
+ break
+ }
+
+ // Otherwise must be followed by colon and more.
+ if s[0] != ':' {
+ return Addr{}, parseAddrError{in: in, msg: "unexpected character, want colon", at: s}
+ } else if len(s) == 1 {
+ return Addr{}, parseAddrError{in: in, msg: "colon must be followed by more characters", at: s}
+ }
+ s = s[1:]
+
+ // Look for ellipsis.
+ if s[0] == ':' {
+ if ellipsis >= 0 { // already have one
+ return Addr{}, parseAddrError{in: in, msg: "multiple :: in address", at: s}
+ }
+ ellipsis = i
+ s = s[1:]
+ if len(s) == 0 { // can be at end
+ break
+ }
+ }
+ }
+
+ // Must have used entire string.
+ if len(s) != 0 {
+ return Addr{}, parseAddrError{in: in, msg: "trailing garbage after address", at: s}
+ }
+
+ // If didn't parse enough, expand ellipsis.
+ if i < 16 {
+ if ellipsis < 0 {
+ return Addr{}, parseAddrError{in: in, msg: "address string too short"}
+ }
+ n := 16 - i
+ for j := i - 1; j >= ellipsis; j-- {
+ ip[j+n] = ip[j]
+ }
+ for j := ellipsis + n - 1; j >= ellipsis; j-- {
+ ip[j] = 0
+ }
+ } else if ellipsis >= 0 {
+ // Ellipsis must represent at least one 0 group.
+ return Addr{}, parseAddrError{in: in, msg: "the :: must expand to at least one field of zeros"}
+ }
+ return AddrFrom16(ip).WithZone(zone), nil
+}
+
+// AddrFromSlice parses the 4- or 16-byte byte slice as an IPv4 or IPv6 address.
+// Note that a net.IP can be passed directly as the []byte argument.
+// If slice's length is not 4 or 16, AddrFromSlice returns Addr{}, false.
+func AddrFromSlice(slice []byte) (ip Addr, ok bool) {
+ switch len(slice) {
+ case 4:
+ return AddrFrom4([4]byte(slice)), true
+ case 16:
+ return AddrFrom16([16]byte(slice)), true
+ }
+ return Addr{}, false
+}
+
+// v4 returns the i'th byte of ip. If ip is not an IPv4, v4 returns
+// unspecified garbage.
+func (ip Addr) v4(i uint8) uint8 {
+ return uint8(ip.addr.lo >> ((3 - i) * 8))
+}
+
+// v6 returns the i'th byte of ip. If ip is an IPv4 address, this
+// accesses the IPv4-mapped IPv6 address form of the IP.
+func (ip Addr) v6(i uint8) uint8 {
+ return uint8(*(ip.addr.halves()[(i/8)%2]) >> ((7 - i%8) * 8))
+}
+
+// v6u16 returns the i'th 16-bit word of ip. If ip is an IPv4 address,
+// this accesses the IPv4-mapped IPv6 address form of the IP.
+func (ip Addr) v6u16(i uint8) uint16 {
+ return uint16(*(ip.addr.halves()[(i/4)%2]) >> ((3 - i%4) * 16))
+}
+
+// isZero reports whether ip is the zero value of the IP type.
+// The zero value is not a valid IP address of any type.
+//
+// Note that "0.0.0.0" and "::" are not the zero value. Use IsUnspecified to
+// check for these values instead.
+func (ip Addr) isZero() bool {
+ // Faster than comparing ip == Addr{}, but effectively equivalent,
+ // as there's no way to make an IP with a nil z from this package.
+ return ip.z == z0
+}
+
+// IsValid reports whether the Addr is an initialized address (not the zero Addr).
+//
+// Note that "0.0.0.0" and "::" are both valid values.
+func (ip Addr) IsValid() bool { return ip.z != z0 }
+
+// BitLen returns the number of bits in the IP address:
+// 128 for IPv6, 32 for IPv4, and 0 for the zero Addr.
+//
+// Note that IPv4-mapped IPv6 addresses are considered IPv6 addresses
+// and therefore have bit length 128.
+func (ip Addr) BitLen() int {
+ switch ip.z {
+ case z0:
+ return 0
+ case z4:
+ return 32
+ }
+ return 128
+}
+
+// Zone returns ip's IPv6 scoped addressing zone, if any.
+func (ip Addr) Zone() string {
+ if ip.z == nil {
+ return ""
+ }
+ zone, _ := ip.z.Get().(string)
+ return zone
+}
+
+// Compare returns an integer comparing two IPs.
+// The result will be 0 if ip == ip2, -1 if ip < ip2, and +1 if ip > ip2.
+// The definition of "less than" is the same as the Less method.
+func (ip Addr) Compare(ip2 Addr) int {
+ f1, f2 := ip.BitLen(), ip2.BitLen()
+ if f1 < f2 {
+ return -1
+ }
+ if f1 > f2 {
+ return 1
+ }
+ hi1, hi2 := ip.addr.hi, ip2.addr.hi
+ if hi1 < hi2 {
+ return -1
+ }
+ if hi1 > hi2 {
+ return 1
+ }
+ lo1, lo2 := ip.addr.lo, ip2.addr.lo
+ if lo1 < lo2 {
+ return -1
+ }
+ if lo1 > lo2 {
+ return 1
+ }
+ if ip.Is6() {
+ za, zb := ip.Zone(), ip2.Zone()
+ if za < zb {
+ return -1
+ }
+ if za > zb {
+ return 1
+ }
+ }
+ return 0
+}
+
+// Less reports whether ip sorts before ip2.
+// IP addresses sort first by length, then their address.
+// IPv6 addresses with zones sort just after the same address without a zone.
+func (ip Addr) Less(ip2 Addr) bool { return ip.Compare(ip2) == -1 }
+
+// Is4 reports whether ip is an IPv4 address.
+//
+// It returns false for IPv4-mapped IPv6 addresses. See Addr.Unmap.
+func (ip Addr) Is4() bool {
+ return ip.z == z4
+}
+
+// Is4In6 reports whether ip is an IPv4-mapped IPv6 address.
+func (ip Addr) Is4In6() bool {
+ return ip.Is6() && ip.addr.hi == 0 && ip.addr.lo>>32 == 0xffff
+}
+
+// Is6 reports whether ip is an IPv6 address, including IPv4-mapped
+// IPv6 addresses.
+func (ip Addr) Is6() bool {
+ return ip.z != z0 && ip.z != z4
+}
+
+// Unmap returns ip with any IPv4-mapped IPv6 address prefix removed.
+//
+// That is, if ip is an IPv6 address wrapping an IPv4 address, it
+// returns the wrapped IPv4 address. Otherwise it returns ip unmodified.
+func (ip Addr) Unmap() Addr {
+ if ip.Is4In6() {
+ ip.z = z4
+ }
+ return ip
+}
+
+// WithZone returns an IP that's the same as ip but with the provided
+// zone. If zone is empty, the zone is removed. If ip is an IPv4
+// address, WithZone is a no-op and returns ip unchanged.
+func (ip Addr) WithZone(zone string) Addr {
+ if !ip.Is6() {
+ return ip
+ }
+ if zone == "" {
+ ip.z = z6noz
+ return ip
+ }
+ ip.z = intern.GetByString(zone)
+ return ip
+}
+
+// withoutZone unconditionally strips the zone from ip.
+// It's similar to WithZone, but small enough to be inlinable.
+func (ip Addr) withoutZone() Addr {
+ if !ip.Is6() {
+ return ip
+ }
+ ip.z = z6noz
+ return ip
+}
+
+// hasZone reports whether ip has an IPv6 zone.
+func (ip Addr) hasZone() bool {
+ return ip.z != z0 && ip.z != z4 && ip.z != z6noz
+}
+
+// IsLinkLocalUnicast reports whether ip is a link-local unicast address.
+func (ip Addr) IsLinkLocalUnicast() bool {
+ // Dynamic Configuration of IPv4 Link-Local Addresses
+ // https://datatracker.ietf.org/doc/html/rfc3927#section-2.1
+ if ip.Is4() {
+ return ip.v4(0) == 169 && ip.v4(1) == 254
+ }
+ // IP Version 6 Addressing Architecture (2.4 Address Type Identification)
+ // https://datatracker.ietf.org/doc/html/rfc4291#section-2.4
+ if ip.Is6() {
+ return ip.v6u16(0)&0xffc0 == 0xfe80
+ }
+ return false // zero value
+}
+
+// IsLoopback reports whether ip is a loopback address.
+func (ip Addr) IsLoopback() bool {
+ // Requirements for Internet Hosts -- Communication Layers (3.2.1.3 Addressing)
+ // https://datatracker.ietf.org/doc/html/rfc1122#section-3.2.1.3
+ if ip.Is4() {
+ return ip.v4(0) == 127
+ }
+ // IP Version 6 Addressing Architecture (2.4 Address Type Identification)
+ // https://datatracker.ietf.org/doc/html/rfc4291#section-2.4
+ if ip.Is6() {
+ return ip.addr.hi == 0 && ip.addr.lo == 1
+ }
+ return false // zero value
+}
+
+// IsMulticast reports whether ip is a multicast address.
+func (ip Addr) IsMulticast() bool {
+ // Host Extensions for IP Multicasting (4. HOST GROUP ADDRESSES)
+ // https://datatracker.ietf.org/doc/html/rfc1112#section-4
+ if ip.Is4() {
+ return ip.v4(0)&0xf0 == 0xe0
+ }
+ // IP Version 6 Addressing Architecture (2.4 Address Type Identification)
+ // https://datatracker.ietf.org/doc/html/rfc4291#section-2.4
+ if ip.Is6() {
+ return ip.addr.hi>>(64-8) == 0xff // ip.v6(0) == 0xff
+ }
+ return false // zero value
+}
+
+// IsInterfaceLocalMulticast reports whether ip is an IPv6 interface-local
+// multicast address.
+func (ip Addr) IsInterfaceLocalMulticast() bool {
+ // IPv6 Addressing Architecture (2.7.1. Pre-Defined Multicast Addresses)
+ // https://datatracker.ietf.org/doc/html/rfc4291#section-2.7.1
+ if ip.Is6() {
+ return ip.v6u16(0)&0xff0f == 0xff01
+ }
+ return false // zero value
+}
+
+// IsLinkLocalMulticast reports whether ip is a link-local multicast address.
+func (ip Addr) IsLinkLocalMulticast() bool {
+ // IPv4 Multicast Guidelines (4. Local Network Control Block (224.0.0/24))
+ // https://datatracker.ietf.org/doc/html/rfc5771#section-4
+ if ip.Is4() {
+ return ip.v4(0) == 224 && ip.v4(1) == 0 && ip.v4(2) == 0
+ }
+ // IPv6 Addressing Architecture (2.7.1. Pre-Defined Multicast Addresses)
+ // https://datatracker.ietf.org/doc/html/rfc4291#section-2.7.1
+ if ip.Is6() {
+ return ip.v6u16(0)&0xff0f == 0xff02
+ }
+ return false // zero value
+}
+
+// IsGlobalUnicast reports whether ip is a global unicast address.
+//
+// It returns true for IPv6 addresses which fall outside of the current
+// IANA-allocated 2000::/3 global unicast space, with the exception of the
+// link-local address space. It also returns true even if ip is in the IPv4
+// private address space or IPv6 unique local address space.
+// It returns false for the zero Addr.
+//
+// For reference, see RFC 1122, RFC 4291, and RFC 4632.
+func (ip Addr) IsGlobalUnicast() bool {
+ if ip.z == z0 {
+ // Invalid or zero-value.
+ return false
+ }
+
+ // Match package net's IsGlobalUnicast logic. Notably private IPv4 addresses
+ // and ULA IPv6 addresses are still considered "global unicast".
+ if ip.Is4() && (ip == IPv4Unspecified() || ip == AddrFrom4([4]byte{255, 255, 255, 255})) {
+ return false
+ }
+
+ return ip != IPv6Unspecified() &&
+ !ip.IsLoopback() &&
+ !ip.IsMulticast() &&
+ !ip.IsLinkLocalUnicast()
+}
+
+// IsPrivate reports whether ip is a private address, according to RFC 1918
+// (IPv4 addresses) and RFC 4193 (IPv6 addresses). That is, it reports whether
+// ip is in 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16, or fc00::/7. This is the
+// same as net.IP.IsPrivate.
+func (ip Addr) IsPrivate() bool {
+ // Match the stdlib's IsPrivate logic.
+ if ip.Is4() {
+ // RFC 1918 allocates 10.0.0.0/8, 172.16.0.0/12, and 192.168.0.0/16 as
+ // private IPv4 address subnets.
+ return ip.v4(0) == 10 ||
+ (ip.v4(0) == 172 && ip.v4(1)&0xf0 == 16) ||
+ (ip.v4(0) == 192 && ip.v4(1) == 168)
+ }
+
+ if ip.Is6() {
+ // RFC 4193 allocates fc00::/7 as the unique local unicast IPv6 address
+ // subnet.
+ return ip.v6(0)&0xfe == 0xfc
+ }
+
+ return false // zero value
+}
+
+// IsUnspecified reports whether ip is an unspecified address, either the IPv4
+// address "0.0.0.0" or the IPv6 address "::".
+//
+// Note that the zero Addr is not an unspecified address.
+func (ip Addr) IsUnspecified() bool {
+ return ip == IPv4Unspecified() || ip == IPv6Unspecified()
+}
+
+// Prefix keeps only the top b bits of IP, producing a Prefix
+// of the specified length.
+// If ip is a zero Addr, Prefix always returns a zero Prefix and a nil error.
+// Otherwise, if bits is less than zero or greater than ip.BitLen(),
+// Prefix returns an error.
+func (ip Addr) Prefix(b int) (Prefix, error) {
+ if b < 0 {
+ return Prefix{}, errors.New("negative Prefix bits")
+ }
+ effectiveBits := b
+ switch ip.z {
+ case z0:
+ return Prefix{}, nil
+ case z4:
+ if b > 32 {
+ return Prefix{}, errors.New("prefix length " + itoa.Itoa(b) + " too large for IPv4")
+ }
+ effectiveBits += 96
+ default:
+ if b > 128 {
+ return Prefix{}, errors.New("prefix length " + itoa.Itoa(b) + " too large for IPv6")
+ }
+ }
+ ip.addr = ip.addr.and(mask6(effectiveBits))
+ return PrefixFrom(ip, b), nil
+}
+
+const (
+ netIPv4len = 4
+ netIPv6len = 16
+)
+
+// As16 returns the IP address in its 16-byte representation.
+// IPv4 addresses are returned as IPv4-mapped IPv6 addresses.
+// IPv6 addresses with zones are returned without their zone (use the
+// Zone method to get it).
+// The ip zero value returns all zeroes.
+func (ip Addr) As16() (a16 [16]byte) {
+ bePutUint64(a16[:8], ip.addr.hi)
+ bePutUint64(a16[8:], ip.addr.lo)
+ return a16
+}
+
+// As4 returns an IPv4 or IPv4-in-IPv6 address in its 4-byte representation.
+// If ip is the zero Addr or an IPv6 address, As4 panics.
+// Note that 0.0.0.0 is not the zero Addr.
+func (ip Addr) As4() (a4 [4]byte) {
+ if ip.z == z4 || ip.Is4In6() {
+ bePutUint32(a4[:], uint32(ip.addr.lo))
+ return a4
+ }
+ if ip.z == z0 {
+ panic("As4 called on IP zero value")
+ }
+ panic("As4 called on IPv6 address")
+}
+
+// AsSlice returns an IPv4 or IPv6 address in its respective 4-byte or 16-byte representation.
+func (ip Addr) AsSlice() []byte {
+ switch ip.z {
+ case z0:
+ return nil
+ case z4:
+ var ret [4]byte
+ bePutUint32(ret[:], uint32(ip.addr.lo))
+ return ret[:]
+ default:
+ var ret [16]byte
+ bePutUint64(ret[:8], ip.addr.hi)
+ bePutUint64(ret[8:], ip.addr.lo)
+ return ret[:]
+ }
+}
+
+// Next returns the address following ip.
+// If there is none, it returns the zero Addr.
+func (ip Addr) Next() Addr {
+ ip.addr = ip.addr.addOne()
+ if ip.Is4() {
+ if uint32(ip.addr.lo) == 0 {
+ // Overflowed.
+ return Addr{}
+ }
+ } else {
+ if ip.addr.isZero() {
+ // Overflowed
+ return Addr{}
+ }
+ }
+ return ip
+}
+
+// Prev returns the IP before ip.
+// If there is none, it returns the IP zero value.
+func (ip Addr) Prev() Addr {
+ if ip.Is4() {
+ if uint32(ip.addr.lo) == 0 {
+ return Addr{}
+ }
+ } else if ip.addr.isZero() {
+ return Addr{}
+ }
+ ip.addr = ip.addr.subOne()
+ return ip
+}
+
+// String returns the string form of the IP address ip.
+// It returns one of 5 forms:
+//
+// - "invalid IP", if ip is the zero Addr
+// - IPv4 dotted decimal ("192.0.2.1")
+// - IPv6 ("2001:db8::1")
+// - "::ffff:1.2.3.4" (if Is4In6)
+// - IPv6 with zone ("fe80:db8::1%eth0")
+//
+// Note that unlike package net's IP.String method,
+// IPv4-mapped IPv6 addresses format with a "::ffff:"
+// prefix before the dotted quad.
+func (ip Addr) String() string {
+ switch ip.z {
+ case z0:
+ return "invalid IP"
+ case z4:
+ return ip.string4()
+ default:
+ if ip.Is4In6() {
+ if z := ip.Zone(); z != "" {
+ return "::ffff:" + ip.Unmap().string4() + "%" + z
+ } else {
+ return "::ffff:" + ip.Unmap().string4()
+ }
+ }
+ return ip.string6()
+ }
+}
+
+// AppendTo appends a text encoding of ip,
+// as generated by MarshalText,
+// to b and returns the extended buffer.
+func (ip Addr) AppendTo(b []byte) []byte {
+ switch ip.z {
+ case z0:
+ return b
+ case z4:
+ return ip.appendTo4(b)
+ default:
+ if ip.Is4In6() {
+ b = append(b, "::ffff:"...)
+ b = ip.Unmap().appendTo4(b)
+ if z := ip.Zone(); z != "" {
+ b = append(b, '%')
+ b = append(b, z...)
+ }
+ return b
+ }
+ return ip.appendTo6(b)
+ }
+}
+
+// digits is a string of the hex digits from 0 to f. It's used in
+// appendDecimal and appendHex to format IP addresses.
+const digits = "0123456789abcdef"
+
+// appendDecimal appends the decimal string representation of x to b.
+func appendDecimal(b []byte, x uint8) []byte {
+ // Using this function rather than strconv.AppendUint makes IPv4
+ // string building 2x faster.
+
+ if x >= 100 {
+ b = append(b, digits[x/100])
+ }
+ if x >= 10 {
+ b = append(b, digits[x/10%10])
+ }
+ return append(b, digits[x%10])
+}
+
+// appendHex appends the hex string representation of x to b.
+func appendHex(b []byte, x uint16) []byte {
+ // Using this function rather than strconv.AppendUint makes IPv6
+ // string building 2x faster.
+
+ if x >= 0x1000 {
+ b = append(b, digits[x>>12])
+ }
+ if x >= 0x100 {
+ b = append(b, digits[x>>8&0xf])
+ }
+ if x >= 0x10 {
+ b = append(b, digits[x>>4&0xf])
+ }
+ return append(b, digits[x&0xf])
+}
+
+// appendHexPad appends the fully padded hex string representation of x to b.
+func appendHexPad(b []byte, x uint16) []byte {
+ return append(b, digits[x>>12], digits[x>>8&0xf], digits[x>>4&0xf], digits[x&0xf])
+}
+
+func (ip Addr) string4() string {
+ const max = len("255.255.255.255")
+ ret := make([]byte, 0, max)
+ ret = ip.appendTo4(ret)
+ return string(ret)
+}
+
+func (ip Addr) appendTo4(ret []byte) []byte {
+ ret = appendDecimal(ret, ip.v4(0))
+ ret = append(ret, '.')
+ ret = appendDecimal(ret, ip.v4(1))
+ ret = append(ret, '.')
+ ret = appendDecimal(ret, ip.v4(2))
+ ret = append(ret, '.')
+ ret = appendDecimal(ret, ip.v4(3))
+ return ret
+}
+
+// string6 formats ip in IPv6 textual representation. It follows the
+// guidelines in section 4 of RFC 5952
+// (https://tools.ietf.org/html/rfc5952#section-4): no unnecessary
+// zeros, use :: to elide the longest run of zeros, and don't use ::
+// to compact a single zero field.
+func (ip Addr) string6() string {
+ // Use a zone with a "plausibly long" name, so that most zone-ful
+ // IP addresses won't require additional allocation.
+ //
+ // The compiler does a cool optimization here, where ret ends up
+ // stack-allocated and so the only allocation this function does
+ // is to construct the returned string. As such, it's okay to be a
+ // bit greedy here, size-wise.
+ const max = len("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff%enp5s0")
+ ret := make([]byte, 0, max)
+ ret = ip.appendTo6(ret)
+ return string(ret)
+}
+
+func (ip Addr) appendTo6(ret []byte) []byte {
+ zeroStart, zeroEnd := uint8(255), uint8(255)
+ for i := uint8(0); i < 8; i++ {
+ j := i
+ for j < 8 && ip.v6u16(j) == 0 {
+ j++
+ }
+ if l := j - i; l >= 2 && l > zeroEnd-zeroStart {
+ zeroStart, zeroEnd = i, j
+ }
+ }
+
+ for i := uint8(0); i < 8; i++ {
+ if i == zeroStart {
+ ret = append(ret, ':', ':')
+ i = zeroEnd
+ if i >= 8 {
+ break
+ }
+ } else if i > 0 {
+ ret = append(ret, ':')
+ }
+
+ ret = appendHex(ret, ip.v6u16(i))
+ }
+
+ if ip.z != z6noz {
+ ret = append(ret, '%')
+ ret = append(ret, ip.Zone()...)
+ }
+ return ret
+}
+
+// StringExpanded is like String but IPv6 addresses are expanded with leading
+// zeroes and no "::" compression. For example, "2001:db8::1" becomes
+// "2001:0db8:0000:0000:0000:0000:0000:0001".
+func (ip Addr) StringExpanded() string {
+ switch ip.z {
+ case z0, z4:
+ return ip.String()
+ }
+
+ const size = len("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff")
+ ret := make([]byte, 0, size)
+ for i := uint8(0); i < 8; i++ {
+ if i > 0 {
+ ret = append(ret, ':')
+ }
+
+ ret = appendHexPad(ret, ip.v6u16(i))
+ }
+
+ if ip.z != z6noz {
+ // The addition of a zone will cause a second allocation, but when there
+ // is no zone the ret slice will be stack allocated.
+ ret = append(ret, '%')
+ ret = append(ret, ip.Zone()...)
+ }
+ return string(ret)
+}
+
+// MarshalText implements the encoding.TextMarshaler interface,
+// The encoding is the same as returned by String, with one exception:
+// If ip is the zero Addr, the encoding is the empty string.
+func (ip Addr) MarshalText() ([]byte, error) {
+ switch ip.z {
+ case z0:
+ return []byte(""), nil
+ case z4:
+ max := len("255.255.255.255")
+ b := make([]byte, 0, max)
+ return ip.appendTo4(b), nil
+ default:
+ max := len("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff%enp5s0")
+ b := make([]byte, 0, max)
+ if ip.Is4In6() {
+ b = append(b, "::ffff:"...)
+ b = ip.Unmap().appendTo4(b)
+ if z := ip.Zone(); z != "" {
+ b = append(b, '%')
+ b = append(b, z...)
+ }
+ return b, nil
+ }
+ return ip.appendTo6(b), nil
+ }
+
+}
+
+// UnmarshalText implements the encoding.TextUnmarshaler interface.
+// The IP address is expected in a form accepted by ParseAddr.
+//
+// If text is empty, UnmarshalText sets *ip to the zero Addr and
+// returns no error.
+func (ip *Addr) UnmarshalText(text []byte) error {
+ if len(text) == 0 {
+ *ip = Addr{}
+ return nil
+ }
+ var err error
+ *ip, err = ParseAddr(string(text))
+ return err
+}
+
+func (ip Addr) marshalBinaryWithTrailingBytes(trailingBytes int) []byte {
+ var b []byte
+ switch ip.z {
+ case z0:
+ b = make([]byte, trailingBytes)
+ case z4:
+ b = make([]byte, 4+trailingBytes)
+ bePutUint32(b, uint32(ip.addr.lo))
+ default:
+ z := ip.Zone()
+ b = make([]byte, 16+len(z)+trailingBytes)
+ bePutUint64(b[:8], ip.addr.hi)
+ bePutUint64(b[8:], ip.addr.lo)
+ copy(b[16:], z)
+ }
+ return b
+}
+
+// MarshalBinary implements the encoding.BinaryMarshaler interface.
+// It returns a zero-length slice for the zero Addr,
+// the 4-byte form for an IPv4 address,
+// and the 16-byte form with zone appended for an IPv6 address.
+func (ip Addr) MarshalBinary() ([]byte, error) {
+ return ip.marshalBinaryWithTrailingBytes(0), nil
+}
+
+// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface.
+// It expects data in the form generated by MarshalBinary.
+func (ip *Addr) UnmarshalBinary(b []byte) error {
+ n := len(b)
+ switch {
+ case n == 0:
+ *ip = Addr{}
+ return nil
+ case n == 4:
+ *ip = AddrFrom4([4]byte(b))
+ return nil
+ case n == 16:
+ *ip = AddrFrom16([16]byte(b))
+ return nil
+ case n > 16:
+ *ip = AddrFrom16([16]byte(b[:16])).WithZone(string(b[16:]))
+ return nil
+ }
+ return errors.New("unexpected slice size")
+}
+
+// AddrPort is an IP and a port number.
+type AddrPort struct {
+ ip Addr
+ port uint16
+}
+
+// AddrPortFrom returns an AddrPort with the provided IP and port.
+// It does not allocate.
+func AddrPortFrom(ip Addr, port uint16) AddrPort { return AddrPort{ip: ip, port: port} }
+
+// Addr returns p's IP address.
+func (p AddrPort) Addr() Addr { return p.ip }
+
+// Port returns p's port.
+func (p AddrPort) Port() uint16 { return p.port }
+
+// splitAddrPort splits s into an IP address string and a port
+// string. It splits strings shaped like "foo:bar" or "[foo]:bar",
+// without further validating the substrings. v6 indicates whether the
+// ip string should parse as an IPv6 address or an IPv4 address, in
+// order for s to be a valid ip:port string.
+func splitAddrPort(s string) (ip, port string, v6 bool, err error) {
+ i := stringsLastIndexByte(s, ':')
+ if i == -1 {
+ return "", "", false, errors.New("not an ip:port")
+ }
+
+ ip, port = s[:i], s[i+1:]
+ if len(ip) == 0 {
+ return "", "", false, errors.New("no IP")
+ }
+ if len(port) == 0 {
+ return "", "", false, errors.New("no port")
+ }
+ if ip[0] == '[' {
+ if len(ip) < 2 || ip[len(ip)-1] != ']' {
+ return "", "", false, errors.New("missing ]")
+ }
+ ip = ip[1 : len(ip)-1]
+ v6 = true
+ }
+
+ return ip, port, v6, nil
+}
+
+// ParseAddrPort parses s as an AddrPort.
+//
+// It doesn't do any name resolution: both the address and the port
+// must be numeric.
+func ParseAddrPort(s string) (AddrPort, error) {
+ var ipp AddrPort
+ ip, port, v6, err := splitAddrPort(s)
+ if err != nil {
+ return ipp, err
+ }
+ port16, err := strconv.ParseUint(port, 10, 16)
+ if err != nil {
+ return ipp, errors.New("invalid port " + strconv.Quote(port) + " parsing " + strconv.Quote(s))
+ }
+ ipp.port = uint16(port16)
+ ipp.ip, err = ParseAddr(ip)
+ if err != nil {
+ return AddrPort{}, err
+ }
+ if v6 && ipp.ip.Is4() {
+ return AddrPort{}, errors.New("invalid ip:port " + strconv.Quote(s) + ", square brackets can only be used with IPv6 addresses")
+ } else if !v6 && ipp.ip.Is6() {
+ return AddrPort{}, errors.New("invalid ip:port " + strconv.Quote(s) + ", IPv6 addresses must be surrounded by square brackets")
+ }
+ return ipp, nil
+}
+
+// MustParseAddrPort calls ParseAddrPort(s) and panics on error.
+// It is intended for use in tests with hard-coded strings.
+func MustParseAddrPort(s string) AddrPort {
+ ip, err := ParseAddrPort(s)
+ if err != nil {
+ panic(err)
+ }
+ return ip
+}
+
+// IsValid reports whether p.Addr() is valid.
+// All ports are valid, including zero.
+func (p AddrPort) IsValid() bool { return p.ip.IsValid() }
+
+func (p AddrPort) String() string {
+ switch p.ip.z {
+ case z0:
+ return "invalid AddrPort"
+ case z4:
+ a := p.ip.As4()
+ buf := make([]byte, 0, 21)
+ for i := range a {
+ buf = strconv.AppendUint(buf, uint64(a[i]), 10)
+ buf = append(buf, "...:"[i])
+ }
+ buf = strconv.AppendUint(buf, uint64(p.port), 10)
+ return string(buf)
+ default:
+ // TODO: this could be more efficient allocation-wise:
+ return joinHostPort(p.ip.String(), itoa.Itoa(int(p.port)))
+ }
+}
+
+func joinHostPort(host, port string) string {
+ // We assume that host is a literal IPv6 address if host has
+ // colons.
+ if bytealg.IndexByteString(host, ':') >= 0 {
+ return "[" + host + "]:" + port
+ }
+ return host + ":" + port
+}
+
+// AppendTo appends a text encoding of p,
+// as generated by MarshalText,
+// to b and returns the extended buffer.
+func (p AddrPort) AppendTo(b []byte) []byte {
+ switch p.ip.z {
+ case z0:
+ return b
+ case z4:
+ b = p.ip.appendTo4(b)
+ default:
+ if p.ip.Is4In6() {
+ b = append(b, "[::ffff:"...)
+ b = p.ip.Unmap().appendTo4(b)
+ if z := p.ip.Zone(); z != "" {
+ b = append(b, '%')
+ b = append(b, z...)
+ }
+ } else {
+ b = append(b, '[')
+ b = p.ip.appendTo6(b)
+ }
+ b = append(b, ']')
+ }
+ b = append(b, ':')
+ b = strconv.AppendUint(b, uint64(p.port), 10)
+ return b
+}
+
+// MarshalText implements the encoding.TextMarshaler interface. The
+// encoding is the same as returned by String, with one exception: if
+// p.Addr() is the zero Addr, the encoding is the empty string.
+func (p AddrPort) MarshalText() ([]byte, error) {
+ var max int
+ switch p.ip.z {
+ case z0:
+ case z4:
+ max = len("255.255.255.255:65535")
+ default:
+ max = len("[ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff%enp5s0]:65535")
+ }
+ b := make([]byte, 0, max)
+ b = p.AppendTo(b)
+ return b, nil
+}
+
+// UnmarshalText implements the encoding.TextUnmarshaler
+// interface. The AddrPort is expected in a form
+// generated by MarshalText or accepted by ParseAddrPort.
+func (p *AddrPort) UnmarshalText(text []byte) error {
+ if len(text) == 0 {
+ *p = AddrPort{}
+ return nil
+ }
+ var err error
+ *p, err = ParseAddrPort(string(text))
+ return err
+}
+
+// MarshalBinary implements the encoding.BinaryMarshaler interface.
+// It returns Addr.MarshalBinary with an additional two bytes appended
+// containing the port in little-endian.
+func (p AddrPort) MarshalBinary() ([]byte, error) {
+ b := p.Addr().marshalBinaryWithTrailingBytes(2)
+ lePutUint16(b[len(b)-2:], p.Port())
+ return b, nil
+}
+
+// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface.
+// It expects data in the form generated by MarshalBinary.
+func (p *AddrPort) UnmarshalBinary(b []byte) error {
+ if len(b) < 2 {
+ return errors.New("unexpected slice size")
+ }
+ var addr Addr
+ err := addr.UnmarshalBinary(b[:len(b)-2])
+ if err != nil {
+ return err
+ }
+ *p = AddrPortFrom(addr, leUint16(b[len(b)-2:]))
+ return nil
+}
+
+// Prefix is an IP address prefix (CIDR) representing an IP network.
+//
+// The first Bits() of Addr() are specified. The remaining bits match any address.
+// The range of Bits() is [0,32] for IPv4 or [0,128] for IPv6.
+type Prefix struct {
+ ip Addr
+
+ // bitsPlusOne stores the prefix bit length plus one.
+ // A Prefix is valid if and only if bitsPlusOne is non-zero.
+ bitsPlusOne uint8
+}
+
+// PrefixFrom returns a Prefix with the provided IP address and bit
+// prefix length.
+//
+// It does not allocate. Unlike Addr.Prefix, PrefixFrom does not mask
+// off the host bits of ip.
+//
+// If bits is less than zero or greater than ip.BitLen, Prefix.Bits
+// will return an invalid value -1.
+func PrefixFrom(ip Addr, bits int) Prefix {
+ var bitsPlusOne uint8
+ if !ip.isZero() && bits >= 0 && bits <= ip.BitLen() {
+ bitsPlusOne = uint8(bits) + 1
+ }
+ return Prefix{
+ ip: ip.withoutZone(),
+ bitsPlusOne: bitsPlusOne,
+ }
+}
+
+// Addr returns p's IP address.
+func (p Prefix) Addr() Addr { return p.ip }
+
+// Bits returns p's prefix length.
+//
+// It reports -1 if invalid.
+func (p Prefix) Bits() int { return int(p.bitsPlusOne) - 1 }
+
+// IsValid reports whether p.Bits() has a valid range for p.Addr().
+// If p.Addr() is the zero Addr, IsValid returns false.
+// Note that if p is the zero Prefix, then p.IsValid() == false.
+func (p Prefix) IsValid() bool { return p.bitsPlusOne > 0 }
+
+func (p Prefix) isZero() bool { return p == Prefix{} }
+
+// IsSingleIP reports whether p contains exactly one IP.
+func (p Prefix) IsSingleIP() bool { return p.IsValid() && p.Bits() == p.ip.BitLen() }
+
+// ParsePrefix parses s as an IP address prefix.
+// The string can be in the form "192.168.1.0/24" or "2001:db8::/32",
+// the CIDR notation defined in RFC 4632 and RFC 4291.
+// IPv6 zones are not permitted in prefixes, and an error will be returned if a
+// zone is present.
+//
+// Note that masked address bits are not zeroed. Use Masked for that.
+func ParsePrefix(s string) (Prefix, error) {
+ i := stringsLastIndexByte(s, '/')
+ if i < 0 {
+ return Prefix{}, errors.New("netip.ParsePrefix(" + strconv.Quote(s) + "): no '/'")
+ }
+ ip, err := ParseAddr(s[:i])
+ if err != nil {
+ return Prefix{}, errors.New("netip.ParsePrefix(" + strconv.Quote(s) + "): " + err.Error())
+ }
+ // IPv6 zones are not allowed: https://go.dev/issue/51899
+ if ip.Is6() && ip.z != z6noz {
+ return Prefix{}, errors.New("netip.ParsePrefix(" + strconv.Quote(s) + "): IPv6 zones cannot be present in a prefix")
+ }
+
+ bitsStr := s[i+1:]
+ bits, err := strconv.Atoi(bitsStr)
+ if err != nil {
+ return Prefix{}, errors.New("netip.ParsePrefix(" + strconv.Quote(s) + "): bad bits after slash: " + strconv.Quote(bitsStr))
+ }
+ maxBits := 32
+ if ip.Is6() {
+ maxBits = 128
+ }
+ if bits < 0 || bits > maxBits {
+ return Prefix{}, errors.New("netip.ParsePrefix(" + strconv.Quote(s) + "): prefix length out of range")
+ }
+ return PrefixFrom(ip, bits), nil
+}
+
+// MustParsePrefix calls ParsePrefix(s) and panics on error.
+// It is intended for use in tests with hard-coded strings.
+func MustParsePrefix(s string) Prefix {
+ ip, err := ParsePrefix(s)
+ if err != nil {
+ panic(err)
+ }
+ return ip
+}
+
+// Masked returns p in its canonical form, with all but the high
+// p.Bits() bits of p.Addr() masked off.
+//
+// If p is zero or otherwise invalid, Masked returns the zero Prefix.
+func (p Prefix) Masked() Prefix {
+ m, _ := p.ip.Prefix(p.Bits())
+ return m
+}
+
+// Contains reports whether the network p includes ip.
+//
+// An IPv4 address will not match an IPv6 prefix.
+// An IPv4-mapped IPv6 address will not match an IPv4 prefix.
+// A zero-value IP will not match any prefix.
+// If ip has an IPv6 zone, Contains returns false,
+// because Prefixes strip zones.
+func (p Prefix) Contains(ip Addr) bool {
+ if !p.IsValid() || ip.hasZone() {
+ return false
+ }
+ if f1, f2 := p.ip.BitLen(), ip.BitLen(); f1 == 0 || f2 == 0 || f1 != f2 {
+ return false
+ }
+ if ip.Is4() {
+ // xor the IP addresses together; mismatched bits are now ones.
+ // Shift away the number of bits we don't care about.
+ // Shifts in Go are more efficient if the compiler can prove
+ // that the shift amount is smaller than the width of the shifted type (64 here).
+ // We know that p.bits is in the range 0..32 because p is Valid;
+ // the compiler doesn't know that, so mask with 63 to help it.
+ // Now truncate to 32 bits, because this is IPv4.
+ // If all the bits we care about are equal, the result will be zero.
+ return uint32((ip.addr.lo^p.ip.addr.lo)>>((32-p.Bits())&63)) == 0
+ } else {
+ // xor the IP addresses together.
+ // Mask away the bits we don't care about.
+ // If all the bits we care about are equal, the result will be zero.
+ return ip.addr.xor(p.ip.addr).and(mask6(p.Bits())).isZero()
+ }
+}
+
+// Overlaps reports whether p and o contain any IP addresses in common.
+//
+// If p and o are of different address families or either have a zero
+// IP, it reports false. Like the Contains method, a prefix with an
+// IPv4-mapped IPv6 address is still treated as an IPv6 mask.
+func (p Prefix) Overlaps(o Prefix) bool {
+ if !p.IsValid() || !o.IsValid() {
+ return false
+ }
+ if p == o {
+ return true
+ }
+ if p.ip.Is4() != o.ip.Is4() {
+ return false
+ }
+ var minBits int
+ if pb, ob := p.Bits(), o.Bits(); pb < ob {
+ minBits = pb
+ } else {
+ minBits = ob
+ }
+ if minBits == 0 {
+ return true
+ }
+ // One of these Prefix calls might look redundant, but we don't require
+ // that p and o values are normalized (via Prefix.Masked) first,
+ // so the Prefix call on the one that's already minBits serves to zero
+ // out any remaining bits in IP.
+ var err error
+ if p, err = p.ip.Prefix(minBits); err != nil {
+ return false
+ }
+ if o, err = o.ip.Prefix(minBits); err != nil {
+ return false
+ }
+ return p.ip == o.ip
+}
+
+// AppendTo appends a text encoding of p,
+// as generated by MarshalText,
+// to b and returns the extended buffer.
+func (p Prefix) AppendTo(b []byte) []byte {
+ if p.isZero() {
+ return b
+ }
+ if !p.IsValid() {
+ return append(b, "invalid Prefix"...)
+ }
+
+ // p.ip is non-nil, because p is valid.
+ if p.ip.z == z4 {
+ b = p.ip.appendTo4(b)
+ } else {
+ if p.ip.Is4In6() {
+ b = append(b, "::ffff:"...)
+ b = p.ip.Unmap().appendTo4(b)
+ } else {
+ b = p.ip.appendTo6(b)
+ }
+ }
+
+ b = append(b, '/')
+ b = appendDecimal(b, uint8(p.Bits()))
+ return b
+}
+
+// MarshalText implements the encoding.TextMarshaler interface,
+// The encoding is the same as returned by String, with one exception:
+// If p is the zero value, the encoding is the empty string.
+func (p Prefix) MarshalText() ([]byte, error) {
+ var max int
+ switch p.ip.z {
+ case z0:
+ case z4:
+ max = len("255.255.255.255/32")
+ default:
+ max = len("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff%enp5s0/128")
+ }
+ b := make([]byte, 0, max)
+ b = p.AppendTo(b)
+ return b, nil
+}
+
+// UnmarshalText implements the encoding.TextUnmarshaler interface.
+// The IP address is expected in a form accepted by ParsePrefix
+// or generated by MarshalText.
+func (p *Prefix) UnmarshalText(text []byte) error {
+ if len(text) == 0 {
+ *p = Prefix{}
+ return nil
+ }
+ var err error
+ *p, err = ParsePrefix(string(text))
+ return err
+}
+
+// MarshalBinary implements the encoding.BinaryMarshaler interface.
+// It returns Addr.MarshalBinary with an additional byte appended
+// containing the prefix bits.
+func (p Prefix) MarshalBinary() ([]byte, error) {
+ b := p.Addr().withoutZone().marshalBinaryWithTrailingBytes(1)
+ b[len(b)-1] = uint8(p.Bits())
+ return b, nil
+}
+
+// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface.
+// It expects data in the form generated by MarshalBinary.
+func (p *Prefix) UnmarshalBinary(b []byte) error {
+ if len(b) < 1 {
+ return errors.New("unexpected slice size")
+ }
+ var addr Addr
+ err := addr.UnmarshalBinary(b[:len(b)-1])
+ if err != nil {
+ return err
+ }
+ *p = PrefixFrom(addr, int(b[len(b)-1]))
+ return nil
+}
+
+// String returns the CIDR notation of p: "<ip>/<bits>".
+func (p Prefix) String() string {
+ if !p.IsValid() {
+ return "invalid Prefix"
+ }
+ return p.ip.String() + "/" + itoa.Itoa(p.Bits())
+}
diff --git a/src/net/netip/netip_pkg_test.go b/src/net/netip/netip_pkg_test.go
new file mode 100644
index 0000000..2c9a2e6
--- /dev/null
+++ b/src/net/netip/netip_pkg_test.go
@@ -0,0 +1,365 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package netip
+
+import (
+ "bytes"
+ "encoding"
+ "encoding/json"
+ "strings"
+ "testing"
+)
+
+var (
+ mustPrefix = MustParsePrefix
+ mustIP = MustParseAddr
+)
+
+func TestPrefixValid(t *testing.T) {
+ v4 := MustParseAddr("1.2.3.4")
+ v6 := MustParseAddr("::1")
+ tests := []struct {
+ ipp Prefix
+ want bool
+ }{
+ {PrefixFrom(v4, -2), false},
+ {PrefixFrom(v4, -1), false},
+ {PrefixFrom(v4, 0), true},
+ {PrefixFrom(v4, 32), true},
+ {PrefixFrom(v4, 33), false},
+
+ {PrefixFrom(v6, -2), false},
+ {PrefixFrom(v6, -1), false},
+ {PrefixFrom(v6, 0), true},
+ {PrefixFrom(v6, 32), true},
+ {PrefixFrom(v6, 128), true},
+ {PrefixFrom(v6, 129), false},
+
+ {PrefixFrom(Addr{}, -2), false},
+ {PrefixFrom(Addr{}, -1), false},
+ {PrefixFrom(Addr{}, 0), false},
+ {PrefixFrom(Addr{}, 32), false},
+ {PrefixFrom(Addr{}, 128), false},
+ }
+ for _, tt := range tests {
+ got := tt.ipp.IsValid()
+ if got != tt.want {
+ t.Errorf("(%v).IsValid() = %v want %v", tt.ipp, got, tt.want)
+ }
+
+ // Test that there is only one invalid Prefix representation per Addr.
+ invalid := PrefixFrom(tt.ipp.Addr(), -1)
+ if !got && tt.ipp != invalid {
+ t.Errorf("(%v == %v) = false, want true", tt.ipp, invalid)
+ }
+ }
+}
+
+var nextPrevTests = []struct {
+ ip Addr
+ next Addr
+ prev Addr
+}{
+ {mustIP("10.0.0.1"), mustIP("10.0.0.2"), mustIP("10.0.0.0")},
+ {mustIP("10.0.0.255"), mustIP("10.0.1.0"), mustIP("10.0.0.254")},
+ {mustIP("127.0.0.1"), mustIP("127.0.0.2"), mustIP("127.0.0.0")},
+ {mustIP("254.255.255.255"), mustIP("255.0.0.0"), mustIP("254.255.255.254")},
+ {mustIP("255.255.255.255"), Addr{}, mustIP("255.255.255.254")},
+ {mustIP("0.0.0.0"), mustIP("0.0.0.1"), Addr{}},
+ {mustIP("::"), mustIP("::1"), Addr{}},
+ {mustIP("::%x"), mustIP("::1%x"), Addr{}},
+ {mustIP("::1"), mustIP("::2"), mustIP("::")},
+ {mustIP("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"), Addr{}, mustIP("ffff:ffff:ffff:ffff:ffff:ffff:ffff:fffe")},
+}
+
+func TestIPNextPrev(t *testing.T) {
+ doNextPrev(t)
+
+ for _, ip := range []Addr{
+ mustIP("0.0.0.0"),
+ mustIP("::"),
+ } {
+ got := ip.Prev()
+ if !got.isZero() {
+ t.Errorf("IP(%v).Prev = %v; want zero", ip, got)
+ }
+ }
+
+ var allFF [16]byte
+ for i := range allFF {
+ allFF[i] = 0xff
+ }
+
+ for _, ip := range []Addr{
+ mustIP("255.255.255.255"),
+ AddrFrom16(allFF),
+ } {
+ got := ip.Next()
+ if !got.isZero() {
+ t.Errorf("IP(%v).Next = %v; want zero", ip, got)
+ }
+ }
+}
+
+func BenchmarkIPNextPrev(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ doNextPrev(b)
+ }
+}
+
+func doNextPrev(t testing.TB) {
+ for _, tt := range nextPrevTests {
+ gnext, gprev := tt.ip.Next(), tt.ip.Prev()
+ if gnext != tt.next {
+ t.Errorf("IP(%v).Next = %v; want %v", tt.ip, gnext, tt.next)
+ }
+ if gprev != tt.prev {
+ t.Errorf("IP(%v).Prev = %v; want %v", tt.ip, gprev, tt.prev)
+ }
+ if !tt.ip.Next().isZero() && tt.ip.Next().Prev() != tt.ip {
+ t.Errorf("IP(%v).Next.Prev = %v; want %v", tt.ip, tt.ip.Next().Prev(), tt.ip)
+ }
+ if !tt.ip.Prev().isZero() && tt.ip.Prev().Next() != tt.ip {
+ t.Errorf("IP(%v).Prev.Next = %v; want %v", tt.ip, tt.ip.Prev().Next(), tt.ip)
+ }
+ }
+}
+
+func TestIPBitLen(t *testing.T) {
+ tests := []struct {
+ ip Addr
+ want int
+ }{
+ {Addr{}, 0},
+ {mustIP("0.0.0.0"), 32},
+ {mustIP("10.0.0.1"), 32},
+ {mustIP("::"), 128},
+ {mustIP("fed0::1"), 128},
+ {mustIP("::ffff:10.0.0.1"), 128},
+ }
+ for _, tt := range tests {
+ got := tt.ip.BitLen()
+ if got != tt.want {
+ t.Errorf("BitLen(%v) = %d; want %d", tt.ip, got, tt.want)
+ }
+ }
+}
+
+func TestPrefixContains(t *testing.T) {
+ tests := []struct {
+ ipp Prefix
+ ip Addr
+ want bool
+ }{
+ {mustPrefix("9.8.7.6/0"), mustIP("9.8.7.6"), true},
+ {mustPrefix("9.8.7.6/16"), mustIP("9.8.7.6"), true},
+ {mustPrefix("9.8.7.6/16"), mustIP("9.8.6.4"), true},
+ {mustPrefix("9.8.7.6/16"), mustIP("9.9.7.6"), false},
+ {mustPrefix("9.8.7.6/32"), mustIP("9.8.7.6"), true},
+ {mustPrefix("9.8.7.6/32"), mustIP("9.8.7.7"), false},
+ {mustPrefix("9.8.7.6/32"), mustIP("9.8.7.7"), false},
+ {mustPrefix("::1/0"), mustIP("::1"), true},
+ {mustPrefix("::1/0"), mustIP("::2"), true},
+ {mustPrefix("::1/127"), mustIP("::1"), true},
+ {mustPrefix("::1/127"), mustIP("::2"), false},
+ {mustPrefix("::1/128"), mustIP("::1"), true},
+ {mustPrefix("::1/127"), mustIP("::2"), false},
+ // Zones ignored: https://go.dev/issue/51899
+ {Prefix{mustIP("1.2.3.4").WithZone("a"), 32}, mustIP("1.2.3.4"), true},
+ {Prefix{mustIP("::1").WithZone("a"), 128}, mustIP("::1"), true},
+ // invalid IP
+ {mustPrefix("::1/0"), Addr{}, false},
+ {mustPrefix("1.2.3.4/0"), Addr{}, false},
+ // invalid Prefix
+ {PrefixFrom(mustIP("::1"), 129), mustIP("::1"), false},
+ {PrefixFrom(mustIP("1.2.3.4"), 33), mustIP("1.2.3.4"), false},
+ {PrefixFrom(Addr{}, 0), mustIP("1.2.3.4"), false},
+ {PrefixFrom(Addr{}, 32), mustIP("1.2.3.4"), false},
+ {PrefixFrom(Addr{}, 128), mustIP("::1"), false},
+ // wrong IP family
+ {mustPrefix("::1/0"), mustIP("1.2.3.4"), false},
+ {mustPrefix("1.2.3.4/0"), mustIP("::1"), false},
+ }
+ for _, tt := range tests {
+ got := tt.ipp.Contains(tt.ip)
+ if got != tt.want {
+ t.Errorf("(%v).Contains(%v) = %v want %v", tt.ipp, tt.ip, got, tt.want)
+ }
+ }
+}
+
+func TestParseIPError(t *testing.T) {
+ tests := []struct {
+ ip string
+ errstr string
+ }{
+ {
+ ip: "localhost",
+ },
+ {
+ ip: "500.0.0.1",
+ errstr: "field has value >255",
+ },
+ {
+ ip: "::gggg%eth0",
+ errstr: "must have at least one digit",
+ },
+ {
+ ip: "fe80::1cc0:3e8c:119f:c2e1%",
+ errstr: "zone must be a non-empty string",
+ },
+ {
+ ip: "%eth0",
+ errstr: "missing IPv6 address",
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.ip, func(t *testing.T) {
+ _, err := ParseAddr(test.ip)
+ if err == nil {
+ t.Fatal("no error")
+ }
+ if _, ok := err.(parseAddrError); !ok {
+ t.Errorf("error type is %T, want parseIPError", err)
+ }
+ if test.errstr == "" {
+ test.errstr = "unable to parse IP"
+ }
+ if got := err.Error(); !strings.Contains(got, test.errstr) {
+ t.Errorf("error is missing substring %q: %s", test.errstr, got)
+ }
+ })
+ }
+}
+
+func TestParseAddrPort(t *testing.T) {
+ tests := []struct {
+ in string
+ want AddrPort
+ wantErr bool
+ }{
+ {in: "1.2.3.4:1234", want: AddrPort{mustIP("1.2.3.4"), 1234}},
+ {in: "1.1.1.1:123456", wantErr: true},
+ {in: "1.1.1.1:-123", wantErr: true},
+ {in: "[::1]:1234", want: AddrPort{mustIP("::1"), 1234}},
+ {in: "[1.2.3.4]:1234", wantErr: true},
+ {in: "fe80::1:1234", wantErr: true},
+ {in: ":0", wantErr: true}, // if we need to parse this form, there should be a separate function that explicitly allows it
+ }
+ for _, test := range tests {
+ t.Run(test.in, func(t *testing.T) {
+ got, err := ParseAddrPort(test.in)
+ if err != nil {
+ if test.wantErr {
+ return
+ }
+ t.Fatal(err)
+ }
+ if got != test.want {
+ t.Errorf("got %v; want %v", got, test.want)
+ }
+ if got.String() != test.in {
+ t.Errorf("String = %q; want %q", got.String(), test.in)
+ }
+ })
+
+ t.Run(test.in+"/AppendTo", func(t *testing.T) {
+ got, err := ParseAddrPort(test.in)
+ if err == nil {
+ testAppendToMarshal(t, got)
+ }
+ })
+
+ // TextMarshal and TextUnmarshal mostly behave like
+ // ParseAddrPort and String. Divergent behavior are handled in
+ // TestAddrPortMarshalUnmarshal.
+ t.Run(test.in+"/Marshal", func(t *testing.T) {
+ var got AddrPort
+ jsin := `"` + test.in + `"`
+ err := json.Unmarshal([]byte(jsin), &got)
+ if err != nil {
+ if test.wantErr {
+ return
+ }
+ t.Fatal(err)
+ }
+ if got != test.want {
+ t.Errorf("got %v; want %v", got, test.want)
+ }
+ gotb, err := json.Marshal(got)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if string(gotb) != jsin {
+ t.Errorf("Marshal = %q; want %q", string(gotb), jsin)
+ }
+ })
+ }
+}
+
+func TestAddrPortMarshalUnmarshal(t *testing.T) {
+ tests := []struct {
+ in string
+ want AddrPort
+ }{
+ {"", AddrPort{}},
+ }
+
+ for _, test := range tests {
+ t.Run(test.in, func(t *testing.T) {
+ orig := `"` + test.in + `"`
+
+ var ipp AddrPort
+ if err := json.Unmarshal([]byte(orig), &ipp); err != nil {
+ t.Fatalf("failed to unmarshal: %v", err)
+ }
+
+ ippb, err := json.Marshal(ipp)
+ if err != nil {
+ t.Fatalf("failed to marshal: %v", err)
+ }
+
+ back := string(ippb)
+ if orig != back {
+ t.Errorf("Marshal = %q; want %q", back, orig)
+ }
+
+ testAppendToMarshal(t, ipp)
+ })
+ }
+}
+
+type appendMarshaler interface {
+ encoding.TextMarshaler
+ AppendTo([]byte) []byte
+}
+
+// testAppendToMarshal tests that x's AppendTo and MarshalText methods yield the same results.
+// x's MarshalText method must not return an error.
+func testAppendToMarshal(t *testing.T, x appendMarshaler) {
+ t.Helper()
+ m, err := x.MarshalText()
+ if err != nil {
+ t.Fatalf("(%v).MarshalText: %v", x, err)
+ }
+ a := make([]byte, 0, len(m))
+ a = x.AppendTo(a)
+ if !bytes.Equal(m, a) {
+ t.Errorf("(%v).MarshalText = %q, (%v).AppendTo = %q", x, m, x, a)
+ }
+}
+
+func TestIPv6Accessor(t *testing.T) {
+ var a [16]byte
+ for i := range a {
+ a[i] = uint8(i) + 1
+ }
+ ip := AddrFrom16(a)
+ for i := range a {
+ if got, want := ip.v6(uint8(i)), uint8(i)+1; got != want {
+ t.Errorf("v6(%v) = %v; want %v", i, got, want)
+ }
+ }
+}
diff --git a/src/net/netip/netip_test.go b/src/net/netip/netip_test.go
new file mode 100644
index 0000000..0f80bb0
--- /dev/null
+++ b/src/net/netip/netip_test.go
@@ -0,0 +1,2029 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package netip_test
+
+import (
+ "bytes"
+ "encoding/json"
+ "flag"
+ "fmt"
+ "internal/intern"
+ "internal/testenv"
+ "net"
+ . "net/netip"
+ "reflect"
+ "sort"
+ "strings"
+ "testing"
+)
+
+var long = flag.Bool("long", false, "run long tests")
+
+type uint128 = Uint128
+
+var (
+ mustPrefix = MustParsePrefix
+ mustIP = MustParseAddr
+ mustIPPort = MustParseAddrPort
+)
+
+func TestParseAddr(t *testing.T) {
+ var validIPs = []struct {
+ in string
+ ip Addr // output of ParseAddr()
+ str string // output of String(). If "", use in.
+ wantErr string
+ }{
+ // Basic zero IPv4 address.
+ {
+ in: "0.0.0.0",
+ ip: MkAddr(Mk128(0, 0xffff00000000), Z4),
+ },
+ // Basic non-zero IPv4 address.
+ {
+ in: "192.168.140.255",
+ ip: MkAddr(Mk128(0, 0xffffc0a88cff), Z4),
+ },
+ // IPv4 address in windows-style "print all the digits" form.
+ {
+ in: "010.000.015.001",
+ wantErr: `ParseAddr("010.000.015.001"): IPv4 field has octet with leading zero`,
+ },
+ // IPv4 address with a silly amount of leading zeros.
+ {
+ in: "000001.00000002.00000003.000000004",
+ wantErr: `ParseAddr("000001.00000002.00000003.000000004"): IPv4 field has octet with leading zero`,
+ },
+ // 4-in-6 with octet with leading zero
+ {
+ in: "::ffff:1.2.03.4",
+ wantErr: `ParseAddr("::ffff:1.2.03.4"): ParseAddr("1.2.03.4"): IPv4 field has octet with leading zero (at "1.2.03.4")`,
+ },
+ // Basic zero IPv6 address.
+ {
+ in: "::",
+ ip: MkAddr(Mk128(0, 0), Z6noz),
+ },
+ // Localhost IPv6.
+ {
+ in: "::1",
+ ip: MkAddr(Mk128(0, 1), Z6noz),
+ },
+ // Fully expanded IPv6 address.
+ {
+ in: "fd7a:115c:a1e0:ab12:4843:cd96:626b:430b",
+ ip: MkAddr(Mk128(0xfd7a115ca1e0ab12, 0x4843cd96626b430b), Z6noz),
+ },
+ // IPv6 with elided fields in the middle.
+ {
+ in: "fd7a:115c::626b:430b",
+ ip: MkAddr(Mk128(0xfd7a115c00000000, 0x00000000626b430b), Z6noz),
+ },
+ // IPv6 with elided fields at the end.
+ {
+ in: "fd7a:115c:a1e0:ab12:4843:cd96::",
+ ip: MkAddr(Mk128(0xfd7a115ca1e0ab12, 0x4843cd9600000000), Z6noz),
+ },
+ // IPv6 with single elided field at the end.
+ {
+ in: "fd7a:115c:a1e0:ab12:4843:cd96:626b::",
+ ip: MkAddr(Mk128(0xfd7a115ca1e0ab12, 0x4843cd96626b0000), Z6noz),
+ str: "fd7a:115c:a1e0:ab12:4843:cd96:626b:0",
+ },
+ // IPv6 with single elided field in the middle.
+ {
+ in: "fd7a:115c:a1e0::4843:cd96:626b:430b",
+ ip: MkAddr(Mk128(0xfd7a115ca1e00000, 0x4843cd96626b430b), Z6noz),
+ str: "fd7a:115c:a1e0:0:4843:cd96:626b:430b",
+ },
+ // IPv6 with the trailing 32 bits written as IPv4 dotted decimal. (4in6)
+ {
+ in: "::ffff:192.168.140.255",
+ ip: MkAddr(Mk128(0, 0x0000ffffc0a88cff), Z6noz),
+ str: "::ffff:192.168.140.255",
+ },
+ // IPv6 with a zone specifier.
+ {
+ in: "fd7a:115c:a1e0:ab12:4843:cd96:626b:430b%eth0",
+ ip: MkAddr(Mk128(0xfd7a115ca1e0ab12, 0x4843cd96626b430b), intern.Get("eth0")),
+ },
+ // IPv6 with dotted decimal and zone specifier.
+ {
+ in: "1:2::ffff:192.168.140.255%eth1",
+ ip: MkAddr(Mk128(0x0001000200000000, 0x0000ffffc0a88cff), intern.Get("eth1")),
+ str: "1:2::ffff:c0a8:8cff%eth1",
+ },
+ // 4-in-6 with zone
+ {
+ in: "::ffff:192.168.140.255%eth1",
+ ip: MkAddr(Mk128(0, 0x0000ffffc0a88cff), intern.Get("eth1")),
+ str: "::ffff:192.168.140.255%eth1",
+ },
+ // IPv6 with capital letters.
+ {
+ in: "FD9E:1A04:F01D::1",
+ ip: MkAddr(Mk128(0xfd9e1a04f01d0000, 0x1), Z6noz),
+ str: "fd9e:1a04:f01d::1",
+ },
+ }
+
+ for _, test := range validIPs {
+ t.Run(test.in, func(t *testing.T) {
+ got, err := ParseAddr(test.in)
+ if err != nil {
+ if err.Error() == test.wantErr {
+ return
+ }
+ t.Fatal(err)
+ }
+ if test.wantErr != "" {
+ t.Fatalf("wanted error %q; got none", test.wantErr)
+ }
+ if got != test.ip {
+ t.Errorf("got %#v, want %#v", got, test.ip)
+ }
+
+ // Check that ParseAddr is a pure function.
+ got2, err := ParseAddr(test.in)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got != got2 {
+ t.Errorf("ParseAddr(%q) got 2 different results: %#v, %#v", test.in, got, got2)
+ }
+
+ // Check that ParseAddr(ip.String()) is the identity function.
+ s := got.String()
+ got3, err := ParseAddr(s)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got != got3 {
+ t.Errorf("ParseAddr(%q) != ParseAddr(ParseIP(%q).String()). Got %#v, want %#v", test.in, test.in, got3, got)
+ }
+
+ // Check that the slow-but-readable parser produces the same result.
+ slow, err := parseIPSlow(test.in)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got != slow {
+ t.Errorf("ParseAddr(%q) = %#v, parseIPSlow(%q) = %#v", test.in, got, test.in, slow)
+ }
+
+ // Check that the parsed IP formats as expected.
+ s = got.String()
+ wants := test.str
+ if wants == "" {
+ wants = test.in
+ }
+ if s != wants {
+ t.Errorf("ParseAddr(%q).String() got %q, want %q", test.in, s, wants)
+ }
+
+ // Check that AppendTo matches MarshalText.
+ TestAppendToMarshal(t, got)
+
+ // Check that MarshalText/UnmarshalText work similarly to
+ // ParseAddr/String (see TestIPMarshalUnmarshal for
+ // marshal-specific behavior that's not common with
+ // ParseAddr/String).
+ js := `"` + test.in + `"`
+ var jsgot Addr
+ if err := json.Unmarshal([]byte(js), &jsgot); err != nil {
+ t.Fatal(err)
+ }
+ if jsgot != got {
+ t.Errorf("json.Unmarshal(%q) = %#v, want %#v", test.in, jsgot, got)
+ }
+ jsb, err := json.Marshal(jsgot)
+ if err != nil {
+ t.Fatal(err)
+ }
+ jswant := `"` + wants + `"`
+ jsback := string(jsb)
+ if jsback != jswant {
+ t.Errorf("Marshal(Unmarshal(%q)) = %s, want %s", test.in, jsback, jswant)
+ }
+ })
+ }
+
+ var invalidIPs = []string{
+ // Empty string
+ "",
+ // Garbage non-IP
+ "bad",
+ // Single number. Some parsers accept this as an IPv4 address in
+ // big-endian uint32 form, but we don't.
+ "1234",
+ // IPv4 with a zone specifier
+ "1.2.3.4%eth0",
+ // IPv4 field must have at least one digit
+ ".1.2.3",
+ "1.2.3.",
+ "1..2.3",
+ // IPv4 address too long
+ "1.2.3.4.5",
+ // IPv4 in dotted octal form
+ "0300.0250.0214.0377",
+ // IPv4 in dotted hex form
+ "0xc0.0xa8.0x8c.0xff",
+ // IPv4 in class B form
+ "192.168.12345",
+ // IPv4 in class B form, with a small enough number to be
+ // parseable as a regular dotted decimal field.
+ "127.0.1",
+ // IPv4 in class A form
+ "192.1234567",
+ // IPv4 in class A form, with a small enough number to be
+ // parseable as a regular dotted decimal field.
+ "127.1",
+ // IPv4 field has value >255
+ "192.168.300.1",
+ // IPv4 with too many fields
+ "192.168.0.1.5.6",
+ // IPv6 with not enough fields
+ "1:2:3:4:5:6:7",
+ // IPv6 with too many fields
+ "1:2:3:4:5:6:7:8:9",
+ // IPv6 with 8 fields and a :: expander
+ "1:2:3:4::5:6:7:8",
+ // IPv6 with a field bigger than 2b
+ "fe801::1",
+ // IPv6 with non-hex values in field
+ "fe80:tail:scal:e::",
+ // IPv6 with a zone delimiter but no zone.
+ "fe80::1%",
+ // IPv6 (without ellipsis) with too many fields for trailing embedded IPv4.
+ "ffff:ffff:ffff:ffff:ffff:ffff:ffff:192.168.140.255",
+ // IPv6 (with ellipsis) with too many fields for trailing embedded IPv4.
+ "ffff::ffff:ffff:ffff:ffff:ffff:ffff:192.168.140.255",
+ // IPv6 with invalid embedded IPv4.
+ "::ffff:192.168.140.bad",
+ // IPv6 with multiple ellipsis ::.
+ "fe80::1::1",
+ // IPv6 with invalid non hex/colon character.
+ "fe80:1?:1",
+ // IPv6 with truncated bytes after single colon.
+ "fe80:",
+ }
+
+ for _, s := range invalidIPs {
+ t.Run(s, func(t *testing.T) {
+ got, err := ParseAddr(s)
+ if err == nil {
+ t.Errorf("ParseAddr(%q) = %#v, want error", s, got)
+ }
+
+ slow, err := parseIPSlow(s)
+ if err == nil {
+ t.Errorf("parseIPSlow(%q) = %#v, want error", s, slow)
+ }
+
+ std := net.ParseIP(s)
+ if std != nil {
+ t.Errorf("net.ParseIP(%q) = %#v, want error", s, std)
+ }
+
+ if s == "" {
+ // Don't test unmarshaling of "" here, do it in
+ // IPMarshalUnmarshal.
+ return
+ }
+ var jsgot Addr
+ js := []byte(`"` + s + `"`)
+ if err := json.Unmarshal(js, &jsgot); err == nil {
+ t.Errorf("json.Unmarshal(%q) = %#v, want error", s, jsgot)
+ }
+ })
+ }
+}
+
+func TestAddrFromSlice(t *testing.T) {
+ tests := []struct {
+ ip []byte
+ wantAddr Addr
+ wantOK bool
+ }{
+ {
+ ip: []byte{10, 0, 0, 1},
+ wantAddr: mustIP("10.0.0.1"),
+ wantOK: true,
+ },
+ {
+ ip: []byte{0xfe, 0x80, 15: 0x01},
+ wantAddr: mustIP("fe80::01"),
+ wantOK: true,
+ },
+ {
+ ip: []byte{0, 1, 2},
+ wantAddr: Addr{},
+ wantOK: false,
+ },
+ {
+ ip: nil,
+ wantAddr: Addr{},
+ wantOK: false,
+ },
+ }
+ for _, tt := range tests {
+ addr, ok := AddrFromSlice(tt.ip)
+ if ok != tt.wantOK || addr != tt.wantAddr {
+ t.Errorf("AddrFromSlice(%#v) = %#v, %v, want %#v, %v", tt.ip, addr, ok, tt.wantAddr, tt.wantOK)
+ }
+ }
+}
+
+func TestIPv4Constructors(t *testing.T) {
+ if AddrFrom4([4]byte{1, 2, 3, 4}) != MustParseAddr("1.2.3.4") {
+ t.Errorf("don't match")
+ }
+}
+
+func TestAddrMarshalUnmarshalBinary(t *testing.T) {
+ tests := []struct {
+ ip string
+ wantSize int
+ }{
+ {"", 0}, // zero IP
+ {"1.2.3.4", 4},
+ {"fd7a:115c:a1e0:ab12:4843:cd96:626b:430b", 16},
+ {"::ffff:c000:0280", 16},
+ {"::ffff:c000:0280%eth0", 20},
+ }
+ for _, tc := range tests {
+ var ip Addr
+ if len(tc.ip) > 0 {
+ ip = mustIP(tc.ip)
+ }
+ b, err := ip.MarshalBinary()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(b) != tc.wantSize {
+ t.Fatalf("%q encoded to size %d; want %d", tc.ip, len(b), tc.wantSize)
+ }
+ var ip2 Addr
+ if err := ip2.UnmarshalBinary(b); err != nil {
+ t.Fatal(err)
+ }
+ if ip != ip2 {
+ t.Fatalf("got %v; want %v", ip2, ip)
+ }
+ }
+
+ // Cannot unmarshal from unexpected IP length.
+ for _, n := range []int{3, 5} {
+ var ip2 Addr
+ if err := ip2.UnmarshalBinary(bytes.Repeat([]byte{1}, n)); err == nil {
+ t.Fatalf("unmarshaled from unexpected IP length %d", n)
+ }
+ }
+}
+
+func TestAddrPortMarshalTextString(t *testing.T) {
+ tests := []struct {
+ in AddrPort
+ want string
+ }{
+ {mustIPPort("1.2.3.4:80"), "1.2.3.4:80"},
+ {mustIPPort("[1::CAFE]:80"), "[1::cafe]:80"},
+ {mustIPPort("[1::CAFE%en0]:80"), "[1::cafe%en0]:80"},
+ {mustIPPort("[::FFFF:192.168.140.255]:80"), "[::ffff:192.168.140.255]:80"},
+ {mustIPPort("[::FFFF:192.168.140.255%en0]:80"), "[::ffff:192.168.140.255%en0]:80"},
+ }
+ for i, tt := range tests {
+ if got := tt.in.String(); got != tt.want {
+ t.Errorf("%d. for (%v, %v) String = %q; want %q", i, tt.in.Addr(), tt.in.Port(), got, tt.want)
+ }
+ mt, err := tt.in.MarshalText()
+ if err != nil {
+ t.Errorf("%d. for (%v, %v) MarshalText error: %v", i, tt.in.Addr(), tt.in.Port(), err)
+ continue
+ }
+ if string(mt) != tt.want {
+ t.Errorf("%d. for (%v, %v) MarshalText = %q; want %q", i, tt.in.Addr(), tt.in.Port(), mt, tt.want)
+ }
+ }
+}
+
+func TestAddrPortMarshalUnmarshalBinary(t *testing.T) {
+ tests := []struct {
+ ipport string
+ wantSize int
+ }{
+ {"1.2.3.4:51820", 4 + 2},
+ {"[fd7a:115c:a1e0:ab12:4843:cd96:626b:430b]:80", 16 + 2},
+ {"[::ffff:c000:0280]:65535", 16 + 2},
+ {"[::ffff:c000:0280%eth0]:1", 20 + 2},
+ }
+ for _, tc := range tests {
+ var ipport AddrPort
+ if len(tc.ipport) > 0 {
+ ipport = mustIPPort(tc.ipport)
+ }
+ b, err := ipport.MarshalBinary()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(b) != tc.wantSize {
+ t.Fatalf("%q encoded to size %d; want %d", tc.ipport, len(b), tc.wantSize)
+ }
+ var ipport2 AddrPort
+ if err := ipport2.UnmarshalBinary(b); err != nil {
+ t.Fatal(err)
+ }
+ if ipport != ipport2 {
+ t.Fatalf("got %v; want %v", ipport2, ipport)
+ }
+ }
+
+ // Cannot unmarshal from unexpected lengths.
+ for _, n := range []int{3, 7} {
+ var ipport2 AddrPort
+ if err := ipport2.UnmarshalBinary(bytes.Repeat([]byte{1}, n)); err == nil {
+ t.Fatalf("unmarshaled from unexpected length %d", n)
+ }
+ }
+}
+
+func TestPrefixMarshalTextString(t *testing.T) {
+ tests := []struct {
+ in Prefix
+ want string
+ }{
+ {mustPrefix("1.2.3.4/24"), "1.2.3.4/24"},
+ {mustPrefix("fd7a:115c:a1e0:ab12:4843:cd96:626b:430b/118"), "fd7a:115c:a1e0:ab12:4843:cd96:626b:430b/118"},
+ {mustPrefix("::ffff:c000:0280/96"), "::ffff:192.0.2.128/96"},
+ {mustPrefix("::ffff:192.168.140.255/8"), "::ffff:192.168.140.255/8"},
+ {PrefixFrom(mustIP("::ffff:c000:0280").WithZone("eth0"), 37), "::ffff:192.0.2.128/37"}, // Zone should be stripped
+ }
+ for i, tt := range tests {
+ if got := tt.in.String(); got != tt.want {
+ t.Errorf("%d. for %v String = %q; want %q", i, tt.in, got, tt.want)
+ }
+ mt, err := tt.in.MarshalText()
+ if err != nil {
+ t.Errorf("%d. for %v MarshalText error: %v", i, tt.in, err)
+ continue
+ }
+ if string(mt) != tt.want {
+ t.Errorf("%d. for %v MarshalText = %q; want %q", i, tt.in, mt, tt.want)
+ }
+ }
+}
+
+func TestPrefixMarshalUnmarshalBinary(t *testing.T) {
+ type testCase struct {
+ prefix Prefix
+ wantSize int
+ }
+ tests := []testCase{
+ {mustPrefix("1.2.3.4/24"), 4 + 1},
+ {mustPrefix("fd7a:115c:a1e0:ab12:4843:cd96:626b:430b/118"), 16 + 1},
+ {mustPrefix("::ffff:c000:0280/96"), 16 + 1},
+ {PrefixFrom(mustIP("::ffff:c000:0280").WithZone("eth0"), 37), 16 + 1}, // Zone should be stripped
+ }
+ tests = append(tests,
+ testCase{PrefixFrom(tests[0].prefix.Addr(), 33), tests[0].wantSize},
+ testCase{PrefixFrom(tests[1].prefix.Addr(), 129), tests[1].wantSize})
+ for _, tc := range tests {
+ prefix := tc.prefix
+ b, err := prefix.MarshalBinary()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(b) != tc.wantSize {
+ t.Fatalf("%q encoded to size %d; want %d", tc.prefix, len(b), tc.wantSize)
+ }
+ var prefix2 Prefix
+ if err := prefix2.UnmarshalBinary(b); err != nil {
+ t.Fatal(err)
+ }
+ if prefix != prefix2 {
+ t.Fatalf("got %v; want %v", prefix2, prefix)
+ }
+ }
+
+ // Cannot unmarshal from unexpected lengths.
+ for _, n := range []int{3, 6} {
+ var prefix2 Prefix
+ if err := prefix2.UnmarshalBinary(bytes.Repeat([]byte{1}, n)); err == nil {
+ t.Fatalf("unmarshaled from unexpected length %d", n)
+ }
+ }
+}
+
+func TestAddrMarshalUnmarshal(t *testing.T) {
+ // This only tests the cases where Marshal/Unmarshal diverges from
+ // the behavior of ParseAddr/String. For the rest of the test cases,
+ // see TestParseAddr above.
+ orig := `""`
+ var ip Addr
+ if err := json.Unmarshal([]byte(orig), &ip); err != nil {
+ t.Fatalf("Unmarshal(%q) got error %v", orig, err)
+ }
+ if ip != (Addr{}) {
+ t.Errorf("Unmarshal(%q) is not the zero Addr", orig)
+ }
+
+ jsb, err := json.Marshal(ip)
+ if err != nil {
+ t.Fatalf("Marshal(%v) got error %v", ip, err)
+ }
+ back := string(jsb)
+ if back != orig {
+ t.Errorf("Marshal(Unmarshal(%q)) got %q, want %q", orig, back, orig)
+ }
+}
+
+func TestAddrFrom16(t *testing.T) {
+ tests := []struct {
+ name string
+ in [16]byte
+ want Addr
+ }{
+ {
+ name: "v6-raw",
+ in: [...]byte{15: 1},
+ want: MkAddr(Mk128(0, 1), Z6noz),
+ },
+ {
+ name: "v4-raw",
+ in: [...]byte{10: 0xff, 11: 0xff, 12: 1, 13: 2, 14: 3, 15: 4},
+ want: MkAddr(Mk128(0, 0xffff01020304), Z6noz),
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := AddrFrom16(tt.in)
+ if got != tt.want {
+ t.Errorf("got %#v; want %#v", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestIPProperties(t *testing.T) {
+ var (
+ nilIP Addr
+
+ unicast4 = mustIP("192.0.2.1")
+ unicast6 = mustIP("2001:db8::1")
+ unicastZone6 = mustIP("2001:db8::1%eth0")
+ unicast6Unassigned = mustIP("4000::1") // not in 2000::/3.
+
+ multicast4 = mustIP("224.0.0.1")
+ multicast6 = mustIP("ff02::1")
+ multicastZone6 = mustIP("ff02::1%eth0")
+
+ llu4 = mustIP("169.254.0.1")
+ llu6 = mustIP("fe80::1")
+ llu6Last = mustIP("febf:ffff:ffff:ffff:ffff:ffff:ffff:ffff")
+ lluZone6 = mustIP("fe80::1%eth0")
+
+ loopback4 = mustIP("127.0.0.1")
+
+ ilm6 = mustIP("ff01::1")
+ ilmZone6 = mustIP("ff01::1%eth0")
+
+ private4a = mustIP("10.0.0.1")
+ private4b = mustIP("172.16.0.1")
+ private4c = mustIP("192.168.1.1")
+ private6 = mustIP("fd00::1")
+ )
+
+ tests := []struct {
+ name string
+ ip Addr
+ globalUnicast bool
+ interfaceLocalMulticast bool
+ linkLocalMulticast bool
+ linkLocalUnicast bool
+ loopback bool
+ multicast bool
+ private bool
+ unspecified bool
+ }{
+ {
+ name: "nil",
+ ip: nilIP,
+ },
+ {
+ name: "unicast v4Addr",
+ ip: unicast4,
+ globalUnicast: true,
+ },
+ {
+ name: "unicast v6Addr",
+ ip: unicast6,
+ globalUnicast: true,
+ },
+ {
+ name: "unicast v6AddrZone",
+ ip: unicastZone6,
+ globalUnicast: true,
+ },
+ {
+ name: "unicast v6Addr unassigned",
+ ip: unicast6Unassigned,
+ globalUnicast: true,
+ },
+ {
+ name: "multicast v4Addr",
+ ip: multicast4,
+ linkLocalMulticast: true,
+ multicast: true,
+ },
+ {
+ name: "multicast v6Addr",
+ ip: multicast6,
+ linkLocalMulticast: true,
+ multicast: true,
+ },
+ {
+ name: "multicast v6AddrZone",
+ ip: multicastZone6,
+ linkLocalMulticast: true,
+ multicast: true,
+ },
+ {
+ name: "link-local unicast v4Addr",
+ ip: llu4,
+ linkLocalUnicast: true,
+ },
+ {
+ name: "link-local unicast v6Addr",
+ ip: llu6,
+ linkLocalUnicast: true,
+ },
+ {
+ name: "link-local unicast v6Addr upper bound",
+ ip: llu6Last,
+ linkLocalUnicast: true,
+ },
+ {
+ name: "link-local unicast v6AddrZone",
+ ip: lluZone6,
+ linkLocalUnicast: true,
+ },
+ {
+ name: "loopback v4Addr",
+ ip: loopback4,
+ loopback: true,
+ },
+ {
+ name: "loopback v6Addr",
+ ip: IPv6Loopback(),
+ loopback: true,
+ },
+ {
+ name: "interface-local multicast v6Addr",
+ ip: ilm6,
+ interfaceLocalMulticast: true,
+ multicast: true,
+ },
+ {
+ name: "interface-local multicast v6AddrZone",
+ ip: ilmZone6,
+ interfaceLocalMulticast: true,
+ multicast: true,
+ },
+ {
+ name: "private v4Addr 10/8",
+ ip: private4a,
+ globalUnicast: true,
+ private: true,
+ },
+ {
+ name: "private v4Addr 172.16/12",
+ ip: private4b,
+ globalUnicast: true,
+ private: true,
+ },
+ {
+ name: "private v4Addr 192.168/16",
+ ip: private4c,
+ globalUnicast: true,
+ private: true,
+ },
+ {
+ name: "private v6Addr",
+ ip: private6,
+ globalUnicast: true,
+ private: true,
+ },
+ {
+ name: "unspecified v4Addr",
+ ip: IPv4Unspecified(),
+ unspecified: true,
+ },
+ {
+ name: "unspecified v6Addr",
+ ip: IPv6Unspecified(),
+ unspecified: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ gu := tt.ip.IsGlobalUnicast()
+ if gu != tt.globalUnicast {
+ t.Errorf("IsGlobalUnicast(%v) = %v; want %v", tt.ip, gu, tt.globalUnicast)
+ }
+
+ ilm := tt.ip.IsInterfaceLocalMulticast()
+ if ilm != tt.interfaceLocalMulticast {
+ t.Errorf("IsInterfaceLocalMulticast(%v) = %v; want %v", tt.ip, ilm, tt.interfaceLocalMulticast)
+ }
+
+ llu := tt.ip.IsLinkLocalUnicast()
+ if llu != tt.linkLocalUnicast {
+ t.Errorf("IsLinkLocalUnicast(%v) = %v; want %v", tt.ip, llu, tt.linkLocalUnicast)
+ }
+
+ llm := tt.ip.IsLinkLocalMulticast()
+ if llm != tt.linkLocalMulticast {
+ t.Errorf("IsLinkLocalMulticast(%v) = %v; want %v", tt.ip, llm, tt.linkLocalMulticast)
+ }
+
+ lo := tt.ip.IsLoopback()
+ if lo != tt.loopback {
+ t.Errorf("IsLoopback(%v) = %v; want %v", tt.ip, lo, tt.loopback)
+ }
+
+ multicast := tt.ip.IsMulticast()
+ if multicast != tt.multicast {
+ t.Errorf("IsMulticast(%v) = %v; want %v", tt.ip, multicast, tt.multicast)
+ }
+
+ private := tt.ip.IsPrivate()
+ if private != tt.private {
+ t.Errorf("IsPrivate(%v) = %v; want %v", tt.ip, private, tt.private)
+ }
+
+ unspecified := tt.ip.IsUnspecified()
+ if unspecified != tt.unspecified {
+ t.Errorf("IsUnspecified(%v) = %v; want %v", tt.ip, unspecified, tt.unspecified)
+ }
+ })
+ }
+}
+
+func TestAddrWellKnown(t *testing.T) {
+ tests := []struct {
+ name string
+ ip Addr
+ std net.IP
+ }{
+ {
+ name: "IPv6 link-local all nodes",
+ ip: IPv6LinkLocalAllNodes(),
+ std: net.IPv6linklocalallnodes,
+ },
+ {
+ name: "IPv6 link-local all routers",
+ ip: IPv6LinkLocalAllRouters(),
+ std: net.IPv6linklocalallrouters,
+ },
+ {
+ name: "IPv6 loopback",
+ ip: IPv6Loopback(),
+ std: net.IPv6loopback,
+ },
+ {
+ name: "IPv6 unspecified",
+ ip: IPv6Unspecified(),
+ std: net.IPv6unspecified,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ want := tt.std.String()
+ got := tt.ip.String()
+
+ if got != want {
+ t.Fatalf("got %s, want %s", got, want)
+ }
+ })
+ }
+}
+
+func TestLessCompare(t *testing.T) {
+ tests := []struct {
+ a, b Addr
+ want bool
+ }{
+ {Addr{}, Addr{}, false},
+ {Addr{}, mustIP("1.2.3.4"), true},
+ {mustIP("1.2.3.4"), Addr{}, false},
+
+ {mustIP("1.2.3.4"), mustIP("0102:0304::0"), true},
+ {mustIP("0102:0304::0"), mustIP("1.2.3.4"), false},
+ {mustIP("1.2.3.4"), mustIP("1.2.3.4"), false},
+
+ {mustIP("::1"), mustIP("::2"), true},
+ {mustIP("::1"), mustIP("::1%foo"), true},
+ {mustIP("::1%foo"), mustIP("::2"), true},
+ {mustIP("::2"), mustIP("::3"), true},
+
+ {mustIP("::"), mustIP("0.0.0.0"), false},
+ {mustIP("0.0.0.0"), mustIP("::"), true},
+
+ {mustIP("::1%a"), mustIP("::1%b"), true},
+ {mustIP("::1%a"), mustIP("::1%a"), false},
+ {mustIP("::1%b"), mustIP("::1%a"), false},
+ }
+ for _, tt := range tests {
+ got := tt.a.Less(tt.b)
+ if got != tt.want {
+ t.Errorf("Less(%q, %q) = %v; want %v", tt.a, tt.b, got, tt.want)
+ }
+ cmp := tt.a.Compare(tt.b)
+ if got && cmp != -1 {
+ t.Errorf("Less(%q, %q) = true, but Compare = %v (not -1)", tt.a, tt.b, cmp)
+ }
+ if cmp < -1 || cmp > 1 {
+ t.Errorf("bogus Compare return value %v", cmp)
+ }
+ if cmp == 0 && tt.a != tt.b {
+ t.Errorf("Compare(%q, %q) = 0; but not equal", tt.a, tt.b)
+ }
+ if cmp == 1 && !tt.b.Less(tt.a) {
+ t.Errorf("Compare(%q, %q) = 1; but b.Less(a) isn't true", tt.a, tt.b)
+ }
+
+ // Also check inverse.
+ if got == tt.want && got {
+ got2 := tt.b.Less(tt.a)
+ if got2 {
+ t.Errorf("Less(%q, %q) was correctly %v, but so was Less(%q, %q)", tt.a, tt.b, got, tt.b, tt.a)
+ }
+ }
+ }
+
+ // And just sort.
+ values := []Addr{
+ mustIP("::1"),
+ mustIP("::2"),
+ Addr{},
+ mustIP("1.2.3.4"),
+ mustIP("8.8.8.8"),
+ mustIP("::1%foo"),
+ }
+ sort.Slice(values, func(i, j int) bool { return values[i].Less(values[j]) })
+ got := fmt.Sprintf("%s", values)
+ want := `[invalid IP 1.2.3.4 8.8.8.8 ::1 ::1%foo ::2]`
+ if got != want {
+ t.Errorf("unexpected sort\n got: %s\nwant: %s\n", got, want)
+ }
+}
+
+func TestIPStringExpanded(t *testing.T) {
+ tests := []struct {
+ ip Addr
+ s string
+ }{
+ {
+ ip: Addr{},
+ s: "invalid IP",
+ },
+ {
+ ip: mustIP("192.0.2.1"),
+ s: "192.0.2.1",
+ },
+ {
+ ip: mustIP("::ffff:192.0.2.1"),
+ s: "0000:0000:0000:0000:0000:ffff:c000:0201",
+ },
+ {
+ ip: mustIP("2001:db8::1"),
+ s: "2001:0db8:0000:0000:0000:0000:0000:0001",
+ },
+ {
+ ip: mustIP("2001:db8::1%eth0"),
+ s: "2001:0db8:0000:0000:0000:0000:0000:0001%eth0",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.ip.String(), func(t *testing.T) {
+ want := tt.s
+ got := tt.ip.StringExpanded()
+
+ if got != want {
+ t.Fatalf("got %s, want %s", got, want)
+ }
+ })
+ }
+}
+
+func TestPrefixMasking(t *testing.T) {
+ type subtest struct {
+ ip Addr
+ bits uint8
+ p Prefix
+ ok bool
+ }
+
+ // makeIPv6 produces a set of IPv6 subtests with an optional zone identifier.
+ makeIPv6 := func(zone string) []subtest {
+ if zone != "" {
+ zone = "%" + zone
+ }
+
+ return []subtest{
+ {
+ ip: mustIP(fmt.Sprintf("2001:db8::1%s", zone)),
+ bits: 255,
+ },
+ {
+ ip: mustIP(fmt.Sprintf("2001:db8::1%s", zone)),
+ bits: 32,
+ p: mustPrefix("2001:db8::/32"),
+ ok: true,
+ },
+ {
+ ip: mustIP(fmt.Sprintf("fe80::dead:beef:dead:beef%s", zone)),
+ bits: 96,
+ p: mustPrefix("fe80::dead:beef:0:0/96"),
+ ok: true,
+ },
+ {
+ ip: mustIP(fmt.Sprintf("aaaa::%s", zone)),
+ bits: 4,
+ p: mustPrefix("a000::/4"),
+ ok: true,
+ },
+ {
+ ip: mustIP(fmt.Sprintf("::%s", zone)),
+ bits: 63,
+ p: mustPrefix("::/63"),
+ ok: true,
+ },
+ }
+ }
+
+ tests := []struct {
+ family string
+ subtests []subtest
+ }{
+ {
+ family: "nil",
+ subtests: []subtest{
+ {
+ bits: 255,
+ ok: true,
+ },
+ {
+ bits: 16,
+ ok: true,
+ },
+ },
+ },
+ {
+ family: "IPv4",
+ subtests: []subtest{
+ {
+ ip: mustIP("192.0.2.0"),
+ bits: 255,
+ },
+ {
+ ip: mustIP("192.0.2.0"),
+ bits: 16,
+ p: mustPrefix("192.0.0.0/16"),
+ ok: true,
+ },
+ {
+ ip: mustIP("255.255.255.255"),
+ bits: 20,
+ p: mustPrefix("255.255.240.0/20"),
+ ok: true,
+ },
+ {
+ // Partially masking one byte that contains both
+ // 1s and 0s on either side of the mask limit.
+ ip: mustIP("100.98.156.66"),
+ bits: 10,
+ p: mustPrefix("100.64.0.0/10"),
+ ok: true,
+ },
+ },
+ },
+ {
+ family: "IPv6",
+ subtests: makeIPv6(""),
+ },
+ {
+ family: "IPv6 zone",
+ subtests: makeIPv6("eth0"),
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.family, func(t *testing.T) {
+ for _, st := range tt.subtests {
+ t.Run(st.p.String(), func(t *testing.T) {
+ // Ensure st.ip is not mutated.
+ orig := st.ip.String()
+
+ p, err := st.ip.Prefix(int(st.bits))
+ if st.ok && err != nil {
+ t.Fatalf("failed to produce prefix: %v", err)
+ }
+ if !st.ok && err == nil {
+ t.Fatal("expected an error, but none occurred")
+ }
+ if err != nil {
+ t.Logf("err: %v", err)
+ return
+ }
+
+ if !reflect.DeepEqual(p, st.p) {
+ t.Errorf("prefix = %q, want %q", p, st.p)
+ }
+
+ if got := st.ip.String(); got != orig {
+ t.Errorf("IP was mutated: %q, want %q", got, orig)
+ }
+ })
+ }
+ })
+ }
+}
+
+func TestPrefixMarshalUnmarshal(t *testing.T) {
+ tests := []string{
+ "",
+ "1.2.3.4/32",
+ "0.0.0.0/0",
+ "::/0",
+ "::1/128",
+ "2001:db8::/32",
+ }
+
+ for _, s := range tests {
+ t.Run(s, func(t *testing.T) {
+ // Ensure that JSON (and by extension, text) marshaling is
+ // sane by entering quoted input.
+ orig := `"` + s + `"`
+
+ var p Prefix
+ if err := json.Unmarshal([]byte(orig), &p); err != nil {
+ t.Fatalf("failed to unmarshal: %v", err)
+ }
+
+ pb, err := json.Marshal(p)
+ if err != nil {
+ t.Fatalf("failed to marshal: %v", err)
+ }
+
+ back := string(pb)
+ if orig != back {
+ t.Errorf("Marshal = %q; want %q", back, orig)
+ }
+ })
+ }
+}
+
+func TestPrefixUnmarshalTextNonZero(t *testing.T) {
+ ip := mustPrefix("fe80::/64")
+ if err := ip.UnmarshalText([]byte("xxx")); err == nil {
+ t.Fatal("unmarshaled into non-empty Prefix")
+ }
+}
+
+func TestIs4AndIs6(t *testing.T) {
+ tests := []struct {
+ ip Addr
+ is4 bool
+ is6 bool
+ }{
+ {Addr{}, false, false},
+ {mustIP("1.2.3.4"), true, false},
+ {mustIP("127.0.0.2"), true, false},
+ {mustIP("::1"), false, true},
+ {mustIP("::ffff:192.0.2.128"), false, true},
+ {mustIP("::fffe:c000:0280"), false, true},
+ {mustIP("::1%eth0"), false, true},
+ }
+ for _, tt := range tests {
+ got4 := tt.ip.Is4()
+ if got4 != tt.is4 {
+ t.Errorf("Is4(%q) = %v; want %v", tt.ip, got4, tt.is4)
+ }
+
+ got6 := tt.ip.Is6()
+ if got6 != tt.is6 {
+ t.Errorf("Is6(%q) = %v; want %v", tt.ip, got6, tt.is6)
+ }
+ }
+}
+
+func TestIs4In6(t *testing.T) {
+ tests := []struct {
+ ip Addr
+ want bool
+ wantUnmap Addr
+ }{
+ {Addr{}, false, Addr{}},
+ {mustIP("::ffff:c000:0280"), true, mustIP("192.0.2.128")},
+ {mustIP("::ffff:192.0.2.128"), true, mustIP("192.0.2.128")},
+ {mustIP("::ffff:192.0.2.128%eth0"), true, mustIP("192.0.2.128")},
+ {mustIP("::fffe:c000:0280"), false, mustIP("::fffe:c000:0280")},
+ {mustIP("::ffff:127.1.2.3"), true, mustIP("127.1.2.3")},
+ {mustIP("::ffff:7f01:0203"), true, mustIP("127.1.2.3")},
+ {mustIP("0:0:0:0:0000:ffff:127.1.2.3"), true, mustIP("127.1.2.3")},
+ {mustIP("0:0:0:0:000000:ffff:127.1.2.3"), true, mustIP("127.1.2.3")},
+ {mustIP("0:0:0:0::ffff:127.1.2.3"), true, mustIP("127.1.2.3")},
+ {mustIP("::1"), false, mustIP("::1")},
+ {mustIP("1.2.3.4"), false, mustIP("1.2.3.4")},
+ }
+ for _, tt := range tests {
+ got := tt.ip.Is4In6()
+ if got != tt.want {
+ t.Errorf("Is4In6(%q) = %v; want %v", tt.ip, got, tt.want)
+ }
+ u := tt.ip.Unmap()
+ if u != tt.wantUnmap {
+ t.Errorf("Unmap(%q) = %v; want %v", tt.ip, u, tt.wantUnmap)
+ }
+ }
+}
+
+func TestPrefixMasked(t *testing.T) {
+ tests := []struct {
+ prefix Prefix
+ masked Prefix
+ }{
+ {
+ prefix: mustPrefix("192.168.0.255/24"),
+ masked: mustPrefix("192.168.0.0/24"),
+ },
+ {
+ prefix: mustPrefix("2100::/3"),
+ masked: mustPrefix("2000::/3"),
+ },
+ {
+ prefix: PrefixFrom(mustIP("2000::"), 129),
+ masked: Prefix{},
+ },
+ {
+ prefix: PrefixFrom(mustIP("1.2.3.4"), 33),
+ masked: Prefix{},
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.prefix.String(), func(t *testing.T) {
+ got := test.prefix.Masked()
+ if got != test.masked {
+ t.Errorf("Masked=%s, want %s", got, test.masked)
+ }
+ })
+ }
+}
+
+func TestPrefix(t *testing.T) {
+ tests := []struct {
+ prefix string
+ ip Addr
+ bits int
+ str string
+ contains []Addr
+ notContains []Addr
+ }{
+ {
+ prefix: "192.168.0.0/24",
+ ip: mustIP("192.168.0.0"),
+ bits: 24,
+ contains: mustIPs("192.168.0.1", "192.168.0.55"),
+ notContains: mustIPs("192.168.1.1", "1.1.1.1"),
+ },
+ {
+ prefix: "192.168.1.1/32",
+ ip: mustIP("192.168.1.1"),
+ bits: 32,
+ contains: mustIPs("192.168.1.1"),
+ notContains: mustIPs("192.168.1.2"),
+ },
+ {
+ prefix: "100.64.0.0/10", // CGNAT range; prefix not multiple of 8
+ ip: mustIP("100.64.0.0"),
+ bits: 10,
+ contains: mustIPs("100.64.0.0", "100.64.0.1", "100.81.251.94", "100.100.100.100", "100.127.255.254", "100.127.255.255"),
+ notContains: mustIPs("100.63.255.255", "100.128.0.0"),
+ },
+ {
+ prefix: "2001:db8::/96",
+ ip: mustIP("2001:db8::"),
+ bits: 96,
+ contains: mustIPs("2001:db8::aaaa:bbbb", "2001:db8::1"),
+ notContains: mustIPs("2001:db8::1:aaaa:bbbb", "2001:db9::"),
+ },
+ {
+ prefix: "0.0.0.0/0",
+ ip: mustIP("0.0.0.0"),
+ bits: 0,
+ contains: mustIPs("192.168.0.1", "1.1.1.1"),
+ notContains: append(mustIPs("2001:db8::1"), Addr{}),
+ },
+ {
+ prefix: "::/0",
+ ip: mustIP("::"),
+ bits: 0,
+ contains: mustIPs("::1", "2001:db8::1"),
+ notContains: mustIPs("192.0.2.1"),
+ },
+ {
+ prefix: "2000::/3",
+ ip: mustIP("2000::"),
+ bits: 3,
+ contains: mustIPs("2001:db8::1"),
+ notContains: mustIPs("fe80::1"),
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.prefix, func(t *testing.T) {
+ prefix, err := ParsePrefix(test.prefix)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if prefix.Addr() != test.ip {
+ t.Errorf("IP=%s, want %s", prefix.Addr(), test.ip)
+ }
+ if prefix.Bits() != test.bits {
+ t.Errorf("bits=%d, want %d", prefix.Bits(), test.bits)
+ }
+ for _, ip := range test.contains {
+ if !prefix.Contains(ip) {
+ t.Errorf("does not contain %s", ip)
+ }
+ }
+ for _, ip := range test.notContains {
+ if prefix.Contains(ip) {
+ t.Errorf("contains %s", ip)
+ }
+ }
+ want := test.str
+ if want == "" {
+ want = test.prefix
+ }
+ if got := prefix.String(); got != want {
+ t.Errorf("prefix.String()=%q, want %q", got, want)
+ }
+
+ TestAppendToMarshal(t, prefix)
+ })
+ }
+}
+
+func TestPrefixFromInvalidBits(t *testing.T) {
+ v4 := MustParseAddr("1.2.3.4")
+ v6 := MustParseAddr("66::66")
+ tests := []struct {
+ ip Addr
+ in, want int
+ }{
+ {v4, 0, 0},
+ {v6, 0, 0},
+ {v4, 1, 1},
+ {v4, 33, -1},
+ {v6, 33, 33},
+ {v6, 127, 127},
+ {v6, 128, 128},
+ {v4, 254, -1},
+ {v4, 255, -1},
+ {v4, -1, -1},
+ {v6, -1, -1},
+ {v4, -5, -1},
+ {v6, -5, -1},
+ }
+ for _, tt := range tests {
+ p := PrefixFrom(tt.ip, tt.in)
+ if got := p.Bits(); got != tt.want {
+ t.Errorf("for (%v, %v), Bits out = %v; want %v", tt.ip, tt.in, got, tt.want)
+ }
+ }
+}
+
+func TestParsePrefixAllocs(t *testing.T) {
+ tests := []struct {
+ ip string
+ slash string
+ }{
+ {"192.168.1.0", "/24"},
+ {"aaaa:bbbb:cccc::", "/24"},
+ }
+ for _, test := range tests {
+ prefix := test.ip + test.slash
+ t.Run(prefix, func(t *testing.T) {
+ ipAllocs := int(testing.AllocsPerRun(5, func() {
+ ParseAddr(test.ip)
+ }))
+ prefixAllocs := int(testing.AllocsPerRun(5, func() {
+ ParsePrefix(prefix)
+ }))
+ if got := prefixAllocs - ipAllocs; got != 0 {
+ t.Errorf("allocs=%d, want 0", got)
+ }
+ })
+ }
+}
+
+func TestParsePrefixError(t *testing.T) {
+ tests := []struct {
+ prefix string
+ errstr string
+ }{
+ {
+ prefix: "192.168.0.0",
+ errstr: "no '/'",
+ },
+ {
+ prefix: "1.257.1.1/24",
+ errstr: "value >255",
+ },
+ {
+ prefix: "1.1.1.0/q",
+ errstr: "bad bits",
+ },
+ {
+ prefix: "1.1.1.0/-1",
+ errstr: "out of range",
+ },
+ {
+ prefix: "1.1.1.0/33",
+ errstr: "out of range",
+ },
+ {
+ prefix: "2001::/129",
+ errstr: "out of range",
+ },
+ // Zones are not allowed: https://go.dev/issue/51899
+ {
+ prefix: "1.1.1.0%a/24",
+ errstr: "unexpected character",
+ },
+ {
+ prefix: "2001:db8::%a/32",
+ errstr: "zones cannot be present",
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.prefix, func(t *testing.T) {
+ _, err := ParsePrefix(test.prefix)
+ if err == nil {
+ t.Fatal("no error")
+ }
+ if got := err.Error(); !strings.Contains(got, test.errstr) {
+ t.Errorf("error is missing substring %q: %s", test.errstr, got)
+ }
+ })
+ }
+}
+
+func TestPrefixIsSingleIP(t *testing.T) {
+ tests := []struct {
+ ipp Prefix
+ want bool
+ }{
+ {ipp: mustPrefix("127.0.0.1/32"), want: true},
+ {ipp: mustPrefix("127.0.0.1/31"), want: false},
+ {ipp: mustPrefix("127.0.0.1/0"), want: false},
+ {ipp: mustPrefix("::1/128"), want: true},
+ {ipp: mustPrefix("::1/127"), want: false},
+ {ipp: mustPrefix("::1/0"), want: false},
+ {ipp: Prefix{}, want: false},
+ }
+ for _, tt := range tests {
+ got := tt.ipp.IsSingleIP()
+ if got != tt.want {
+ t.Errorf("IsSingleIP(%v) = %v want %v", tt.ipp, got, tt.want)
+ }
+ }
+}
+
+func mustIPs(strs ...string) []Addr {
+ var res []Addr
+ for _, s := range strs {
+ res = append(res, mustIP(s))
+ }
+ return res
+}
+
+func BenchmarkBinaryMarshalRoundTrip(b *testing.B) {
+ b.ReportAllocs()
+ tests := []struct {
+ name string
+ ip string
+ }{
+ {"ipv4", "1.2.3.4"},
+ {"ipv6", "2001:db8::1"},
+ {"ipv6+zone", "2001:db8::1%eth0"},
+ }
+ for _, tc := range tests {
+ b.Run(tc.name, func(b *testing.B) {
+ ip := mustIP(tc.ip)
+ for i := 0; i < b.N; i++ {
+ bt, err := ip.MarshalBinary()
+ if err != nil {
+ b.Fatal(err)
+ }
+ var ip2 Addr
+ if err := ip2.UnmarshalBinary(bt); err != nil {
+ b.Fatal(err)
+ }
+ }
+ })
+ }
+}
+
+func BenchmarkStdIPv4(b *testing.B) {
+ b.ReportAllocs()
+ ips := []net.IP{}
+ for i := 0; i < b.N; i++ {
+ ip := net.IPv4(8, 8, 8, 8)
+ ips = ips[:0]
+ for i := 0; i < 100; i++ {
+ ips = append(ips, ip)
+ }
+ }
+}
+
+func BenchmarkIPv4(b *testing.B) {
+ b.ReportAllocs()
+ ips := []Addr{}
+ for i := 0; i < b.N; i++ {
+ ip := IPv4(8, 8, 8, 8)
+ ips = ips[:0]
+ for i := 0; i < 100; i++ {
+ ips = append(ips, ip)
+ }
+ }
+}
+
+// ip4i was one of the possible representations of IP that came up in
+// discussions, inlining IPv4 addresses, but having an "overflow"
+// interface for IPv6 or IPv6 + zone. This is here for benchmarking.
+type ip4i struct {
+ ip4 [4]byte
+ flags1 byte
+ flags2 byte
+ flags3 byte
+ flags4 byte
+ ipv6 any
+}
+
+func newip4i_v4(a, b, c, d byte) ip4i {
+ return ip4i{ip4: [4]byte{a, b, c, d}}
+}
+
+// BenchmarkIPv4_inline benchmarks the candidate representation, ip4i.
+func BenchmarkIPv4_inline(b *testing.B) {
+ b.ReportAllocs()
+ ips := []ip4i{}
+ for i := 0; i < b.N; i++ {
+ ip := newip4i_v4(8, 8, 8, 8)
+ ips = ips[:0]
+ for i := 0; i < 100; i++ {
+ ips = append(ips, ip)
+ }
+ }
+}
+
+func BenchmarkStdIPv6(b *testing.B) {
+ b.ReportAllocs()
+ ips := []net.IP{}
+ for i := 0; i < b.N; i++ {
+ ip := net.ParseIP("2001:db8::1")
+ ips = ips[:0]
+ for i := 0; i < 100; i++ {
+ ips = append(ips, ip)
+ }
+ }
+}
+
+func BenchmarkIPv6(b *testing.B) {
+ b.ReportAllocs()
+ ips := []Addr{}
+ for i := 0; i < b.N; i++ {
+ ip := mustIP("2001:db8::1")
+ ips = ips[:0]
+ for i := 0; i < 100; i++ {
+ ips = append(ips, ip)
+ }
+ }
+}
+
+func BenchmarkIPv4Contains(b *testing.B) {
+ b.ReportAllocs()
+ prefix := PrefixFrom(IPv4(192, 168, 1, 0), 24)
+ ip := IPv4(192, 168, 1, 1)
+ for i := 0; i < b.N; i++ {
+ prefix.Contains(ip)
+ }
+}
+
+func BenchmarkIPv6Contains(b *testing.B) {
+ b.ReportAllocs()
+ prefix := MustParsePrefix("::1/128")
+ ip := MustParseAddr("::1")
+ for i := 0; i < b.N; i++ {
+ prefix.Contains(ip)
+ }
+}
+
+var parseBenchInputs = []struct {
+ name string
+ ip string
+}{
+ {"v4", "192.168.1.1"},
+ {"v6", "fd7a:115c:a1e0:ab12:4843:cd96:626b:430b"},
+ {"v6_ellipsis", "fd7a:115c::626b:430b"},
+ {"v6_v4", "::ffff:192.168.140.255"},
+ {"v6_zone", "1:2::ffff:192.168.140.255%eth1"},
+}
+
+func BenchmarkParseAddr(b *testing.B) {
+ sinkInternValue = intern.Get("eth1") // Pin to not benchmark the intern package
+ for _, test := range parseBenchInputs {
+ b.Run(test.name, func(b *testing.B) {
+ b.ReportAllocs()
+ for i := 0; i < b.N; i++ {
+ sinkIP, _ = ParseAddr(test.ip)
+ }
+ })
+ }
+}
+
+func BenchmarkStdParseIP(b *testing.B) {
+ for _, test := range parseBenchInputs {
+ b.Run(test.name, func(b *testing.B) {
+ b.ReportAllocs()
+ for i := 0; i < b.N; i++ {
+ sinkStdIP = net.ParseIP(test.ip)
+ }
+ })
+ }
+}
+
+func BenchmarkIPString(b *testing.B) {
+ for _, test := range parseBenchInputs {
+ ip := MustParseAddr(test.ip)
+ b.Run(test.name, func(b *testing.B) {
+ b.ReportAllocs()
+ for i := 0; i < b.N; i++ {
+ sinkString = ip.String()
+ }
+ })
+ }
+}
+
+func BenchmarkIPStringExpanded(b *testing.B) {
+ for _, test := range parseBenchInputs {
+ ip := MustParseAddr(test.ip)
+ b.Run(test.name, func(b *testing.B) {
+ b.ReportAllocs()
+ for i := 0; i < b.N; i++ {
+ sinkString = ip.StringExpanded()
+ }
+ })
+ }
+}
+
+func BenchmarkIPMarshalText(b *testing.B) {
+ b.ReportAllocs()
+ ip := MustParseAddr("66.55.44.33")
+ for i := 0; i < b.N; i++ {
+ sinkBytes, _ = ip.MarshalText()
+ }
+}
+
+func BenchmarkAddrPortString(b *testing.B) {
+ for _, test := range parseBenchInputs {
+ ip := MustParseAddr(test.ip)
+ ipp := AddrPortFrom(ip, 60000)
+ b.Run(test.name, func(b *testing.B) {
+ b.ReportAllocs()
+ for i := 0; i < b.N; i++ {
+ sinkString = ipp.String()
+ }
+ })
+ }
+}
+
+func BenchmarkAddrPortMarshalText(b *testing.B) {
+ for _, test := range parseBenchInputs {
+ ip := MustParseAddr(test.ip)
+ ipp := AddrPortFrom(ip, 60000)
+ b.Run(test.name, func(b *testing.B) {
+ b.ReportAllocs()
+ for i := 0; i < b.N; i++ {
+ sinkBytes, _ = ipp.MarshalText()
+ }
+ })
+ }
+}
+
+func BenchmarkPrefixMasking(b *testing.B) {
+ tests := []struct {
+ name string
+ ip Addr
+ bits int
+ }{
+ {
+ name: "IPv4 /32",
+ ip: IPv4(192, 0, 2, 0),
+ bits: 32,
+ },
+ {
+ name: "IPv4 /17",
+ ip: IPv4(192, 0, 2, 0),
+ bits: 17,
+ },
+ {
+ name: "IPv4 /0",
+ ip: IPv4(192, 0, 2, 0),
+ bits: 0,
+ },
+ {
+ name: "IPv6 /128",
+ ip: mustIP("2001:db8::1"),
+ bits: 128,
+ },
+ {
+ name: "IPv6 /65",
+ ip: mustIP("2001:db8::1"),
+ bits: 65,
+ },
+ {
+ name: "IPv6 /0",
+ ip: mustIP("2001:db8::1"),
+ bits: 0,
+ },
+ {
+ name: "IPv6 zone /128",
+ ip: mustIP("2001:db8::1%eth0"),
+ bits: 128,
+ },
+ {
+ name: "IPv6 zone /65",
+ ip: mustIP("2001:db8::1%eth0"),
+ bits: 65,
+ },
+ {
+ name: "IPv6 zone /0",
+ ip: mustIP("2001:db8::1%eth0"),
+ bits: 0,
+ },
+ }
+
+ for _, tt := range tests {
+ b.Run(tt.name, func(b *testing.B) {
+ b.ReportAllocs()
+
+ for i := 0; i < b.N; i++ {
+ sinkPrefix, _ = tt.ip.Prefix(tt.bits)
+ }
+ })
+ }
+}
+
+func BenchmarkPrefixMarshalText(b *testing.B) {
+ b.ReportAllocs()
+ ipp := MustParsePrefix("66.55.44.33/22")
+ for i := 0; i < b.N; i++ {
+ sinkBytes, _ = ipp.MarshalText()
+ }
+}
+
+func BenchmarkParseAddrPort(b *testing.B) {
+ for _, test := range parseBenchInputs {
+ var ipp string
+ if strings.HasPrefix(test.name, "v6") {
+ ipp = fmt.Sprintf("[%s]:1234", test.ip)
+ } else {
+ ipp = fmt.Sprintf("%s:1234", test.ip)
+ }
+ b.Run(test.name, func(b *testing.B) {
+ b.ReportAllocs()
+
+ for i := 0; i < b.N; i++ {
+ sinkAddrPort, _ = ParseAddrPort(ipp)
+ }
+ })
+ }
+}
+
+func TestAs4(t *testing.T) {
+ tests := []struct {
+ ip Addr
+ want [4]byte
+ wantPanic bool
+ }{
+ {
+ ip: mustIP("1.2.3.4"),
+ want: [4]byte{1, 2, 3, 4},
+ },
+ {
+ ip: AddrFrom16(mustIP("1.2.3.4").As16()), // IPv4-in-IPv6
+ want: [4]byte{1, 2, 3, 4},
+ },
+ {
+ ip: mustIP("0.0.0.0"),
+ want: [4]byte{0, 0, 0, 0},
+ },
+ {
+ ip: Addr{},
+ wantPanic: true,
+ },
+ {
+ ip: mustIP("::1"),
+ wantPanic: true,
+ },
+ }
+ as4 := func(ip Addr) (v [4]byte, gotPanic bool) {
+ defer func() {
+ if recover() != nil {
+ gotPanic = true
+ return
+ }
+ }()
+ v = ip.As4()
+ return
+ }
+ for i, tt := range tests {
+ got, gotPanic := as4(tt.ip)
+ if gotPanic != tt.wantPanic {
+ t.Errorf("%d. panic on %v = %v; want %v", i, tt.ip, gotPanic, tt.wantPanic)
+ continue
+ }
+ if got != tt.want {
+ t.Errorf("%d. %v = %v; want %v", i, tt.ip, got, tt.want)
+ }
+ }
+}
+
+func TestPrefixOverlaps(t *testing.T) {
+ pfx := mustPrefix
+ tests := []struct {
+ a, b Prefix
+ want bool
+ }{
+ {Prefix{}, pfx("1.2.0.0/16"), false}, // first zero
+ {pfx("1.2.0.0/16"), Prefix{}, false}, // second zero
+ {pfx("::0/3"), pfx("0.0.0.0/3"), false}, // different families
+
+ {pfx("1.2.0.0/16"), pfx("1.2.0.0/16"), true}, // equal
+
+ {pfx("1.2.0.0/16"), pfx("1.2.3.0/24"), true},
+ {pfx("1.2.3.0/24"), pfx("1.2.0.0/16"), true},
+
+ {pfx("1.2.0.0/16"), pfx("1.2.3.0/32"), true},
+ {pfx("1.2.3.0/32"), pfx("1.2.0.0/16"), true},
+
+ // Match /0 either order
+ {pfx("1.2.3.0/32"), pfx("0.0.0.0/0"), true},
+ {pfx("0.0.0.0/0"), pfx("1.2.3.0/32"), true},
+
+ {pfx("1.2.3.0/32"), pfx("5.5.5.5/0"), true}, // normalization not required; /0 means true
+
+ // IPv6 overlapping
+ {pfx("5::1/128"), pfx("5::0/8"), true},
+ {pfx("5::0/8"), pfx("5::1/128"), true},
+
+ // IPv6 not overlapping
+ {pfx("1::1/128"), pfx("2::2/128"), false},
+ {pfx("0100::0/8"), pfx("::1/128"), false},
+
+ // IPv4-mapped IPv6 addresses should not overlap with IPv4.
+ {PrefixFrom(AddrFrom16(mustIP("1.2.0.0").As16()), 16), pfx("1.2.3.0/24"), false},
+
+ // Invalid prefixes
+ {PrefixFrom(mustIP("1.2.3.4"), 33), pfx("1.2.3.0/24"), false},
+ {PrefixFrom(mustIP("2000::"), 129), pfx("2000::/64"), false},
+ }
+ for i, tt := range tests {
+ if got := tt.a.Overlaps(tt.b); got != tt.want {
+ t.Errorf("%d. (%v).Overlaps(%v) = %v; want %v", i, tt.a, tt.b, got, tt.want)
+ }
+ // Overlaps is commutative
+ if got := tt.b.Overlaps(tt.a); got != tt.want {
+ t.Errorf("%d. (%v).Overlaps(%v) = %v; want %v", i, tt.b, tt.a, got, tt.want)
+ }
+ }
+}
+
+// Sink variables are here to force the compiler to not elide
+// seemingly useless work in benchmarks and allocation tests. If you
+// were to just `_ = foo()` within a test function, the compiler could
+// correctly deduce that foo() does nothing and doesn't need to be
+// called. By writing results to a global variable, we hide that fact
+// from the compiler and force it to keep the code under test.
+var (
+ sinkIP Addr
+ sinkStdIP net.IP
+ sinkAddrPort AddrPort
+ sinkPrefix Prefix
+ sinkPrefixSlice []Prefix
+ sinkInternValue *intern.Value
+ sinkIP16 [16]byte
+ sinkIP4 [4]byte
+ sinkBool bool
+ sinkString string
+ sinkBytes []byte
+ sinkUDPAddr = &net.UDPAddr{IP: make(net.IP, 0, 16)}
+)
+
+func TestNoAllocs(t *testing.T) {
+ // Wrappers that panic on error, to prove that our alloc-free
+ // methods are returning successfully.
+ panicIP := func(ip Addr, err error) Addr {
+ if err != nil {
+ panic(err)
+ }
+ return ip
+ }
+ panicPfx := func(pfx Prefix, err error) Prefix {
+ if err != nil {
+ panic(err)
+ }
+ return pfx
+ }
+ panicIPP := func(ipp AddrPort, err error) AddrPort {
+ if err != nil {
+ panic(err)
+ }
+ return ipp
+ }
+ test := func(name string, f func()) {
+ t.Run(name, func(t *testing.T) {
+ n := testing.AllocsPerRun(1000, f)
+ if n != 0 {
+ t.Fatalf("allocs = %d; want 0", int(n))
+ }
+ })
+ }
+
+ // Addr constructors
+ test("IPv4", func() { sinkIP = IPv4(1, 2, 3, 4) })
+ test("AddrFrom4", func() { sinkIP = AddrFrom4([4]byte{1, 2, 3, 4}) })
+ test("AddrFrom16", func() { sinkIP = AddrFrom16([16]byte{}) })
+ test("ParseAddr/4", func() { sinkIP = panicIP(ParseAddr("1.2.3.4")) })
+ test("ParseAddr/6", func() { sinkIP = panicIP(ParseAddr("::1")) })
+ test("MustParseAddr", func() { sinkIP = MustParseAddr("1.2.3.4") })
+ test("IPv6LinkLocalAllNodes", func() { sinkIP = IPv6LinkLocalAllNodes() })
+ test("IPv6LinkLocalAllRouters", func() { sinkIP = IPv6LinkLocalAllRouters() })
+ test("IPv6Loopback", func() { sinkIP = IPv6Loopback() })
+ test("IPv6Unspecified", func() { sinkIP = IPv6Unspecified() })
+
+ // Addr methods
+ test("Addr.IsZero", func() { sinkBool = MustParseAddr("1.2.3.4").IsZero() })
+ test("Addr.BitLen", func() { sinkBool = MustParseAddr("1.2.3.4").BitLen() == 8 })
+ test("Addr.Zone/4", func() { sinkBool = MustParseAddr("1.2.3.4").Zone() == "" })
+ test("Addr.Zone/6", func() { sinkBool = MustParseAddr("fe80::1").Zone() == "" })
+ test("Addr.Zone/6zone", func() { sinkBool = MustParseAddr("fe80::1%zone").Zone() == "" })
+ test("Addr.Compare", func() {
+ a := MustParseAddr("1.2.3.4")
+ b := MustParseAddr("2.3.4.5")
+ sinkBool = a.Compare(b) == 0
+ })
+ test("Addr.Less", func() {
+ a := MustParseAddr("1.2.3.4")
+ b := MustParseAddr("2.3.4.5")
+ sinkBool = a.Less(b)
+ })
+ test("Addr.Is4", func() { sinkBool = MustParseAddr("1.2.3.4").Is4() })
+ test("Addr.Is6", func() { sinkBool = MustParseAddr("fe80::1").Is6() })
+ test("Addr.Is4In6", func() { sinkBool = MustParseAddr("fe80::1").Is4In6() })
+ test("Addr.Unmap", func() { sinkIP = MustParseAddr("ffff::2.3.4.5").Unmap() })
+ test("Addr.WithZone", func() { sinkIP = MustParseAddr("fe80::1").WithZone("") })
+ test("Addr.IsGlobalUnicast", func() { sinkBool = MustParseAddr("2001:db8::1").IsGlobalUnicast() })
+ test("Addr.IsInterfaceLocalMulticast", func() { sinkBool = MustParseAddr("fe80::1").IsInterfaceLocalMulticast() })
+ test("Addr.IsLinkLocalMulticast", func() { sinkBool = MustParseAddr("fe80::1").IsLinkLocalMulticast() })
+ test("Addr.IsLinkLocalUnicast", func() { sinkBool = MustParseAddr("fe80::1").IsLinkLocalUnicast() })
+ test("Addr.IsLoopback", func() { sinkBool = MustParseAddr("fe80::1").IsLoopback() })
+ test("Addr.IsMulticast", func() { sinkBool = MustParseAddr("fe80::1").IsMulticast() })
+ test("Addr.IsPrivate", func() { sinkBool = MustParseAddr("fd00::1").IsPrivate() })
+ test("Addr.IsUnspecified", func() { sinkBool = IPv6Unspecified().IsUnspecified() })
+ test("Addr.Prefix/4", func() { sinkPrefix = panicPfx(MustParseAddr("1.2.3.4").Prefix(20)) })
+ test("Addr.Prefix/6", func() { sinkPrefix = panicPfx(MustParseAddr("fe80::1").Prefix(64)) })
+ test("Addr.As16", func() { sinkIP16 = MustParseAddr("1.2.3.4").As16() })
+ test("Addr.As4", func() { sinkIP4 = MustParseAddr("1.2.3.4").As4() })
+ test("Addr.Next", func() { sinkIP = MustParseAddr("1.2.3.4").Next() })
+ test("Addr.Prev", func() { sinkIP = MustParseAddr("1.2.3.4").Prev() })
+
+ // AddrPort constructors
+ test("AddrPortFrom", func() { sinkAddrPort = AddrPortFrom(IPv4(1, 2, 3, 4), 22) })
+ test("ParseAddrPort", func() { sinkAddrPort = panicIPP(ParseAddrPort("[::1]:1234")) })
+ test("MustParseAddrPort", func() { sinkAddrPort = MustParseAddrPort("[::1]:1234") })
+
+ // Prefix constructors
+ test("PrefixFrom", func() { sinkPrefix = PrefixFrom(IPv4(1, 2, 3, 4), 32) })
+ test("ParsePrefix/4", func() { sinkPrefix = panicPfx(ParsePrefix("1.2.3.4/20")) })
+ test("ParsePrefix/6", func() { sinkPrefix = panicPfx(ParsePrefix("fe80::1/64")) })
+ test("MustParsePrefix", func() { sinkPrefix = MustParsePrefix("1.2.3.4/20") })
+
+ // Prefix methods
+ test("Prefix.Contains", func() { sinkBool = MustParsePrefix("1.2.3.0/24").Contains(MustParseAddr("1.2.3.4")) })
+ test("Prefix.Overlaps", func() {
+ a, b := MustParsePrefix("1.2.3.0/24"), MustParsePrefix("1.2.0.0/16")
+ sinkBool = a.Overlaps(b)
+ })
+ test("Prefix.IsZero", func() { sinkBool = MustParsePrefix("1.2.0.0/16").IsZero() })
+ test("Prefix.IsSingleIP", func() { sinkBool = MustParsePrefix("1.2.3.4/32").IsSingleIP() })
+ test("Prefix.Masked", func() { sinkPrefix = MustParsePrefix("1.2.3.4/16").Masked() })
+}
+
+func TestAddrStringAllocs(t *testing.T) {
+ tests := []struct {
+ name string
+ ip Addr
+ wantAllocs int
+ }{
+ {"zero", Addr{}, 0},
+ {"ipv4", MustParseAddr("192.168.1.1"), 1},
+ {"ipv6", MustParseAddr("2001:db8::1"), 1},
+ {"ipv6+zone", MustParseAddr("2001:db8::1%eth0"), 1},
+ {"ipv4-in-ipv6", MustParseAddr("::ffff:192.168.1.1"), 1},
+ {"ipv4-in-ipv6+zone", MustParseAddr("::ffff:192.168.1.1%eth0"), 1},
+ }
+ optimizationOff := testenv.OptimizationOff()
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ if optimizationOff && strings.HasPrefix(tc.name, "ipv4-in-ipv6") {
+ // Optimizations are required to remove some allocs.
+ t.Skipf("skipping on %v", testenv.Builder())
+ }
+ allocs := int(testing.AllocsPerRun(1000, func() {
+ sinkString = tc.ip.String()
+ }))
+ if allocs != tc.wantAllocs {
+ t.Errorf("allocs=%d, want %d", allocs, tc.wantAllocs)
+ }
+ })
+ }
+}
+
+func TestPrefixString(t *testing.T) {
+ tests := []struct {
+ ipp Prefix
+ want string
+ }{
+ {Prefix{}, "invalid Prefix"},
+ {PrefixFrom(Addr{}, 8), "invalid Prefix"},
+ {PrefixFrom(MustParseAddr("1.2.3.4"), 88), "invalid Prefix"},
+ }
+
+ for _, tt := range tests {
+ if got := tt.ipp.String(); got != tt.want {
+ t.Errorf("(%#v).String() = %q want %q", tt.ipp, got, tt.want)
+ }
+ }
+}
+
+func TestInvalidAddrPortString(t *testing.T) {
+ tests := []struct {
+ ipp AddrPort
+ want string
+ }{
+ {AddrPort{}, "invalid AddrPort"},
+ {AddrPortFrom(Addr{}, 80), "invalid AddrPort"},
+ }
+
+ for _, tt := range tests {
+ if got := tt.ipp.String(); got != tt.want {
+ t.Errorf("(%#v).String() = %q want %q", tt.ipp, got, tt.want)
+ }
+ }
+}
+
+func TestAsSlice(t *testing.T) {
+ tests := []struct {
+ in Addr
+ want []byte
+ }{
+ {in: Addr{}, want: nil},
+ {in: mustIP("1.2.3.4"), want: []byte{1, 2, 3, 4}},
+ {in: mustIP("ffff::1"), want: []byte{0xff, 0xff, 15: 1}},
+ }
+
+ for _, test := range tests {
+ got := test.in.AsSlice()
+ if !bytes.Equal(got, test.want) {
+ t.Errorf("%v.AsSlice() = %v want %v", test.in, got, test.want)
+ }
+ }
+}
+
+var sink16 [16]byte
+
+func BenchmarkAs16(b *testing.B) {
+ addr := MustParseAddr("1::10")
+ for i := 0; i < b.N; i++ {
+ sink16 = addr.As16()
+ }
+}
diff --git a/src/net/netip/slow_test.go b/src/net/netip/slow_test.go
new file mode 100644
index 0000000..d7c8025
--- /dev/null
+++ b/src/net/netip/slow_test.go
@@ -0,0 +1,190 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package netip_test
+
+import (
+ "fmt"
+ . "net/netip"
+ "strconv"
+ "strings"
+)
+
+// zeros is a slice of eight stringified zeros. It's used in
+// parseIPSlow to construct slices of specific amounts of zero fields,
+// from 1 to 8.
+var zeros = []string{"0", "0", "0", "0", "0", "0", "0", "0"}
+
+// parseIPSlow is like ParseIP, but aims for readability above
+// speed. It's the reference implementation for correctness checking
+// and against which we measure optimized parsers.
+//
+// parseIPSlow understands the following forms of IP addresses:
+// - Regular IPv4: 1.2.3.4
+// - IPv4 with many leading zeros: 0000001.0000002.0000003.0000004
+// - Regular IPv6: 1111:2222:3333:4444:5555:6666:7777:8888
+// - IPv6 with many leading zeros: 00000001:0000002:0000003:0000004:0000005:0000006:0000007:0000008
+// - IPv6 with zero blocks elided: 1111:2222::7777:8888
+// - IPv6 with trailing 32 bits expressed as IPv4: 1111:2222:3333:4444:5555:6666:77.77.88.88
+//
+// It does not process the following IP address forms, which have been
+// varyingly accepted by some programs due to an under-specification
+// of the shapes of IPv4 addresses:
+//
+// - IPv4 as a single 32-bit uint: 4660 (same as "1.2.3.4")
+// - IPv4 with octal numbers: 0300.0250.0.01 (same as "192.168.0.1")
+// - IPv4 with hex numbers: 0xc0.0xa8.0x0.0x1 (same as "192.168.0.1")
+// - IPv4 in "class-B style": 1.2.52 (same as "1.2.3.4")
+// - IPv4 in "class-A style": 1.564 (same as "1.2.3.4")
+func parseIPSlow(s string) (Addr, error) {
+ // Identify and strip out the zone, if any. There should be 0 or 1
+ // '%' in the string.
+ var zone string
+ fs := strings.Split(s, "%")
+ switch len(fs) {
+ case 1:
+ // No zone, that's fine.
+ case 2:
+ s, zone = fs[0], fs[1]
+ if zone == "" {
+ return Addr{}, fmt.Errorf("netaddr.ParseIP(%q): no zone after zone specifier", s)
+ }
+ default:
+ return Addr{}, fmt.Errorf("netaddr.ParseIP(%q): too many zone specifiers", s) // TODO: less specific?
+ }
+
+ // IPv4 by itself is easy to do in a helper.
+ if strings.Count(s, ":") == 0 {
+ if zone != "" {
+ return Addr{}, fmt.Errorf("netaddr.ParseIP(%q): IPv4 addresses cannot have a zone", s)
+ }
+ return parseIPv4Slow(s)
+ }
+
+ normal, err := normalizeIPv6Slow(s)
+ if err != nil {
+ return Addr{}, err
+ }
+
+ // At this point, we've normalized the address back into 8 hex
+ // fields of 16 bits each. Parse that.
+ fs = strings.Split(normal, ":")
+ if len(fs) != 8 {
+ return Addr{}, fmt.Errorf("netaddr.ParseIP(%q): wrong size address", s)
+ }
+ var ret [16]byte
+ for i, f := range fs {
+ a, b, err := parseWord(f)
+ if err != nil {
+ return Addr{}, err
+ }
+ ret[i*2] = a
+ ret[i*2+1] = b
+ }
+
+ return AddrFrom16(ret).WithZone(zone), nil
+}
+
+// normalizeIPv6Slow expands s, which is assumed to be an IPv6
+// address, to its canonical text form.
+//
+// The canonical form of an IPv6 address is 8 colon-separated fields,
+// where each field should be a hex value from 0 to ffff. This
+// function does not verify the contents of each field.
+//
+// This function performs two transformations:
+// - The last 32 bits of an IPv6 address may be represented in
+// IPv4-style dotted quad form, as in 1:2:3:4:5:6:7.8.9.10. That
+// address is transformed to its hex equivalent,
+// e.g. 1:2:3:4:5:6:708:90a.
+// - An address may contain one "::", which expands into as many
+// 16-bit blocks of zeros as needed to make the address its correct
+// full size. For example, fe80::1:2 expands to fe80:0:0:0:0:0:1:2.
+//
+// Both short forms may be present in a single address,
+// e.g. fe80::1.2.3.4.
+func normalizeIPv6Slow(orig string) (string, error) {
+ s := orig
+
+ // Find and convert an IPv4 address in the final field, if any.
+ i := strings.LastIndex(s, ":")
+ if i == -1 {
+ return "", fmt.Errorf("netaddr.ParseIP(%q): invalid IP address", orig)
+ }
+ if strings.Contains(s[i+1:], ".") {
+ ip, err := parseIPv4Slow(s[i+1:])
+ if err != nil {
+ return "", err
+ }
+ a4 := ip.As4()
+ s = fmt.Sprintf("%s:%02x%02x:%02x%02x", s[:i], a4[0], a4[1], a4[2], a4[3])
+ }
+
+ // Find and expand a ::, if any.
+ fs := strings.Split(s, "::")
+ switch len(fs) {
+ case 1:
+ // No ::, nothing to do.
+ case 2:
+ lhs, rhs := fs[0], fs[1]
+ // Found a ::, figure out how many zero blocks need to be
+ // inserted.
+ nblocks := strings.Count(lhs, ":") + strings.Count(rhs, ":")
+ if lhs != "" {
+ nblocks++
+ }
+ if rhs != "" {
+ nblocks++
+ }
+ if nblocks > 7 {
+ return "", fmt.Errorf("netaddr.ParseIP(%q): address too long", orig)
+ }
+ fs = nil
+ // Either side of the :: can be empty. We don't want empty
+ // fields to feature in the final normalized address.
+ if lhs != "" {
+ fs = append(fs, lhs)
+ }
+ fs = append(fs, zeros[:8-nblocks]...)
+ if rhs != "" {
+ fs = append(fs, rhs)
+ }
+ s = strings.Join(fs, ":")
+ default:
+ // Too many ::
+ return "", fmt.Errorf("netaddr.ParseIP(%q): invalid IP address", orig)
+ }
+
+ return s, nil
+}
+
+// parseIPv4Slow parses and returns an IPv4 address in dotted quad
+// form, e.g. "192.168.0.1". It is slow but easy to read, and the
+// reference implementation against which we compare faster
+// implementations for correctness.
+func parseIPv4Slow(s string) (Addr, error) {
+ fs := strings.Split(s, ".")
+ if len(fs) != 4 {
+ return Addr{}, fmt.Errorf("netaddr.ParseIP(%q): invalid IP address", s)
+ }
+ var ret [4]byte
+ for i := range ret {
+ val, err := strconv.ParseUint(fs[i], 10, 8)
+ if err != nil {
+ return Addr{}, err
+ }
+ ret[i] = uint8(val)
+ }
+ return AddrFrom4([4]byte{ret[0], ret[1], ret[2], ret[3]}), nil
+}
+
+// parseWord converts a 16-bit hex string into its corresponding
+// two-byte value.
+func parseWord(s string) (byte, byte, error) {
+ ret, err := strconv.ParseUint(s, 16, 16)
+ if err != nil {
+ return 0, 0, err
+ }
+ return uint8(ret >> 8), uint8(ret), nil
+}
diff --git a/src/net/netip/uint128.go b/src/net/netip/uint128.go
new file mode 100644
index 0000000..b1605af
--- /dev/null
+++ b/src/net/netip/uint128.go
@@ -0,0 +1,81 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package netip
+
+import "math/bits"
+
+// uint128 represents a uint128 using two uint64s.
+//
+// When the methods below mention a bit number, bit 0 is the most
+// significant bit (in hi) and bit 127 is the lowest (lo&1).
+type uint128 struct {
+ hi uint64
+ lo uint64
+}
+
+// mask6 returns a uint128 bitmask with the topmost n bits of a
+// 128-bit number.
+func mask6(n int) uint128 {
+ return uint128{^(^uint64(0) >> n), ^uint64(0) << (128 - n)}
+}
+
+// isZero reports whether u == 0.
+//
+// It's faster than u == (uint128{}) because the compiler (as of Go
+// 1.15/1.16b1) doesn't do this trick and instead inserts a branch in
+// its eq alg's generated code.
+func (u uint128) isZero() bool { return u.hi|u.lo == 0 }
+
+// and returns the bitwise AND of u and m (u&m).
+func (u uint128) and(m uint128) uint128 {
+ return uint128{u.hi & m.hi, u.lo & m.lo}
+}
+
+// xor returns the bitwise XOR of u and m (u^m).
+func (u uint128) xor(m uint128) uint128 {
+ return uint128{u.hi ^ m.hi, u.lo ^ m.lo}
+}
+
+// or returns the bitwise OR of u and m (u|m).
+func (u uint128) or(m uint128) uint128 {
+ return uint128{u.hi | m.hi, u.lo | m.lo}
+}
+
+// not returns the bitwise NOT of u.
+func (u uint128) not() uint128 {
+ return uint128{^u.hi, ^u.lo}
+}
+
+// subOne returns u - 1.
+func (u uint128) subOne() uint128 {
+ lo, borrow := bits.Sub64(u.lo, 1, 0)
+ return uint128{u.hi - borrow, lo}
+}
+
+// addOne returns u + 1.
+func (u uint128) addOne() uint128 {
+ lo, carry := bits.Add64(u.lo, 1, 0)
+ return uint128{u.hi + carry, lo}
+}
+
+// halves returns the two uint64 halves of the uint128.
+//
+// Logically, think of it as returning two uint64s.
+// It only returns pointers for inlining reasons on 32-bit platforms.
+func (u *uint128) halves() [2]*uint64 {
+ return [2]*uint64{&u.hi, &u.lo}
+}
+
+// bitsSetFrom returns a copy of u with the given bit
+// and all subsequent ones set.
+func (u uint128) bitsSetFrom(bit uint8) uint128 {
+ return u.or(mask6(int(bit)).not())
+}
+
+// bitsClearedFrom returns a copy of u with the given bit
+// and all subsequent ones cleared.
+func (u uint128) bitsClearedFrom(bit uint8) uint128 {
+ return u.and(mask6(int(bit)))
+}
diff --git a/src/net/netip/uint128_test.go b/src/net/netip/uint128_test.go
new file mode 100644
index 0000000..dd1ae0e
--- /dev/null
+++ b/src/net/netip/uint128_test.go
@@ -0,0 +1,89 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package netip
+
+import (
+ "testing"
+)
+
+func TestUint128AddSub(t *testing.T) {
+ const add1 = 1
+ const sub1 = -1
+ tests := []struct {
+ in uint128
+ op int // +1 or -1 to add vs subtract
+ want uint128
+ }{
+ {uint128{0, 0}, add1, uint128{0, 1}},
+ {uint128{0, 1}, add1, uint128{0, 2}},
+ {uint128{1, 0}, add1, uint128{1, 1}},
+ {uint128{0, ^uint64(0)}, add1, uint128{1, 0}},
+ {uint128{^uint64(0), ^uint64(0)}, add1, uint128{0, 0}},
+
+ {uint128{0, 0}, sub1, uint128{^uint64(0), ^uint64(0)}},
+ {uint128{0, 1}, sub1, uint128{0, 0}},
+ {uint128{0, 2}, sub1, uint128{0, 1}},
+ {uint128{1, 0}, sub1, uint128{0, ^uint64(0)}},
+ {uint128{1, 1}, sub1, uint128{1, 0}},
+ }
+ for _, tt := range tests {
+ var got uint128
+ switch tt.op {
+ case add1:
+ got = tt.in.addOne()
+ case sub1:
+ got = tt.in.subOne()
+ default:
+ panic("bogus op")
+ }
+ if got != tt.want {
+ t.Errorf("%v add %d = %v; want %v", tt.in, tt.op, got, tt.want)
+ }
+ }
+}
+
+func TestBitsSetFrom(t *testing.T) {
+ tests := []struct {
+ bit uint8
+ want uint128
+ }{
+ {0, uint128{^uint64(0), ^uint64(0)}},
+ {1, uint128{^uint64(0) >> 1, ^uint64(0)}},
+ {63, uint128{1, ^uint64(0)}},
+ {64, uint128{0, ^uint64(0)}},
+ {65, uint128{0, ^uint64(0) >> 1}},
+ {127, uint128{0, 1}},
+ {128, uint128{0, 0}},
+ }
+ for _, tt := range tests {
+ var zero uint128
+ got := zero.bitsSetFrom(tt.bit)
+ if got != tt.want {
+ t.Errorf("0.bitsSetFrom(%d) = %064b want %064b", tt.bit, got, tt.want)
+ }
+ }
+}
+
+func TestBitsClearedFrom(t *testing.T) {
+ tests := []struct {
+ bit uint8
+ want uint128
+ }{
+ {0, uint128{0, 0}},
+ {1, uint128{1 << 63, 0}},
+ {63, uint128{^uint64(0) &^ 1, 0}},
+ {64, uint128{^uint64(0), 0}},
+ {65, uint128{^uint64(0), 1 << 63}},
+ {127, uint128{^uint64(0), ^uint64(0) &^ 1}},
+ {128, uint128{^uint64(0), ^uint64(0)}},
+ }
+ for _, tt := range tests {
+ ones := uint128{^uint64(0), ^uint64(0)}
+ got := ones.bitsClearedFrom(tt.bit)
+ if got != tt.want {
+ t.Errorf("ones.bitsClearedFrom(%d) = %064b want %064b", tt.bit, got, tt.want)
+ }
+ }
+}
diff --git a/src/net/nss.go b/src/net/nss.go
new file mode 100644
index 0000000..092b515
--- /dev/null
+++ b/src/net/nss.go
@@ -0,0 +1,249 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "errors"
+ "internal/bytealg"
+ "os"
+ "sync"
+ "time"
+)
+
+const (
+ nssConfigPath = "/etc/nsswitch.conf"
+)
+
+var nssConfig nsswitchConfig
+
+type nsswitchConfig struct {
+ initOnce sync.Once // guards init of nsswitchConfig
+
+ // ch is used as a semaphore that only allows one lookup at a
+ // time to recheck nsswitch.conf
+ ch chan struct{} // guards lastChecked and modTime
+ lastChecked time.Time // last time nsswitch.conf was checked
+
+ mu sync.Mutex // protects nssConf
+ nssConf *nssConf
+}
+
+func getSystemNSS() *nssConf {
+ nssConfig.tryUpdate()
+ nssConfig.mu.Lock()
+ conf := nssConfig.nssConf
+ nssConfig.mu.Unlock()
+ return conf
+}
+
+// init initializes conf and is only called via conf.initOnce.
+func (conf *nsswitchConfig) init() {
+ conf.nssConf = parseNSSConfFile("/etc/nsswitch.conf")
+ conf.lastChecked = time.Now()
+ conf.ch = make(chan struct{}, 1)
+}
+
+// tryUpdate tries to update conf.
+func (conf *nsswitchConfig) tryUpdate() {
+ conf.initOnce.Do(conf.init)
+
+ // Ensure only one update at a time checks nsswitch.conf
+ if !conf.tryAcquireSema() {
+ return
+ }
+ defer conf.releaseSema()
+
+ now := time.Now()
+ if conf.lastChecked.After(now.Add(-5 * time.Second)) {
+ return
+ }
+ conf.lastChecked = now
+
+ var mtime time.Time
+ if fi, err := os.Stat(nssConfigPath); err == nil {
+ mtime = fi.ModTime()
+ }
+ if mtime.Equal(conf.nssConf.mtime) {
+ return
+ }
+
+ nssConf := parseNSSConfFile(nssConfigPath)
+ conf.mu.Lock()
+ conf.nssConf = nssConf
+ conf.mu.Unlock()
+}
+
+func (conf *nsswitchConfig) acquireSema() {
+ conf.ch <- struct{}{}
+}
+
+func (conf *nsswitchConfig) tryAcquireSema() bool {
+ select {
+ case conf.ch <- struct{}{}:
+ return true
+ default:
+ return false
+ }
+}
+
+func (conf *nsswitchConfig) releaseSema() {
+ <-conf.ch
+}
+
+// nssConf represents the state of the machine's /etc/nsswitch.conf file.
+type nssConf struct {
+ mtime time.Time // time of nsswitch.conf modification
+ err error // any error encountered opening or parsing the file
+ sources map[string][]nssSource // keyed by database (e.g. "hosts")
+}
+
+type nssSource struct {
+ source string // e.g. "compat", "files", "mdns4_minimal"
+ criteria []nssCriterion
+}
+
+// standardCriteria reports all specified criteria have the default
+// status actions.
+func (s nssSource) standardCriteria() bool {
+ for i, crit := range s.criteria {
+ if !crit.standardStatusAction(i == len(s.criteria)-1) {
+ return false
+ }
+ }
+ return true
+}
+
+// nssCriterion is the parsed structure of one of the criteria in brackets
+// after an NSS source name.
+type nssCriterion struct {
+ negate bool // if "!" was present
+ status string // e.g. "success", "unavail" (lowercase)
+ action string // e.g. "return", "continue" (lowercase)
+}
+
+// standardStatusAction reports whether c is equivalent to not
+// specifying the criterion at all. last is whether this criteria is the
+// last in the list.
+func (c nssCriterion) standardStatusAction(last bool) bool {
+ if c.negate {
+ return false
+ }
+ var def string
+ switch c.status {
+ case "success":
+ def = "return"
+ case "notfound", "unavail", "tryagain":
+ def = "continue"
+ default:
+ // Unknown status
+ return false
+ }
+ if last && c.action == "return" {
+ return true
+ }
+ return c.action == def
+}
+
+func parseNSSConfFile(file string) *nssConf {
+ f, err := open(file)
+ if err != nil {
+ return &nssConf{err: err}
+ }
+ defer f.close()
+ mtime, _, err := f.stat()
+ if err != nil {
+ return &nssConf{err: err}
+ }
+
+ conf := parseNSSConf(f)
+ conf.mtime = mtime
+ return conf
+}
+
+func parseNSSConf(f *file) *nssConf {
+ conf := new(nssConf)
+ for line, ok := f.readLine(); ok; line, ok = f.readLine() {
+ line = trimSpace(removeComment(line))
+ if len(line) == 0 {
+ continue
+ }
+ colon := bytealg.IndexByteString(line, ':')
+ if colon == -1 {
+ conf.err = errors.New("no colon on line")
+ return conf
+ }
+ db := trimSpace(line[:colon])
+ srcs := line[colon+1:]
+ for {
+ srcs = trimSpace(srcs)
+ if len(srcs) == 0 {
+ break
+ }
+ sp := bytealg.IndexByteString(srcs, ' ')
+ var src string
+ if sp == -1 {
+ src = srcs
+ srcs = "" // done
+ } else {
+ src = srcs[:sp]
+ srcs = trimSpace(srcs[sp+1:])
+ }
+ var criteria []nssCriterion
+ // See if there's a criteria block in brackets.
+ if len(srcs) > 0 && srcs[0] == '[' {
+ bclose := bytealg.IndexByteString(srcs, ']')
+ if bclose == -1 {
+ conf.err = errors.New("unclosed criterion bracket")
+ return conf
+ }
+ var err error
+ criteria, err = parseCriteria(srcs[1:bclose])
+ if err != nil {
+ conf.err = errors.New("invalid criteria: " + srcs[1:bclose])
+ return conf
+ }
+ srcs = srcs[bclose+1:]
+ }
+ if conf.sources == nil {
+ conf.sources = make(map[string][]nssSource)
+ }
+ conf.sources[db] = append(conf.sources[db], nssSource{
+ source: src,
+ criteria: criteria,
+ })
+ }
+ }
+ return conf
+}
+
+// parses "foo=bar !foo=bar"
+func parseCriteria(x string) (c []nssCriterion, err error) {
+ err = foreachField(x, func(f string) error {
+ not := false
+ if len(f) > 0 && f[0] == '!' {
+ not = true
+ f = f[1:]
+ }
+ if len(f) < 3 {
+ return errors.New("criterion too short")
+ }
+ eq := bytealg.IndexByteString(f, '=')
+ if eq == -1 {
+ return errors.New("criterion lacks equal sign")
+ }
+ if hasUpperCase(f) {
+ lower := []byte(f)
+ lowerASCIIBytes(lower)
+ f = string(lower)
+ }
+ c = append(c, nssCriterion{
+ negate: not,
+ status: f[:eq],
+ action: f[eq+1:],
+ })
+ return nil
+ })
+ return
+}
diff --git a/src/net/nss_test.go b/src/net/nss_test.go
new file mode 100644
index 0000000..94e6b5f
--- /dev/null
+++ b/src/net/nss_test.go
@@ -0,0 +1,172 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
+
+package net
+
+import (
+ "reflect"
+ "testing"
+ "time"
+)
+
+const ubuntuTrustyAvahi = `# /etc/nsswitch.conf
+#
+# Example configuration of GNU Name Service Switch functionality.
+# If you have the libc-doc-reference' and nfo' packages installed, try:
+# nfo libc "Name Service Switch"' for information about this file.
+
+passwd: compat
+group: compat
+shadow: compat
+
+hosts: files mdns4_minimal [NOTFOUND=return] dns mdns4
+networks: files
+
+protocols: db files
+services: db files
+ethers: db files
+rpc: db files
+
+netgroup: nis
+`
+
+func TestParseNSSConf(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ in string
+ want *nssConf
+ }{
+ {
+ name: "no_newline",
+ in: "foo: a b",
+ want: &nssConf{
+ sources: map[string][]nssSource{
+ "foo": {{source: "a"}, {source: "b"}},
+ },
+ },
+ },
+ {
+ name: "newline",
+ in: "foo: a b\n",
+ want: &nssConf{
+ sources: map[string][]nssSource{
+ "foo": {{source: "a"}, {source: "b"}},
+ },
+ },
+ },
+ {
+ name: "whitespace",
+ in: " foo:a b \n",
+ want: &nssConf{
+ sources: map[string][]nssSource{
+ "foo": {{source: "a"}, {source: "b"}},
+ },
+ },
+ },
+ {
+ name: "comment1",
+ in: " foo:a b#c\n",
+ want: &nssConf{
+ sources: map[string][]nssSource{
+ "foo": {{source: "a"}, {source: "b"}},
+ },
+ },
+ },
+ {
+ name: "comment2",
+ in: " foo:a b #c \n",
+ want: &nssConf{
+ sources: map[string][]nssSource{
+ "foo": {{source: "a"}, {source: "b"}},
+ },
+ },
+ },
+ {
+ name: "crit",
+ in: " foo:a b [!a=b X=Y ] c#d \n",
+ want: &nssConf{
+ sources: map[string][]nssSource{
+ "foo": {
+ {source: "a"},
+ {
+ source: "b",
+ criteria: []nssCriterion{
+ {
+ negate: true,
+ status: "a",
+ action: "b",
+ },
+ {
+ status: "x",
+ action: "y",
+ },
+ },
+ },
+ {source: "c"},
+ },
+ },
+ },
+ },
+
+ // Ubuntu Trusty w/ avahi-daemon, libavahi-* etc installed.
+ {
+ name: "ubuntu_trusty_avahi",
+ in: ubuntuTrustyAvahi,
+ want: &nssConf{
+ sources: map[string][]nssSource{
+ "passwd": {{source: "compat"}},
+ "group": {{source: "compat"}},
+ "shadow": {{source: "compat"}},
+ "hosts": {
+ {source: "files"},
+ {
+ source: "mdns4_minimal",
+ criteria: []nssCriterion{
+ {
+ negate: false,
+ status: "notfound",
+ action: "return",
+ },
+ },
+ },
+ {source: "dns"},
+ {source: "mdns4"},
+ },
+ "networks": {{source: "files"}},
+ "protocols": {
+ {source: "db"},
+ {source: "files"},
+ },
+ "services": {
+ {source: "db"},
+ {source: "files"},
+ },
+ "ethers": {
+ {source: "db"},
+ {source: "files"},
+ },
+ "rpc": {
+ {source: "db"},
+ {source: "files"},
+ },
+ "netgroup": {
+ {source: "nis"},
+ },
+ },
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ gotConf := nssStr(t, tt.in)
+ gotConf.mtime = time.Time{} // ignore mtime in comparison
+ if !reflect.DeepEqual(gotConf, tt.want) {
+ t.Errorf("%s: mismatch\n got %#v\nwant %#v", tt.name, gotConf, tt.want)
+ }
+ }
+}
diff --git a/src/net/packetconn_test.go b/src/net/packetconn_test.go
new file mode 100644
index 0000000..dc0c14b
--- /dev/null
+++ b/src/net/packetconn_test.go
@@ -0,0 +1,151 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This file implements API tests across platforms and will never have a build
+// tag.
+
+//go:build !js && !wasip1
+
+package net
+
+import (
+ "os"
+ "testing"
+)
+
+// The full stack test cases for IPConn have been moved to the
+// following:
+// golang.org/x/net/ipv4
+// golang.org/x/net/ipv6
+// golang.org/x/net/icmp
+
+func packetConnTestData(t *testing.T, network string) ([]byte, func()) {
+ if !testableNetwork(network) {
+ return nil, func() { t.Logf("skipping %s test", network) }
+ }
+ return []byte("PACKETCONN TEST"), nil
+}
+
+func TestPacketConn(t *testing.T) {
+ var packetConnTests = []struct {
+ net string
+ addr1 string
+ addr2 string
+ }{
+ {"udp", "127.0.0.1:0", "127.0.0.1:0"},
+ {"unixgram", testUnixAddr(t), testUnixAddr(t)},
+ }
+
+ closer := func(c PacketConn, net, addr1, addr2 string) {
+ c.Close()
+ switch net {
+ case "unixgram":
+ os.Remove(addr1)
+ os.Remove(addr2)
+ }
+ }
+
+ for _, tt := range packetConnTests {
+ wb, skipOrFatalFn := packetConnTestData(t, tt.net)
+ if skipOrFatalFn != nil {
+ skipOrFatalFn()
+ continue
+ }
+
+ c1, err := ListenPacket(tt.net, tt.addr1)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer closer(c1, tt.net, tt.addr1, tt.addr2)
+ c1.LocalAddr()
+
+ c2, err := ListenPacket(tt.net, tt.addr2)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer closer(c2, tt.net, tt.addr1, tt.addr2)
+ c2.LocalAddr()
+ rb2 := make([]byte, 128)
+
+ if _, err := c1.WriteTo(wb, c2.LocalAddr()); err != nil {
+ t.Fatal(err)
+ }
+ if _, _, err := c2.ReadFrom(rb2); err != nil {
+ t.Fatal(err)
+ }
+ if _, err := c2.WriteTo(wb, c1.LocalAddr()); err != nil {
+ t.Fatal(err)
+ }
+ rb1 := make([]byte, 128)
+ if _, _, err := c1.ReadFrom(rb1); err != nil {
+ t.Fatal(err)
+ }
+ }
+}
+
+func TestConnAndPacketConn(t *testing.T) {
+ var packetConnTests = []struct {
+ net string
+ addr1 string
+ addr2 string
+ }{
+ {"udp", "127.0.0.1:0", "127.0.0.1:0"},
+ {"unixgram", testUnixAddr(t), testUnixAddr(t)},
+ }
+
+ closer := func(c PacketConn, net, addr1, addr2 string) {
+ c.Close()
+ switch net {
+ case "unixgram":
+ os.Remove(addr1)
+ os.Remove(addr2)
+ }
+ }
+
+ for _, tt := range packetConnTests {
+ var wb []byte
+ wb, skipOrFatalFn := packetConnTestData(t, tt.net)
+ if skipOrFatalFn != nil {
+ skipOrFatalFn()
+ continue
+ }
+
+ c1, err := ListenPacket(tt.net, tt.addr1)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer closer(c1, tt.net, tt.addr1, tt.addr2)
+ c1.LocalAddr()
+
+ c2, err := Dial(tt.net, c1.LocalAddr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c2.Close()
+ c2.LocalAddr()
+ c2.RemoteAddr()
+
+ if _, err := c2.Write(wb); err != nil {
+ t.Fatal(err)
+ }
+ rb1 := make([]byte, 128)
+ if _, _, err := c1.ReadFrom(rb1); err != nil {
+ t.Fatal(err)
+ }
+ var dst Addr
+ switch tt.net {
+ case "unixgram":
+ continue
+ default:
+ dst = c2.LocalAddr()
+ }
+ if _, err := c1.WriteTo(wb, dst); err != nil {
+ t.Fatal(err)
+ }
+ rb2 := make([]byte, 128)
+ if _, err := c2.Read(rb2); err != nil {
+ t.Fatal(err)
+ }
+ }
+}
diff --git a/src/net/parse.go b/src/net/parse.go
new file mode 100644
index 0000000..fbc5014
--- /dev/null
+++ b/src/net/parse.go
@@ -0,0 +1,319 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Simple file i/o and string manipulation, to avoid
+// depending on strconv and bufio and strings.
+
+package net
+
+import (
+ "internal/bytealg"
+ "io"
+ "os"
+ "time"
+)
+
+type file struct {
+ file *os.File
+ data []byte
+ atEOF bool
+}
+
+func (f *file) close() { f.file.Close() }
+
+func (f *file) getLineFromData() (s string, ok bool) {
+ data := f.data
+ i := 0
+ for i = 0; i < len(data); i++ {
+ if data[i] == '\n' {
+ s = string(data[0:i])
+ ok = true
+ // move data
+ i++
+ n := len(data) - i
+ copy(data[0:], data[i:])
+ f.data = data[0:n]
+ return
+ }
+ }
+ if f.atEOF && len(f.data) > 0 {
+ // EOF, return all we have
+ s = string(data)
+ f.data = f.data[0:0]
+ ok = true
+ }
+ return
+}
+
+func (f *file) readLine() (s string, ok bool) {
+ if s, ok = f.getLineFromData(); ok {
+ return
+ }
+ if len(f.data) < cap(f.data) {
+ ln := len(f.data)
+ n, err := io.ReadFull(f.file, f.data[ln:cap(f.data)])
+ if n >= 0 {
+ f.data = f.data[0 : ln+n]
+ }
+ if err == io.EOF || err == io.ErrUnexpectedEOF {
+ f.atEOF = true
+ }
+ }
+ s, ok = f.getLineFromData()
+ return
+}
+
+func (f *file) stat() (mtime time.Time, size int64, err error) {
+ st, err := f.file.Stat()
+ if err != nil {
+ return time.Time{}, 0, err
+ }
+ return st.ModTime(), st.Size(), nil
+}
+
+func open(name string) (*file, error) {
+ fd, err := os.Open(name)
+ if err != nil {
+ return nil, err
+ }
+ return &file{fd, make([]byte, 0, 64*1024), false}, nil
+}
+
+func stat(name string) (mtime time.Time, size int64, err error) {
+ st, err := os.Stat(name)
+ if err != nil {
+ return time.Time{}, 0, err
+ }
+ return st.ModTime(), st.Size(), nil
+}
+
+// Count occurrences in s of any bytes in t.
+func countAnyByte(s string, t string) int {
+ n := 0
+ for i := 0; i < len(s); i++ {
+ if bytealg.IndexByteString(t, s[i]) >= 0 {
+ n++
+ }
+ }
+ return n
+}
+
+// Split s at any bytes in t.
+func splitAtBytes(s string, t string) []string {
+ a := make([]string, 1+countAnyByte(s, t))
+ n := 0
+ last := 0
+ for i := 0; i < len(s); i++ {
+ if bytealg.IndexByteString(t, s[i]) >= 0 {
+ if last < i {
+ a[n] = s[last:i]
+ n++
+ }
+ last = i + 1
+ }
+ }
+ if last < len(s) {
+ a[n] = s[last:]
+ n++
+ }
+ return a[0:n]
+}
+
+func getFields(s string) []string { return splitAtBytes(s, " \r\t\n") }
+
+// Bigger than we need, not too big to worry about overflow
+const big = 0xFFFFFF
+
+// Decimal to integer.
+// Returns number, characters consumed, success.
+func dtoi(s string) (n int, i int, ok bool) {
+ n = 0
+ for i = 0; i < len(s) && '0' <= s[i] && s[i] <= '9'; i++ {
+ n = n*10 + int(s[i]-'0')
+ if n >= big {
+ return big, i, false
+ }
+ }
+ if i == 0 {
+ return 0, 0, false
+ }
+ return n, i, true
+}
+
+// Hexadecimal to integer.
+// Returns number, characters consumed, success.
+func xtoi(s string) (n int, i int, ok bool) {
+ n = 0
+ for i = 0; i < len(s); i++ {
+ if '0' <= s[i] && s[i] <= '9' {
+ n *= 16
+ n += int(s[i] - '0')
+ } else if 'a' <= s[i] && s[i] <= 'f' {
+ n *= 16
+ n += int(s[i]-'a') + 10
+ } else if 'A' <= s[i] && s[i] <= 'F' {
+ n *= 16
+ n += int(s[i]-'A') + 10
+ } else {
+ break
+ }
+ if n >= big {
+ return 0, i, false
+ }
+ }
+ if i == 0 {
+ return 0, i, false
+ }
+ return n, i, true
+}
+
+// xtoi2 converts the next two hex digits of s into a byte.
+// If s is longer than 2 bytes then the third byte must be e.
+// If the first two bytes of s are not hex digits or the third byte
+// does not match e, false is returned.
+func xtoi2(s string, e byte) (byte, bool) {
+ if len(s) > 2 && s[2] != e {
+ return 0, false
+ }
+ n, ei, ok := xtoi(s[:2])
+ return byte(n), ok && ei == 2
+}
+
+// Convert i to a hexadecimal string. Leading zeros are not printed.
+func appendHex(dst []byte, i uint32) []byte {
+ if i == 0 {
+ return append(dst, '0')
+ }
+ for j := 7; j >= 0; j-- {
+ v := i >> uint(j*4)
+ if v > 0 {
+ dst = append(dst, hexDigit[v&0xf])
+ }
+ }
+ return dst
+}
+
+// Number of occurrences of b in s.
+func count(s string, b byte) int {
+ n := 0
+ for i := 0; i < len(s); i++ {
+ if s[i] == b {
+ n++
+ }
+ }
+ return n
+}
+
+// Index of rightmost occurrence of b in s.
+func last(s string, b byte) int {
+ i := len(s)
+ for i--; i >= 0; i-- {
+ if s[i] == b {
+ break
+ }
+ }
+ return i
+}
+
+// hasUpperCase tells whether the given string contains at least one upper-case.
+func hasUpperCase(s string) bool {
+ for i := range s {
+ if 'A' <= s[i] && s[i] <= 'Z' {
+ return true
+ }
+ }
+ return false
+}
+
+// lowerASCIIBytes makes x ASCII lowercase in-place.
+func lowerASCIIBytes(x []byte) {
+ for i, b := range x {
+ if 'A' <= b && b <= 'Z' {
+ x[i] += 'a' - 'A'
+ }
+ }
+}
+
+// lowerASCII returns the ASCII lowercase version of b.
+func lowerASCII(b byte) byte {
+ if 'A' <= b && b <= 'Z' {
+ return b + ('a' - 'A')
+ }
+ return b
+}
+
+// trimSpace returns x without any leading or trailing ASCII whitespace.
+func trimSpace(x string) string {
+ for len(x) > 0 && isSpace(x[0]) {
+ x = x[1:]
+ }
+ for len(x) > 0 && isSpace(x[len(x)-1]) {
+ x = x[:len(x)-1]
+ }
+ return x
+}
+
+// isSpace reports whether b is an ASCII space character.
+func isSpace(b byte) bool {
+ return b == ' ' || b == '\t' || b == '\n' || b == '\r'
+}
+
+// removeComment returns line, removing any '#' byte and any following
+// bytes.
+func removeComment(line string) string {
+ if i := bytealg.IndexByteString(line, '#'); i != -1 {
+ return line[:i]
+ }
+ return line
+}
+
+// foreachField runs fn on each non-empty run of non-space bytes in x.
+// It returns the first non-nil error returned by fn.
+func foreachField(x string, fn func(field string) error) error {
+ x = trimSpace(x)
+ for len(x) > 0 {
+ sp := bytealg.IndexByteString(x, ' ')
+ if sp == -1 {
+ return fn(x)
+ }
+ if field := trimSpace(x[:sp]); len(field) > 0 {
+ if err := fn(field); err != nil {
+ return err
+ }
+ }
+ x = trimSpace(x[sp+1:])
+ }
+ return nil
+}
+
+// stringsHasSuffix is strings.HasSuffix. It reports whether s ends in
+// suffix.
+func stringsHasSuffix(s, suffix string) bool {
+ return len(s) >= len(suffix) && s[len(s)-len(suffix):] == suffix
+}
+
+// stringsHasSuffixFold reports whether s ends in suffix,
+// ASCII-case-insensitively.
+func stringsHasSuffixFold(s, suffix string) bool {
+ return len(s) >= len(suffix) && stringsEqualFold(s[len(s)-len(suffix):], suffix)
+}
+
+// stringsHasPrefix is strings.HasPrefix. It reports whether s begins with prefix.
+func stringsHasPrefix(s, prefix string) bool {
+ return len(s) >= len(prefix) && s[:len(prefix)] == prefix
+}
+
+// stringsEqualFold is strings.EqualFold, ASCII only. It reports whether s and t
+// are equal, ASCII-case-insensitively.
+func stringsEqualFold(s, t string) bool {
+ if len(s) != len(t) {
+ return false
+ }
+ for i := 0; i < len(s); i++ {
+ if lowerASCII(s[i]) != lowerASCII(t[i]) {
+ return false
+ }
+ }
+ return true
+}
diff --git a/src/net/parse_test.go b/src/net/parse_test.go
new file mode 100644
index 0000000..7c509a9
--- /dev/null
+++ b/src/net/parse_test.go
@@ -0,0 +1,74 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "bufio"
+ "os"
+ "runtime"
+ "testing"
+)
+
+func TestReadLine(t *testing.T) {
+ // /etc/services file does not exist on android, plan9, windows, or wasip1
+ // where it would be required to be mounted from the host file system.
+ switch runtime.GOOS {
+ case "android", "plan9", "windows", "wasip1":
+ t.Skipf("not supported on %s", runtime.GOOS)
+ }
+ filename := "/etc/services" // a nice big file
+
+ fd, err := os.Open(filename)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer fd.Close()
+ br := bufio.NewReader(fd)
+
+ file, err := open(filename)
+ if file == nil {
+ t.Fatal(err)
+ }
+ defer file.close()
+
+ lineno := 1
+ byteno := 0
+ for {
+ bline, berr := br.ReadString('\n')
+ if n := len(bline); n > 0 {
+ bline = bline[0 : n-1]
+ }
+ line, ok := file.readLine()
+ if (berr != nil) != !ok || bline != line {
+ t.Fatalf("%s:%d (#%d)\nbufio => %q, %v\nnet => %q, %v", filename, lineno, byteno, bline, berr, line, ok)
+ }
+ if !ok {
+ break
+ }
+ lineno++
+ byteno += len(line) + 1
+ }
+}
+
+func TestDtoi(t *testing.T) {
+ for _, tt := range []struct {
+ in string
+ out int
+ off int
+ ok bool
+ }{
+ {"", 0, 0, false},
+ {"0", 0, 1, true},
+ {"65536", 65536, 5, true},
+ {"123456789", big, 8, false},
+ {"-0", 0, 0, false},
+ {"-1234", 0, 0, false},
+ } {
+ n, i, ok := dtoi(tt.in)
+ if n != tt.out || i != tt.off || ok != tt.ok {
+ t.Errorf("got %d, %d, %v; want %d, %d, %v", n, i, ok, tt.out, tt.off, tt.ok)
+ }
+ }
+}
diff --git a/src/net/pipe.go b/src/net/pipe.go
new file mode 100644
index 0000000..f174193
--- /dev/null
+++ b/src/net/pipe.go
@@ -0,0 +1,238 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "io"
+ "os"
+ "sync"
+ "time"
+)
+
+// pipeDeadline is an abstraction for handling timeouts.
+type pipeDeadline struct {
+ mu sync.Mutex // Guards timer and cancel
+ timer *time.Timer
+ cancel chan struct{} // Must be non-nil
+}
+
+func makePipeDeadline() pipeDeadline {
+ return pipeDeadline{cancel: make(chan struct{})}
+}
+
+// set sets the point in time when the deadline will time out.
+// A timeout event is signaled by closing the channel returned by waiter.
+// Once a timeout has occurred, the deadline can be refreshed by specifying a
+// t value in the future.
+//
+// A zero value for t prevents timeout.
+func (d *pipeDeadline) set(t time.Time) {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+
+ if d.timer != nil && !d.timer.Stop() {
+ <-d.cancel // Wait for the timer callback to finish and close cancel
+ }
+ d.timer = nil
+
+ // Time is zero, then there is no deadline.
+ closed := isClosedChan(d.cancel)
+ if t.IsZero() {
+ if closed {
+ d.cancel = make(chan struct{})
+ }
+ return
+ }
+
+ // Time in the future, setup a timer to cancel in the future.
+ if dur := time.Until(t); dur > 0 {
+ if closed {
+ d.cancel = make(chan struct{})
+ }
+ d.timer = time.AfterFunc(dur, func() {
+ close(d.cancel)
+ })
+ return
+ }
+
+ // Time in the past, so close immediately.
+ if !closed {
+ close(d.cancel)
+ }
+}
+
+// wait returns a channel that is closed when the deadline is exceeded.
+func (d *pipeDeadline) wait() chan struct{} {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ return d.cancel
+}
+
+func isClosedChan(c <-chan struct{}) bool {
+ select {
+ case <-c:
+ return true
+ default:
+ return false
+ }
+}
+
+type pipeAddr struct{}
+
+func (pipeAddr) Network() string { return "pipe" }
+func (pipeAddr) String() string { return "pipe" }
+
+type pipe struct {
+ wrMu sync.Mutex // Serialize Write operations
+
+ // Used by local Read to interact with remote Write.
+ // Successful receive on rdRx is always followed by send on rdTx.
+ rdRx <-chan []byte
+ rdTx chan<- int
+
+ // Used by local Write to interact with remote Read.
+ // Successful send on wrTx is always followed by receive on wrRx.
+ wrTx chan<- []byte
+ wrRx <-chan int
+
+ once sync.Once // Protects closing localDone
+ localDone chan struct{}
+ remoteDone <-chan struct{}
+
+ readDeadline pipeDeadline
+ writeDeadline pipeDeadline
+}
+
+// Pipe creates a synchronous, in-memory, full duplex
+// network connection; both ends implement the Conn interface.
+// Reads on one end are matched with writes on the other,
+// copying data directly between the two; there is no internal
+// buffering.
+func Pipe() (Conn, Conn) {
+ cb1 := make(chan []byte)
+ cb2 := make(chan []byte)
+ cn1 := make(chan int)
+ cn2 := make(chan int)
+ done1 := make(chan struct{})
+ done2 := make(chan struct{})
+
+ p1 := &pipe{
+ rdRx: cb1, rdTx: cn1,
+ wrTx: cb2, wrRx: cn2,
+ localDone: done1, remoteDone: done2,
+ readDeadline: makePipeDeadline(),
+ writeDeadline: makePipeDeadline(),
+ }
+ p2 := &pipe{
+ rdRx: cb2, rdTx: cn2,
+ wrTx: cb1, wrRx: cn1,
+ localDone: done2, remoteDone: done1,
+ readDeadline: makePipeDeadline(),
+ writeDeadline: makePipeDeadline(),
+ }
+ return p1, p2
+}
+
+func (*pipe) LocalAddr() Addr { return pipeAddr{} }
+func (*pipe) RemoteAddr() Addr { return pipeAddr{} }
+
+func (p *pipe) Read(b []byte) (int, error) {
+ n, err := p.read(b)
+ if err != nil && err != io.EOF && err != io.ErrClosedPipe {
+ err = &OpError{Op: "read", Net: "pipe", Err: err}
+ }
+ return n, err
+}
+
+func (p *pipe) read(b []byte) (n int, err error) {
+ switch {
+ case isClosedChan(p.localDone):
+ return 0, io.ErrClosedPipe
+ case isClosedChan(p.remoteDone):
+ return 0, io.EOF
+ case isClosedChan(p.readDeadline.wait()):
+ return 0, os.ErrDeadlineExceeded
+ }
+
+ select {
+ case bw := <-p.rdRx:
+ nr := copy(b, bw)
+ p.rdTx <- nr
+ return nr, nil
+ case <-p.localDone:
+ return 0, io.ErrClosedPipe
+ case <-p.remoteDone:
+ return 0, io.EOF
+ case <-p.readDeadline.wait():
+ return 0, os.ErrDeadlineExceeded
+ }
+}
+
+func (p *pipe) Write(b []byte) (int, error) {
+ n, err := p.write(b)
+ if err != nil && err != io.ErrClosedPipe {
+ err = &OpError{Op: "write", Net: "pipe", Err: err}
+ }
+ return n, err
+}
+
+func (p *pipe) write(b []byte) (n int, err error) {
+ switch {
+ case isClosedChan(p.localDone):
+ return 0, io.ErrClosedPipe
+ case isClosedChan(p.remoteDone):
+ return 0, io.ErrClosedPipe
+ case isClosedChan(p.writeDeadline.wait()):
+ return 0, os.ErrDeadlineExceeded
+ }
+
+ p.wrMu.Lock() // Ensure entirety of b is written together
+ defer p.wrMu.Unlock()
+ for once := true; once || len(b) > 0; once = false {
+ select {
+ case p.wrTx <- b:
+ nw := <-p.wrRx
+ b = b[nw:]
+ n += nw
+ case <-p.localDone:
+ return n, io.ErrClosedPipe
+ case <-p.remoteDone:
+ return n, io.ErrClosedPipe
+ case <-p.writeDeadline.wait():
+ return n, os.ErrDeadlineExceeded
+ }
+ }
+ return n, nil
+}
+
+func (p *pipe) SetDeadline(t time.Time) error {
+ if isClosedChan(p.localDone) || isClosedChan(p.remoteDone) {
+ return io.ErrClosedPipe
+ }
+ p.readDeadline.set(t)
+ p.writeDeadline.set(t)
+ return nil
+}
+
+func (p *pipe) SetReadDeadline(t time.Time) error {
+ if isClosedChan(p.localDone) || isClosedChan(p.remoteDone) {
+ return io.ErrClosedPipe
+ }
+ p.readDeadline.set(t)
+ return nil
+}
+
+func (p *pipe) SetWriteDeadline(t time.Time) error {
+ if isClosedChan(p.localDone) || isClosedChan(p.remoteDone) {
+ return io.ErrClosedPipe
+ }
+ p.writeDeadline.set(t)
+ return nil
+}
+
+func (p *pipe) Close() error {
+ p.once.Do(func() { close(p.localDone) })
+ return nil
+}
diff --git a/src/net/pipe_test.go b/src/net/pipe_test.go
new file mode 100644
index 0000000..9cc2414
--- /dev/null
+++ b/src/net/pipe_test.go
@@ -0,0 +1,49 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net_test
+
+import (
+ "io"
+ "net"
+ "testing"
+ "time"
+
+ "golang.org/x/net/nettest"
+)
+
+func TestPipe(t *testing.T) {
+ nettest.TestConn(t, func() (c1, c2 net.Conn, stop func(), err error) {
+ c1, c2 = net.Pipe()
+ stop = func() {
+ c1.Close()
+ c2.Close()
+ }
+ return
+ })
+}
+
+func TestPipeCloseError(t *testing.T) {
+ c1, c2 := net.Pipe()
+ c1.Close()
+
+ if _, err := c1.Read(nil); err != io.ErrClosedPipe {
+ t.Errorf("c1.Read() = %v, want io.ErrClosedPipe", err)
+ }
+ if _, err := c1.Write(nil); err != io.ErrClosedPipe {
+ t.Errorf("c1.Write() = %v, want io.ErrClosedPipe", err)
+ }
+ if err := c1.SetDeadline(time.Time{}); err != io.ErrClosedPipe {
+ t.Errorf("c1.SetDeadline() = %v, want io.ErrClosedPipe", err)
+ }
+ if _, err := c2.Read(nil); err != io.EOF {
+ t.Errorf("c2.Read() = %v, want io.EOF", err)
+ }
+ if _, err := c2.Write(nil); err != io.ErrClosedPipe {
+ t.Errorf("c2.Write() = %v, want io.ErrClosedPipe", err)
+ }
+ if err := c2.SetDeadline(time.Time{}); err != io.ErrClosedPipe {
+ t.Errorf("c2.SetDeadline() = %v, want io.ErrClosedPipe", err)
+ }
+}
diff --git a/src/net/platform_test.go b/src/net/platform_test.go
new file mode 100644
index 0000000..71e9082
--- /dev/null
+++ b/src/net/platform_test.go
@@ -0,0 +1,178 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "internal/testenv"
+ "os"
+ "os/exec"
+ "runtime"
+ "strconv"
+ "strings"
+ "testing"
+)
+
+var unixEnabledOnAIX bool
+
+func init() {
+ if runtime.GOOS == "aix" {
+ // Unix network isn't properly working on AIX 7.2 with
+ // Technical Level < 2.
+ // The information is retrieved only once in this init()
+ // instead of everytime testableNetwork is called.
+ out, _ := exec.Command("oslevel", "-s").Output()
+ if len(out) >= len("7200-XX-ZZ-YYMM") { // AIX 7.2, Tech Level XX, Service Pack ZZ, date YYMM
+ aixVer := string(out[:4])
+ tl, _ := strconv.Atoi(string(out[5:7]))
+ unixEnabledOnAIX = aixVer > "7200" || (aixVer == "7200" && tl >= 2)
+ }
+ }
+}
+
+// testableNetwork reports whether network is testable on the current
+// platform configuration.
+func testableNetwork(network string) bool {
+ net, _, _ := strings.Cut(network, ":")
+ switch net {
+ case "ip+nopriv":
+ case "ip", "ip4", "ip6":
+ switch runtime.GOOS {
+ case "plan9":
+ return false
+ default:
+ if os.Getuid() != 0 {
+ return false
+ }
+ }
+ case "unix", "unixgram":
+ switch runtime.GOOS {
+ case "android", "ios", "plan9", "windows":
+ return false
+ case "aix":
+ return unixEnabledOnAIX
+ }
+ case "unixpacket":
+ switch runtime.GOOS {
+ case "aix", "android", "darwin", "ios", "plan9", "windows":
+ return false
+ }
+ }
+ switch net {
+ case "tcp4", "udp4", "ip4":
+ if !supportsIPv4() {
+ return false
+ }
+ case "tcp6", "udp6", "ip6":
+ if !supportsIPv6() {
+ return false
+ }
+ }
+ return true
+}
+
+// testableAddress reports whether address of network is testable on
+// the current platform configuration.
+func testableAddress(network, address string) bool {
+ switch net, _, _ := strings.Cut(network, ":"); net {
+ case "unix", "unixgram", "unixpacket":
+ // Abstract unix domain sockets, a Linux-ism.
+ if address[0] == '@' && runtime.GOOS != "linux" {
+ return false
+ }
+ }
+ return true
+}
+
+// testableListenArgs reports whether arguments are testable on the
+// current platform configuration.
+func testableListenArgs(network, address, client string) bool {
+ if !testableNetwork(network) || !testableAddress(network, address) {
+ return false
+ }
+
+ var err error
+ var addr Addr
+ switch net, _, _ := strings.Cut(network, ":"); net {
+ case "tcp", "tcp4", "tcp6":
+ addr, err = ResolveTCPAddr("tcp", address)
+ case "udp", "udp4", "udp6":
+ addr, err = ResolveUDPAddr("udp", address)
+ case "ip", "ip4", "ip6":
+ addr, err = ResolveIPAddr("ip", address)
+ default:
+ return true
+ }
+ if err != nil {
+ return false
+ }
+ var ip IP
+ var wildcard bool
+ switch addr := addr.(type) {
+ case *TCPAddr:
+ ip = addr.IP
+ wildcard = addr.isWildcard()
+ case *UDPAddr:
+ ip = addr.IP
+ wildcard = addr.isWildcard()
+ case *IPAddr:
+ ip = addr.IP
+ wildcard = addr.isWildcard()
+ }
+
+ // Test wildcard IP addresses.
+ if wildcard && !testenv.HasExternalNetwork() {
+ return false
+ }
+
+ // Test functionality of IPv4 communication using AF_INET and
+ // IPv6 communication using AF_INET6 sockets.
+ if !supportsIPv4() && ip.To4() != nil {
+ return false
+ }
+ if !supportsIPv6() && ip.To16() != nil && ip.To4() == nil {
+ return false
+ }
+ cip := ParseIP(client)
+ if cip != nil {
+ if !supportsIPv4() && cip.To4() != nil {
+ return false
+ }
+ if !supportsIPv6() && cip.To16() != nil && cip.To4() == nil {
+ return false
+ }
+ }
+
+ // Test functionality of IPv4 communication using AF_INET6
+ // sockets.
+ if !supportsIPv4map() && supportsIPv4() && (network == "tcp" || network == "udp" || network == "ip") && wildcard {
+ // At this point, we prefer IPv4 when ip is nil.
+ // See favoriteAddrFamily for further information.
+ if ip.To16() != nil && ip.To4() == nil && cip.To4() != nil { // a pair of IPv6 server and IPv4 client
+ return false
+ }
+ if (ip.To4() != nil || ip == nil) && cip.To16() != nil && cip.To4() == nil { // a pair of IPv4 server and IPv6 client
+ return false
+ }
+ }
+
+ return true
+}
+
+func condFatalf(t *testing.T, network string, format string, args ...any) {
+ t.Helper()
+ // A few APIs like File and Read/WriteMsg{UDP,IP} are not
+ // fully implemented yet on Plan 9 and Windows.
+ switch runtime.GOOS {
+ case "windows":
+ if network == "file+net" {
+ t.Logf(format, args...)
+ return
+ }
+ case "plan9":
+ t.Logf(format, args...)
+ return
+ }
+ t.Fatalf(format, args...)
+}
diff --git a/src/net/port.go b/src/net/port.go
new file mode 100644
index 0000000..32e7628
--- /dev/null
+++ b/src/net/port.go
@@ -0,0 +1,62 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+// parsePort parses service as a decimal integer and returns the
+// corresponding value as port. It is the caller's responsibility to
+// parse service as a non-decimal integer when needsLookup is true.
+//
+// Some system resolvers will return a valid port number when given a number
+// over 65536 (see https://golang.org/issues/11715). Alas, the parser
+// can't bail early on numbers > 65536. Therefore reasonably large/small
+// numbers are parsed in full and rejected if invalid.
+func parsePort(service string) (port int, needsLookup bool) {
+ if service == "" {
+ // Lock in the legacy behavior that an empty string
+ // means port 0. See golang.org/issue/13610.
+ return 0, false
+ }
+ const (
+ max = uint32(1<<32 - 1)
+ cutoff = uint32(1 << 30)
+ )
+ neg := false
+ if service[0] == '+' {
+ service = service[1:]
+ } else if service[0] == '-' {
+ neg = true
+ service = service[1:]
+ }
+ var n uint32
+ for _, d := range service {
+ if '0' <= d && d <= '9' {
+ d -= '0'
+ } else {
+ return 0, true
+ }
+ if n >= cutoff {
+ n = max
+ break
+ }
+ n *= 10
+ nn := n + uint32(d)
+ if nn < n || nn > max {
+ n = max
+ break
+ }
+ n = nn
+ }
+ if !neg && n >= cutoff {
+ port = int(cutoff - 1)
+ } else if neg && n > cutoff {
+ port = int(cutoff)
+ } else {
+ port = int(n)
+ }
+ if neg {
+ port = -port
+ }
+ return port, false
+}
diff --git a/src/net/port_test.go b/src/net/port_test.go
new file mode 100644
index 0000000..e0bdb42
--- /dev/null
+++ b/src/net/port_test.go
@@ -0,0 +1,52 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import "testing"
+
+var parsePortTests = []struct {
+ service string
+ port int
+ needsLookup bool
+}{
+ {"", 0, false},
+
+ // Decimal number literals
+ {"-1073741825", -1 << 30, false},
+ {"-1073741824", -1 << 30, false},
+ {"-1073741823", -(1<<30 - 1), false},
+ {"-123456789", -123456789, false},
+ {"-1", -1, false},
+ {"-0", 0, false},
+ {"0", 0, false},
+ {"+0", 0, false},
+ {"+1", 1, false},
+ {"65535", 65535, false},
+ {"65536", 65536, false},
+ {"123456789", 123456789, false},
+ {"1073741822", 1<<30 - 2, false},
+ {"1073741823", 1<<30 - 1, false},
+ {"1073741824", 1<<30 - 1, false},
+ {"1073741825", 1<<30 - 1, false},
+
+ // Others
+ {"abc", 0, true},
+ {"9pfs", 0, true},
+ {"123badport", 0, true},
+ {"bad123port", 0, true},
+ {"badport123", 0, true},
+ {"123456789badport", 0, true},
+ {"-2147483649badport", 0, true},
+ {"2147483649badport", 0, true},
+}
+
+func TestParsePort(t *testing.T) {
+ // The following test cases are cribbed from the strconv
+ for _, tt := range parsePortTests {
+ if port, needsLookup := parsePort(tt.service); port != tt.port || needsLookup != tt.needsLookup {
+ t.Errorf("parsePort(%q) = %d, %t; want %d, %t", tt.service, port, needsLookup, tt.port, tt.needsLookup)
+ }
+ }
+}
diff --git a/src/net/port_unix.go b/src/net/port_unix.go
new file mode 100644
index 0000000..0b2ea3e
--- /dev/null
+++ b/src/net/port_unix.go
@@ -0,0 +1,57 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build unix || (js && wasm) || wasip1
+
+// Read system port mappings from /etc/services
+
+package net
+
+import (
+ "internal/bytealg"
+ "sync"
+)
+
+var onceReadServices sync.Once
+
+func readServices() {
+ file, err := open("/etc/services")
+ if err != nil {
+ return
+ }
+ defer file.close()
+
+ for line, ok := file.readLine(); ok; line, ok = file.readLine() {
+ // "http 80/tcp www www-http # World Wide Web HTTP"
+ if i := bytealg.IndexByteString(line, '#'); i >= 0 {
+ line = line[:i]
+ }
+ f := getFields(line)
+ if len(f) < 2 {
+ continue
+ }
+ portnet := f[1] // "80/tcp"
+ port, j, ok := dtoi(portnet)
+ if !ok || port <= 0 || j >= len(portnet) || portnet[j] != '/' {
+ continue
+ }
+ netw := portnet[j+1:] // "tcp"
+ m, ok1 := services[netw]
+ if !ok1 {
+ m = make(map[string]int)
+ services[netw] = m
+ }
+ for i := 0; i < len(f); i++ {
+ if i != 1 { // f[1] was port/net
+ m[f[i]] = port
+ }
+ }
+ }
+}
+
+// goLookupPort is the native Go implementation of LookupPort.
+func goLookupPort(network, service string) (port int, err error) {
+ onceReadServices.Do(readServices)
+ return lookupPortMap(network, service)
+}
diff --git a/src/net/protoconn_test.go b/src/net/protoconn_test.go
new file mode 100644
index 0000000..c566807
--- /dev/null
+++ b/src/net/protoconn_test.go
@@ -0,0 +1,350 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This file implements API tests across platforms and will never have a build
+// tag.
+
+//go:build !js && !wasip1
+
+package net
+
+import (
+ "internal/testenv"
+ "os"
+ "runtime"
+ "testing"
+ "time"
+)
+
+// The full stack test cases for IPConn have been moved to the
+// following:
+// golang.org/x/net/ipv4
+// golang.org/x/net/ipv6
+// golang.org/x/net/icmp
+
+func TestTCPListenerSpecificMethods(t *testing.T) {
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("not supported on %s", runtime.GOOS)
+ }
+
+ la, err := ResolveTCPAddr("tcp4", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ ln, err := ListenTCP("tcp4", la)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ln.Close()
+ ln.Addr()
+ ln.SetDeadline(time.Now().Add(30 * time.Nanosecond))
+
+ if c, err := ln.Accept(); err != nil {
+ if !err.(Error).Timeout() {
+ t.Fatal(err)
+ }
+ } else {
+ c.Close()
+ }
+ if c, err := ln.AcceptTCP(); err != nil {
+ if !err.(Error).Timeout() {
+ t.Fatal(err)
+ }
+ } else {
+ c.Close()
+ }
+
+ if f, err := ln.File(); err != nil {
+ condFatalf(t, "file+net", "%v", err)
+ } else {
+ f.Close()
+ }
+}
+
+func TestTCPConnSpecificMethods(t *testing.T) {
+ la, err := ResolveTCPAddr("tcp4", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ ln, err := ListenTCP("tcp4", la)
+ if err != nil {
+ t.Fatal(err)
+ }
+ ch := make(chan error, 1)
+ handler := func(ls *localServer, ln Listener) { ls.transponder(ls.Listener, ch) }
+ ls := (&streamListener{Listener: ln}).newLocalServer()
+ defer ls.teardown()
+ if err := ls.buildup(handler); err != nil {
+ t.Fatal(err)
+ }
+
+ ra, err := ResolveTCPAddr("tcp4", ls.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ c, err := DialTCP("tcp4", nil, ra)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+ c.SetKeepAlive(false)
+ c.SetKeepAlivePeriod(3 * time.Second)
+ c.SetLinger(0)
+ c.SetNoDelay(false)
+ c.LocalAddr()
+ c.RemoteAddr()
+ c.SetDeadline(time.Now().Add(someTimeout))
+ c.SetReadDeadline(time.Now().Add(someTimeout))
+ c.SetWriteDeadline(time.Now().Add(someTimeout))
+
+ if _, err := c.Write([]byte("TCPCONN TEST")); err != nil {
+ t.Fatal(err)
+ }
+ rb := make([]byte, 128)
+ if _, err := c.Read(rb); err != nil {
+ t.Fatal(err)
+ }
+
+ for err := range ch {
+ t.Error(err)
+ }
+}
+
+func TestUDPConnSpecificMethods(t *testing.T) {
+ la, err := ResolveUDPAddr("udp4", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ c, err := ListenUDP("udp4", la)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+ c.LocalAddr()
+ c.RemoteAddr()
+ c.SetDeadline(time.Now().Add(someTimeout))
+ c.SetReadDeadline(time.Now().Add(someTimeout))
+ c.SetWriteDeadline(time.Now().Add(someTimeout))
+ c.SetReadBuffer(2048)
+ c.SetWriteBuffer(2048)
+
+ wb := []byte("UDPCONN TEST")
+ rb := make([]byte, 128)
+ if _, err := c.WriteToUDP(wb, c.LocalAddr().(*UDPAddr)); err != nil {
+ t.Fatal(err)
+ }
+ if _, _, err := c.ReadFromUDP(rb); err != nil {
+ t.Fatal(err)
+ }
+ if _, _, err := c.WriteMsgUDP(wb, nil, c.LocalAddr().(*UDPAddr)); err != nil {
+ condFatalf(t, c.LocalAddr().Network(), "%v", err)
+ }
+ if _, _, _, _, err := c.ReadMsgUDP(rb, nil); err != nil {
+ condFatalf(t, c.LocalAddr().Network(), "%v", err)
+ }
+
+ if f, err := c.File(); err != nil {
+ condFatalf(t, "file+net", "%v", err)
+ } else {
+ f.Close()
+ }
+
+ defer func() {
+ if p := recover(); p != nil {
+ t.Fatalf("panicked: %v", p)
+ }
+ }()
+
+ c.WriteToUDP(wb, nil)
+ c.WriteMsgUDP(wb, nil, nil)
+}
+
+func TestIPConnSpecificMethods(t *testing.T) {
+ la, err := ResolveIPAddr("ip4", "127.0.0.1")
+ if err != nil {
+ t.Fatal(err)
+ }
+ c, err := ListenIP("ip4:icmp", la)
+ if testenv.SyscallIsNotSupported(err) {
+ // May be inside a container that disallows creating a socket or
+ // not running as root.
+ t.Skipf("skipping: %v", err)
+ } else if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+ c.LocalAddr()
+ c.RemoteAddr()
+ c.SetDeadline(time.Now().Add(someTimeout))
+ c.SetReadDeadline(time.Now().Add(someTimeout))
+ c.SetWriteDeadline(time.Now().Add(someTimeout))
+ c.SetReadBuffer(2048)
+ c.SetWriteBuffer(2048)
+
+ if f, err := c.File(); err != nil {
+ condFatalf(t, "file+net", "%v", err)
+ } else {
+ f.Close()
+ }
+
+ defer func() {
+ if p := recover(); p != nil {
+ t.Fatalf("panicked: %v", p)
+ }
+ }()
+
+ wb := []byte("IPCONN TEST")
+ c.WriteToIP(wb, nil)
+ c.WriteMsgIP(wb, nil, nil)
+}
+
+func TestUnixListenerSpecificMethods(t *testing.T) {
+ if !testableNetwork("unix") {
+ t.Skip("unix test")
+ }
+
+ addr := testUnixAddr(t)
+ la, err := ResolveUnixAddr("unix", addr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ ln, err := ListenUnix("unix", la)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ln.Close()
+ defer os.Remove(addr)
+ ln.Addr()
+ ln.SetDeadline(time.Now().Add(30 * time.Nanosecond))
+
+ if c, err := ln.Accept(); err != nil {
+ if !err.(Error).Timeout() {
+ t.Fatal(err)
+ }
+ } else {
+ c.Close()
+ }
+ if c, err := ln.AcceptUnix(); err != nil {
+ if !err.(Error).Timeout() {
+ t.Fatal(err)
+ }
+ } else {
+ c.Close()
+ }
+
+ if f, err := ln.File(); err != nil {
+ t.Fatal(err)
+ } else {
+ f.Close()
+ }
+}
+
+func TestUnixConnSpecificMethods(t *testing.T) {
+ if !testableNetwork("unixgram") {
+ t.Skip("unixgram test")
+ }
+
+ addr1, addr2, addr3 := testUnixAddr(t), testUnixAddr(t), testUnixAddr(t)
+
+ a1, err := ResolveUnixAddr("unixgram", addr1)
+ if err != nil {
+ t.Fatal(err)
+ }
+ c1, err := DialUnix("unixgram", a1, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c1.Close()
+ defer os.Remove(addr1)
+ c1.LocalAddr()
+ c1.RemoteAddr()
+ c1.SetDeadline(time.Now().Add(someTimeout))
+ c1.SetReadDeadline(time.Now().Add(someTimeout))
+ c1.SetWriteDeadline(time.Now().Add(someTimeout))
+ c1.SetReadBuffer(2048)
+ c1.SetWriteBuffer(2048)
+
+ a2, err := ResolveUnixAddr("unixgram", addr2)
+ if err != nil {
+ t.Fatal(err)
+ }
+ c2, err := DialUnix("unixgram", a2, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c2.Close()
+ defer os.Remove(addr2)
+ c2.LocalAddr()
+ c2.RemoteAddr()
+ c2.SetDeadline(time.Now().Add(someTimeout))
+ c2.SetReadDeadline(time.Now().Add(someTimeout))
+ c2.SetWriteDeadline(time.Now().Add(someTimeout))
+ c2.SetReadBuffer(2048)
+ c2.SetWriteBuffer(2048)
+
+ a3, err := ResolveUnixAddr("unixgram", addr3)
+ if err != nil {
+ t.Fatal(err)
+ }
+ c3, err := ListenUnixgram("unixgram", a3)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c3.Close()
+ defer os.Remove(addr3)
+ c3.LocalAddr()
+ c3.RemoteAddr()
+ c3.SetDeadline(time.Now().Add(someTimeout))
+ c3.SetReadDeadline(time.Now().Add(someTimeout))
+ c3.SetWriteDeadline(time.Now().Add(someTimeout))
+ c3.SetReadBuffer(2048)
+ c3.SetWriteBuffer(2048)
+
+ wb := []byte("UNIXCONN TEST")
+ rb1 := make([]byte, 128)
+ rb2 := make([]byte, 128)
+ rb3 := make([]byte, 128)
+ if _, _, err := c1.WriteMsgUnix(wb, nil, a2); err != nil {
+ t.Fatal(err)
+ }
+ if _, _, _, _, err := c2.ReadMsgUnix(rb2, nil); err != nil {
+ t.Fatal(err)
+ }
+ if _, err := c2.WriteToUnix(wb, a1); err != nil {
+ t.Fatal(err)
+ }
+ if _, _, err := c1.ReadFromUnix(rb1); err != nil {
+ t.Fatal(err)
+ }
+ if _, err := c3.WriteToUnix(wb, a1); err != nil {
+ t.Fatal(err)
+ }
+ if _, _, err := c1.ReadFromUnix(rb1); err != nil {
+ t.Fatal(err)
+ }
+ if _, err := c2.WriteToUnix(wb, a3); err != nil {
+ t.Fatal(err)
+ }
+ if _, _, err := c3.ReadFromUnix(rb3); err != nil {
+ t.Fatal(err)
+ }
+
+ if f, err := c1.File(); err != nil {
+ t.Fatal(err)
+ } else {
+ f.Close()
+ }
+
+ defer func() {
+ if p := recover(); p != nil {
+ t.Fatalf("panicked: %v", p)
+ }
+ }()
+
+ c1.WriteToUnix(wb, nil)
+ c1.WriteMsgUnix(wb, nil, nil)
+ c3.WriteToUnix(wb, nil)
+ c3.WriteMsgUnix(wb, nil, nil)
+}
diff --git a/src/net/rawconn.go b/src/net/rawconn.go
new file mode 100644
index 0000000..974320c
--- /dev/null
+++ b/src/net/rawconn.go
@@ -0,0 +1,96 @@
+// Copyright 2017 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "internal/poll"
+ "runtime"
+ "syscall"
+)
+
+// BUG(tmm1): On Windows, the Write method of syscall.RawConn
+// does not integrate with the runtime's network poller. It cannot
+// wait for the connection to become writeable, and does not respect
+// deadlines. If the user-provided callback returns false, the Write
+// method will fail immediately.
+
+// BUG(mikio): On JS and Plan 9, the Control, Read and Write
+// methods of syscall.RawConn are not implemented.
+
+type rawConn struct {
+ fd *netFD
+}
+
+func (c *rawConn) ok() bool { return c != nil && c.fd != nil }
+
+func (c *rawConn) Control(f func(uintptr)) error {
+ if !c.ok() {
+ return syscall.EINVAL
+ }
+ err := c.fd.pfd.RawControl(f)
+ runtime.KeepAlive(c.fd)
+ if err != nil {
+ err = &OpError{Op: "raw-control", Net: c.fd.net, Source: nil, Addr: c.fd.laddr, Err: err}
+ }
+ return err
+}
+
+func (c *rawConn) Read(f func(uintptr) bool) error {
+ if !c.ok() {
+ return syscall.EINVAL
+ }
+ err := c.fd.pfd.RawRead(f)
+ runtime.KeepAlive(c.fd)
+ if err != nil {
+ err = &OpError{Op: "raw-read", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
+ }
+ return err
+}
+
+func (c *rawConn) Write(f func(uintptr) bool) error {
+ if !c.ok() {
+ return syscall.EINVAL
+ }
+ err := c.fd.pfd.RawWrite(f)
+ runtime.KeepAlive(c.fd)
+ if err != nil {
+ err = &OpError{Op: "raw-write", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
+ }
+ return err
+}
+
+// PollFD returns the poll.FD of the underlying connection.
+//
+// Other packages in std that also import internal/poll (such as os)
+// can use a type assertion to access this extension method so that
+// they can pass the *poll.FD to functions like poll.Splice.
+//
+// PollFD is not intended for use outside the standard library.
+func (c *rawConn) PollFD() *poll.FD {
+ if !c.ok() {
+ return nil
+ }
+ return &c.fd.pfd
+}
+
+func newRawConn(fd *netFD) (*rawConn, error) {
+ return &rawConn{fd: fd}, nil
+}
+
+type rawListener struct {
+ rawConn
+}
+
+func (l *rawListener) Read(func(uintptr) bool) error {
+ return syscall.EINVAL
+}
+
+func (l *rawListener) Write(func(uintptr) bool) error {
+ return syscall.EINVAL
+}
+
+func newRawListener(fd *netFD) (*rawListener, error) {
+ return &rawListener{rawConn{fd: fd}}, nil
+}
diff --git a/src/net/rawconn_stub_test.go b/src/net/rawconn_stub_test.go
new file mode 100644
index 0000000..c8ad80c
--- /dev/null
+++ b/src/net/rawconn_stub_test.go
@@ -0,0 +1,28 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build (js && wasm) || plan9 || wasip1
+
+package net
+
+import (
+ "errors"
+ "syscall"
+)
+
+func readRawConn(c syscall.RawConn, b []byte) (int, error) {
+ return 0, errors.New("not supported")
+}
+
+func writeRawConn(c syscall.RawConn, b []byte) error {
+ return errors.New("not supported")
+}
+
+func controlRawConn(c syscall.RawConn, addr Addr) error {
+ return errors.New("not supported")
+}
+
+func controlOnConnSetup(network string, address string, c syscall.RawConn) error {
+ return nil
+}
diff --git a/src/net/rawconn_test.go b/src/net/rawconn_test.go
new file mode 100644
index 0000000..06d5856
--- /dev/null
+++ b/src/net/rawconn_test.go
@@ -0,0 +1,211 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !js && !wasip1
+
+package net
+
+import (
+ "bytes"
+ "runtime"
+ "testing"
+ "time"
+)
+
+func TestRawConnReadWrite(t *testing.T) {
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("not supported on %s", runtime.GOOS)
+ }
+
+ t.Run("TCP", func(t *testing.T) {
+ handler := func(ls *localServer, ln Listener) {
+ c, err := ln.Accept()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ defer c.Close()
+
+ cc, err := ln.(*TCPListener).SyscallConn()
+ if err != nil {
+ t.Fatal(err)
+ }
+ called := false
+ op := func(uintptr) bool {
+ called = true
+ return true
+ }
+ err = cc.Write(op)
+ if err == nil {
+ t.Error("Write should return an error")
+ }
+ if called {
+ t.Error("Write shouldn't call op")
+ }
+ called = false
+ err = cc.Read(op)
+ if err == nil {
+ t.Error("Read should return an error")
+ }
+ if called {
+ t.Error("Read shouldn't call op")
+ }
+
+ var b [32]byte
+ n, err := c.Read(b[:])
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ if _, err := c.Write(b[:n]); err != nil {
+ t.Error(err)
+ return
+ }
+ }
+ ls := newLocalServer(t, "tcp")
+ defer ls.teardown()
+ if err := ls.buildup(handler); err != nil {
+ t.Fatal(err)
+ }
+
+ c, err := Dial(ls.Listener.Addr().Network(), ls.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+
+ cc, err := c.(*TCPConn).SyscallConn()
+ if err != nil {
+ t.Fatal(err)
+ }
+ data := []byte("HELLO-R-U-THERE")
+ if err := writeRawConn(cc, data); err != nil {
+ t.Fatal(err)
+ }
+ var b [32]byte
+ n, err := readRawConn(cc, b[:])
+ if err != nil {
+ t.Fatal(err)
+ }
+ if bytes.Compare(b[:n], data) != 0 {
+ t.Fatalf("got %q; want %q", b[:n], data)
+ }
+ })
+ t.Run("Deadline", func(t *testing.T) {
+ switch runtime.GOOS {
+ case "windows":
+ t.Skipf("not supported on %s", runtime.GOOS)
+ }
+
+ ln := newLocalListener(t, "tcp")
+ defer ln.Close()
+
+ c, err := Dial(ln.Addr().Network(), ln.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+
+ cc, err := c.(*TCPConn).SyscallConn()
+ if err != nil {
+ t.Fatal(err)
+ }
+ var b [1]byte
+
+ c.SetDeadline(noDeadline)
+ if err := c.SetDeadline(time.Now().Add(-1)); err != nil {
+ t.Fatal(err)
+ }
+ if err = writeRawConn(cc, b[:]); err == nil {
+ t.Fatal("Write should fail")
+ }
+ if perr := parseWriteError(err); perr != nil {
+ t.Error(perr)
+ }
+ if !isDeadlineExceeded(err) {
+ t.Errorf("got %v; want timeout", err)
+ }
+ if _, err = readRawConn(cc, b[:]); err == nil {
+ t.Fatal("Read should fail")
+ }
+ if perr := parseReadError(err); perr != nil {
+ t.Error(perr)
+ }
+ if !isDeadlineExceeded(err) {
+ t.Errorf("got %v; want timeout", err)
+ }
+
+ c.SetReadDeadline(noDeadline)
+ if err := c.SetReadDeadline(time.Now().Add(-1)); err != nil {
+ t.Fatal(err)
+ }
+ if _, err = readRawConn(cc, b[:]); err == nil {
+ t.Fatal("Read should fail")
+ }
+ if perr := parseReadError(err); perr != nil {
+ t.Error(perr)
+ }
+ if !isDeadlineExceeded(err) {
+ t.Errorf("got %v; want timeout", err)
+ }
+
+ c.SetWriteDeadline(noDeadline)
+ if err := c.SetWriteDeadline(time.Now().Add(-1)); err != nil {
+ t.Fatal(err)
+ }
+ if err = writeRawConn(cc, b[:]); err == nil {
+ t.Fatal("Write should fail")
+ }
+ if perr := parseWriteError(err); perr != nil {
+ t.Error(perr)
+ }
+ if !isDeadlineExceeded(err) {
+ t.Errorf("got %v; want timeout", err)
+ }
+ })
+}
+
+func TestRawConnControl(t *testing.T) {
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("not supported on %s", runtime.GOOS)
+ }
+
+ t.Run("TCP", func(t *testing.T) {
+ ln := newLocalListener(t, "tcp")
+ defer ln.Close()
+
+ cc1, err := ln.(*TCPListener).SyscallConn()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if err := controlRawConn(cc1, ln.Addr()); err != nil {
+ t.Fatal(err)
+ }
+
+ c, err := Dial(ln.Addr().Network(), ln.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+
+ cc2, err := c.(*TCPConn).SyscallConn()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if err := controlRawConn(cc2, c.LocalAddr()); err != nil {
+ t.Fatal(err)
+ }
+
+ ln.Close()
+ if err := controlRawConn(cc1, ln.Addr()); err == nil {
+ t.Fatal("Control after Close should fail")
+ }
+ c.Close()
+ if err := controlRawConn(cc2, c.LocalAddr()); err == nil {
+ t.Fatal("Control after Close should fail")
+ }
+ })
+}
diff --git a/src/net/rawconn_unix_test.go b/src/net/rawconn_unix_test.go
new file mode 100644
index 0000000..f11119e
--- /dev/null
+++ b/src/net/rawconn_unix_test.go
@@ -0,0 +1,115 @@
+// Copyright 2017 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build unix
+
+package net
+
+import (
+ "errors"
+ "syscall"
+)
+
+func readRawConn(c syscall.RawConn, b []byte) (int, error) {
+ var operr error
+ var n int
+ err := c.Read(func(s uintptr) bool {
+ n, operr = syscall.Read(int(s), b)
+ if operr == syscall.EAGAIN {
+ return false
+ }
+ return true
+ })
+ if err != nil {
+ return n, err
+ }
+ return n, operr
+}
+
+func writeRawConn(c syscall.RawConn, b []byte) error {
+ var operr error
+ err := c.Write(func(s uintptr) bool {
+ _, operr = syscall.Write(int(s), b)
+ if operr == syscall.EAGAIN {
+ return false
+ }
+ return true
+ })
+ if err != nil {
+ return err
+ }
+ return operr
+}
+
+func controlRawConn(c syscall.RawConn, addr Addr) error {
+ var operr error
+ fn := func(s uintptr) {
+ _, operr = syscall.GetsockoptInt(int(s), syscall.SOL_SOCKET, syscall.SO_REUSEADDR)
+ if operr != nil {
+ return
+ }
+ switch addr := addr.(type) {
+ case *TCPAddr:
+ // There's no guarantee that IP-level socket
+ // options work well with dual stack sockets.
+ // A simple solution would be to take a look
+ // at the bound address to the raw connection
+ // and to classify the address family of the
+ // underlying socket by the bound address:
+ //
+ // - When IP.To16() != nil and IP.To4() == nil,
+ // we can assume that the raw connection
+ // consists of an IPv6 socket using only
+ // IPv6 addresses.
+ //
+ // - When IP.To16() == nil and IP.To4() != nil,
+ // the raw connection consists of an IPv4
+ // socket using only IPv4 addresses.
+ //
+ // - Otherwise, the raw connection is a dual
+ // stack socket, an IPv6 socket using IPv6
+ // addresses including IPv4-mapped or
+ // IPv4-embedded IPv6 addresses.
+ if addr.IP.To16() != nil && addr.IP.To4() == nil {
+ operr = syscall.SetsockoptInt(int(s), syscall.IPPROTO_IPV6, syscall.IPV6_UNICAST_HOPS, 1)
+ } else if addr.IP.To16() == nil && addr.IP.To4() != nil {
+ operr = syscall.SetsockoptInt(int(s), syscall.IPPROTO_IP, syscall.IP_TTL, 1)
+ }
+ }
+ }
+ if err := c.Control(fn); err != nil {
+ return err
+ }
+ return operr
+}
+
+func controlOnConnSetup(network string, address string, c syscall.RawConn) error {
+ var operr error
+ var fn func(uintptr)
+ switch network {
+ case "tcp", "udp", "ip":
+ return errors.New("ambiguous network: " + network)
+ case "unix", "unixpacket", "unixgram":
+ fn = func(s uintptr) {
+ _, operr = syscall.GetsockoptInt(int(s), syscall.SOL_SOCKET, syscall.SO_ERROR)
+ }
+ default:
+ switch network[len(network)-1] {
+ case '4':
+ fn = func(s uintptr) {
+ operr = syscall.SetsockoptInt(int(s), syscall.IPPROTO_IP, syscall.IP_TTL, 1)
+ }
+ case '6':
+ fn = func(s uintptr) {
+ operr = syscall.SetsockoptInt(int(s), syscall.IPPROTO_IPV6, syscall.IPV6_UNICAST_HOPS, 1)
+ }
+ default:
+ return errors.New("unknown network: " + network)
+ }
+ }
+ if err := c.Control(fn); err != nil {
+ return err
+ }
+ return operr
+}
diff --git a/src/net/rawconn_windows_test.go b/src/net/rawconn_windows_test.go
new file mode 100644
index 0000000..5febf08
--- /dev/null
+++ b/src/net/rawconn_windows_test.go
@@ -0,0 +1,116 @@
+// Copyright 2017 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "errors"
+ "syscall"
+ "unsafe"
+)
+
+func readRawConn(c syscall.RawConn, b []byte) (int, error) {
+ var operr error
+ var n int
+ err := c.Read(func(s uintptr) bool {
+ var read uint32
+ var flags uint32
+ var buf syscall.WSABuf
+ buf.Buf = &b[0]
+ buf.Len = uint32(len(b))
+ operr = syscall.WSARecv(syscall.Handle(s), &buf, 1, &read, &flags, nil, nil)
+ n = int(read)
+ return true
+ })
+ if err != nil {
+ return n, err
+ }
+ return n, operr
+}
+
+func writeRawConn(c syscall.RawConn, b []byte) error {
+ var operr error
+ err := c.Write(func(s uintptr) bool {
+ var written uint32
+ var buf syscall.WSABuf
+ buf.Buf = &b[0]
+ buf.Len = uint32(len(b))
+ operr = syscall.WSASend(syscall.Handle(s), &buf, 1, &written, 0, nil, nil)
+ return true
+ })
+ if err != nil {
+ return err
+ }
+ return operr
+}
+
+func controlRawConn(c syscall.RawConn, addr Addr) error {
+ var operr error
+ fn := func(s uintptr) {
+ var v, l int32
+ l = int32(unsafe.Sizeof(v))
+ operr = syscall.Getsockopt(syscall.Handle(s), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, (*byte)(unsafe.Pointer(&v)), &l)
+ if operr != nil {
+ return
+ }
+ switch addr := addr.(type) {
+ case *TCPAddr:
+ // There's no guarantee that IP-level socket
+ // options work well with dual stack sockets.
+ // A simple solution would be to take a look
+ // at the bound address to the raw connection
+ // and to classify the address family of the
+ // underlying socket by the bound address:
+ //
+ // - When IP.To16() != nil and IP.To4() == nil,
+ // we can assume that the raw connection
+ // consists of an IPv6 socket using only
+ // IPv6 addresses.
+ //
+ // - When IP.To16() == nil and IP.To4() != nil,
+ // the raw connection consists of an IPv4
+ // socket using only IPv4 addresses.
+ //
+ // - Otherwise, the raw connection is a dual
+ // stack socket, an IPv6 socket using IPv6
+ // addresses including IPv4-mapped or
+ // IPv4-embedded IPv6 addresses.
+ if addr.IP.To16() != nil && addr.IP.To4() == nil {
+ operr = syscall.SetsockoptInt(syscall.Handle(s), syscall.IPPROTO_IPV6, syscall.IPV6_UNICAST_HOPS, 1)
+ } else if addr.IP.To16() == nil && addr.IP.To4() != nil {
+ operr = syscall.SetsockoptInt(syscall.Handle(s), syscall.IPPROTO_IP, syscall.IP_TTL, 1)
+ }
+ }
+ }
+ if err := c.Control(fn); err != nil {
+ return err
+ }
+ return operr
+}
+
+func controlOnConnSetup(network string, address string, c syscall.RawConn) error {
+ var operr error
+ var fn func(uintptr)
+ switch network {
+ case "tcp", "udp", "ip":
+ return errors.New("ambiguous network: " + network)
+ default:
+ switch network[len(network)-1] {
+ case '4':
+ fn = func(s uintptr) {
+ operr = syscall.SetsockoptInt(syscall.Handle(s), syscall.IPPROTO_IP, syscall.IP_TTL, 1)
+ }
+ case '6':
+ fn = func(s uintptr) {
+ operr = syscall.SetsockoptInt(syscall.Handle(s), syscall.IPPROTO_IPV6, syscall.IPV6_UNICAST_HOPS, 1)
+ }
+ default:
+ return errors.New("unknown network: " + network)
+ }
+ }
+ if err := c.Control(fn); err != nil {
+ return err
+ }
+ return operr
+}
diff --git a/src/net/resolverdialfunc_test.go b/src/net/resolverdialfunc_test.go
new file mode 100644
index 0000000..1de0402
--- /dev/null
+++ b/src/net/resolverdialfunc_test.go
@@ -0,0 +1,327 @@
+// Copyright 2022 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !js && !wasip1
+
+// Test that Resolver.Dial can be a func returning an in-memory net.Conn
+// speaking DNS.
+
+package net
+
+import (
+ "bytes"
+ "context"
+ "errors"
+ "fmt"
+ "reflect"
+ "sort"
+ "testing"
+ "time"
+
+ "golang.org/x/net/dns/dnsmessage"
+)
+
+func TestResolverDialFunc(t *testing.T) {
+ r := &Resolver{
+ PreferGo: true,
+ Dial: newResolverDialFunc(&resolverDialHandler{
+ StartDial: func(network, address string) error {
+ t.Logf("StartDial(%q, %q) ...", network, address)
+ return nil
+ },
+ Question: func(h dnsmessage.Header, q dnsmessage.Question) {
+ t.Logf("Header: %+v for %q (type=%v, class=%v)", h,
+ q.Name.String(), q.Type, q.Class)
+ },
+ // TODO: add test without HandleA* hooks specified at all, that Go
+ // doesn't issue retries; map to something terminal.
+ HandleA: func(w AWriter, name string) error {
+ w.AddIP([4]byte{1, 2, 3, 4})
+ w.AddIP([4]byte{5, 6, 7, 8})
+ return nil
+ },
+ HandleAAAA: func(w AAAAWriter, name string) error {
+ w.AddIP([16]byte{1: 1, 15: 15})
+ w.AddIP([16]byte{2: 2, 14: 14})
+ return nil
+ },
+ HandleSRV: func(w SRVWriter, name string) error {
+ w.AddSRV(1, 2, 80, "foo.bar.")
+ w.AddSRV(2, 3, 81, "bar.baz.")
+ return nil
+ },
+ }),
+ }
+ ctx := context.Background()
+ const fakeDomain = "something-that-is-a-not-a-real-domain.fake-tld."
+
+ t.Run("LookupIP", func(t *testing.T) {
+ ips, err := r.LookupIP(ctx, "ip", fakeDomain)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got, want := sortedIPStrings(ips), []string{"0:200::e00", "1.2.3.4", "1::f", "5.6.7.8"}; !reflect.DeepEqual(got, want) {
+ t.Errorf("LookupIP wrong.\n got: %q\nwant: %q\n", got, want)
+ }
+ })
+
+ t.Run("LookupSRV", func(t *testing.T) {
+ _, got, err := r.LookupSRV(ctx, "some-service", "tcp", fakeDomain)
+ if err != nil {
+ t.Fatal(err)
+ }
+ want := []*SRV{
+ {
+ Target: "foo.bar.",
+ Port: 80,
+ Priority: 1,
+ Weight: 2,
+ },
+ {
+ Target: "bar.baz.",
+ Port: 81,
+ Priority: 2,
+ Weight: 3,
+ },
+ }
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("wrong result. got:")
+ for _, r := range got {
+ t.Logf(" - %+v", r)
+ }
+ }
+ })
+}
+
+func sortedIPStrings(ips []IP) []string {
+ ret := make([]string, len(ips))
+ for i, ip := range ips {
+ ret[i] = ip.String()
+ }
+ sort.Strings(ret)
+ return ret
+}
+
+func newResolverDialFunc(h *resolverDialHandler) func(ctx context.Context, network, address string) (Conn, error) {
+ return func(ctx context.Context, network, address string) (Conn, error) {
+ a := &resolverFuncConn{
+ h: h,
+ network: network,
+ address: address,
+ ttl: 10, // 10 second default if unset
+ }
+ if h.StartDial != nil {
+ if err := h.StartDial(network, address); err != nil {
+ return nil, err
+ }
+ }
+ return a, nil
+ }
+}
+
+type resolverDialHandler struct {
+ // StartDial, if non-nil, is called when Go first calls Resolver.Dial.
+ // Any error returned aborts the dial and is returned unwrapped.
+ StartDial func(network, address string) error
+
+ Question func(dnsmessage.Header, dnsmessage.Question)
+
+ // err may be ErrNotExist or ErrRefused; others map to SERVFAIL (RCode2).
+ // A nil error means success.
+ HandleA func(w AWriter, name string) error
+ HandleAAAA func(w AAAAWriter, name string) error
+ HandleSRV func(w SRVWriter, name string) error
+}
+
+type ResponseWriter struct{ a *resolverFuncConn }
+
+func (w ResponseWriter) header() dnsmessage.ResourceHeader {
+ q := w.a.q
+ return dnsmessage.ResourceHeader{
+ Name: q.Name,
+ Type: q.Type,
+ Class: q.Class,
+ TTL: w.a.ttl,
+ }
+}
+
+// SetTTL sets the TTL for subsequent written resources.
+// Once a resource has been written, SetTTL calls are no-ops.
+// That is, it can only be called at most once, before anything
+// else is written.
+func (w ResponseWriter) SetTTL(seconds uint32) {
+ // ... intention is last one wins and mutates all previously
+ // written records too, but that's a little annoying.
+ // But it's also annoying if the requirement is it needs to be set
+ // last.
+ // And it's also annoying if it's possible for users to set
+ // different TTLs per Answer.
+ if w.a.wrote {
+ return
+ }
+ w.a.ttl = seconds
+
+}
+
+type AWriter struct{ ResponseWriter }
+
+func (w AWriter) AddIP(v4 [4]byte) {
+ w.a.wrote = true
+ err := w.a.builder.AResource(w.header(), dnsmessage.AResource{A: v4})
+ if err != nil {
+ panic(err)
+ }
+}
+
+type AAAAWriter struct{ ResponseWriter }
+
+func (w AAAAWriter) AddIP(v6 [16]byte) {
+ w.a.wrote = true
+ err := w.a.builder.AAAAResource(w.header(), dnsmessage.AAAAResource{AAAA: v6})
+ if err != nil {
+ panic(err)
+ }
+}
+
+type SRVWriter struct{ ResponseWriter }
+
+// AddSRV adds a SRV record. The target name must end in a period and
+// be 63 bytes or fewer.
+func (w SRVWriter) AddSRV(priority, weight, port uint16, target string) error {
+ targetName, err := dnsmessage.NewName(target)
+ if err != nil {
+ return err
+ }
+ w.a.wrote = true
+ err = w.a.builder.SRVResource(w.header(), dnsmessage.SRVResource{
+ Priority: priority,
+ Weight: weight,
+ Port: port,
+ Target: targetName,
+ })
+ if err != nil {
+ panic(err) // internal fault, not user
+ }
+ return nil
+}
+
+var (
+ ErrNotExist = errors.New("name does not exist") // maps to RCode3, NXDOMAIN
+ ErrRefused = errors.New("refused") // maps to RCode5, REFUSED
+)
+
+type resolverFuncConn struct {
+ h *resolverDialHandler
+ network string
+ address string
+ builder *dnsmessage.Builder
+ q dnsmessage.Question
+ ttl uint32
+ wrote bool
+
+ rbuf bytes.Buffer
+}
+
+func (*resolverFuncConn) Close() error { return nil }
+func (*resolverFuncConn) LocalAddr() Addr { return someaddr{} }
+func (*resolverFuncConn) RemoteAddr() Addr { return someaddr{} }
+func (*resolverFuncConn) SetDeadline(t time.Time) error { return nil }
+func (*resolverFuncConn) SetReadDeadline(t time.Time) error { return nil }
+func (*resolverFuncConn) SetWriteDeadline(t time.Time) error { return nil }
+
+func (a *resolverFuncConn) Read(p []byte) (n int, err error) {
+ return a.rbuf.Read(p)
+}
+
+func (a *resolverFuncConn) Write(packet []byte) (n int, err error) {
+ if len(packet) < 2 {
+ return 0, fmt.Errorf("short write of %d bytes; want 2+", len(packet))
+ }
+ reqLen := int(packet[0])<<8 | int(packet[1])
+ req := packet[2:]
+ if len(req) != reqLen {
+ return 0, fmt.Errorf("packet declared length %d doesn't match body length %d", reqLen, len(req))
+ }
+
+ var parser dnsmessage.Parser
+ h, err := parser.Start(req)
+ if err != nil {
+ // TODO: hook
+ return 0, err
+ }
+ q, err := parser.Question()
+ hadQ := (err == nil)
+ if err == nil && a.h.Question != nil {
+ a.h.Question(h, q)
+ }
+ if err != nil && err != dnsmessage.ErrSectionDone {
+ return 0, err
+ }
+
+ resh := h
+ resh.Response = true
+ resh.Authoritative = true
+ if hadQ {
+ resh.RCode = dnsmessage.RCodeSuccess
+ } else {
+ resh.RCode = dnsmessage.RCodeNotImplemented
+ }
+ a.rbuf.Grow(514)
+ a.rbuf.WriteByte('X') // reserved header for beu16 length
+ a.rbuf.WriteByte('Y') // reserved header for beu16 length
+ builder := dnsmessage.NewBuilder(a.rbuf.Bytes(), resh)
+ a.builder = &builder
+ if hadQ {
+ a.q = q
+ a.builder.StartQuestions()
+ err := a.builder.Question(q)
+ if err != nil {
+ return 0, fmt.Errorf("Question: %w", err)
+ }
+ a.builder.StartAnswers()
+ switch q.Type {
+ case dnsmessage.TypeA:
+ if a.h.HandleA != nil {
+ resh.RCode = mapRCode(a.h.HandleA(AWriter{ResponseWriter{a}}, q.Name.String()))
+ }
+ case dnsmessage.TypeAAAA:
+ if a.h.HandleAAAA != nil {
+ resh.RCode = mapRCode(a.h.HandleAAAA(AAAAWriter{ResponseWriter{a}}, q.Name.String()))
+ }
+ case dnsmessage.TypeSRV:
+ if a.h.HandleSRV != nil {
+ resh.RCode = mapRCode(a.h.HandleSRV(SRVWriter{ResponseWriter{a}}, q.Name.String()))
+ }
+ }
+ }
+ tcpRes, err := builder.Finish()
+ if err != nil {
+ return 0, fmt.Errorf("Finish: %w", err)
+ }
+
+ n = len(tcpRes) - 2
+ tcpRes[0] = byte(n >> 8)
+ tcpRes[1] = byte(n)
+ a.rbuf.Write(tcpRes[2:])
+
+ return len(packet), nil
+}
+
+type someaddr struct{}
+
+func (someaddr) Network() string { return "unused" }
+func (someaddr) String() string { return "unused-someaddr" }
+
+func mapRCode(err error) dnsmessage.RCode {
+ switch err {
+ case nil:
+ return dnsmessage.RCodeSuccess
+ case ErrNotExist:
+ return dnsmessage.RCodeNameError
+ case ErrRefused:
+ return dnsmessage.RCodeRefused
+ default:
+ return dnsmessage.RCodeServerFailure
+ }
+}
diff --git a/src/net/rpc/client.go b/src/net/rpc/client.go
new file mode 100644
index 0000000..42d1351
--- /dev/null
+++ b/src/net/rpc/client.go
@@ -0,0 +1,323 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package rpc
+
+import (
+ "bufio"
+ "encoding/gob"
+ "errors"
+ "io"
+ "log"
+ "net"
+ "net/http"
+ "sync"
+)
+
+// ServerError represents an error that has been returned from
+// the remote side of the RPC connection.
+type ServerError string
+
+func (e ServerError) Error() string {
+ return string(e)
+}
+
+var ErrShutdown = errors.New("connection is shut down")
+
+// Call represents an active RPC.
+type Call struct {
+ ServiceMethod string // The name of the service and method to call.
+ Args any // The argument to the function (*struct).
+ Reply any // The reply from the function (*struct).
+ Error error // After completion, the error status.
+ Done chan *Call // Receives *Call when Go is complete.
+}
+
+// Client represents an RPC Client.
+// There may be multiple outstanding Calls associated
+// with a single Client, and a Client may be used by
+// multiple goroutines simultaneously.
+type Client struct {
+ codec ClientCodec
+
+ reqMutex sync.Mutex // protects following
+ request Request
+
+ mutex sync.Mutex // protects following
+ seq uint64
+ pending map[uint64]*Call
+ closing bool // user has called Close
+ shutdown bool // server has told us to stop
+}
+
+// A ClientCodec implements writing of RPC requests and
+// reading of RPC responses for the client side of an RPC session.
+// The client calls WriteRequest to write a request to the connection
+// and calls ReadResponseHeader and ReadResponseBody in pairs
+// to read responses. The client calls Close when finished with the
+// connection. ReadResponseBody may be called with a nil
+// argument to force the body of the response to be read and then
+// discarded.
+// See NewClient's comment for information about concurrent access.
+type ClientCodec interface {
+ WriteRequest(*Request, any) error
+ ReadResponseHeader(*Response) error
+ ReadResponseBody(any) error
+
+ Close() error
+}
+
+func (client *Client) send(call *Call) {
+ client.reqMutex.Lock()
+ defer client.reqMutex.Unlock()
+
+ // Register this call.
+ client.mutex.Lock()
+ if client.shutdown || client.closing {
+ client.mutex.Unlock()
+ call.Error = ErrShutdown
+ call.done()
+ return
+ }
+ seq := client.seq
+ client.seq++
+ client.pending[seq] = call
+ client.mutex.Unlock()
+
+ // Encode and send the request.
+ client.request.Seq = seq
+ client.request.ServiceMethod = call.ServiceMethod
+ err := client.codec.WriteRequest(&client.request, call.Args)
+ if err != nil {
+ client.mutex.Lock()
+ call = client.pending[seq]
+ delete(client.pending, seq)
+ client.mutex.Unlock()
+ if call != nil {
+ call.Error = err
+ call.done()
+ }
+ }
+}
+
+func (client *Client) input() {
+ var err error
+ var response Response
+ for err == nil {
+ response = Response{}
+ err = client.codec.ReadResponseHeader(&response)
+ if err != nil {
+ break
+ }
+ seq := response.Seq
+ client.mutex.Lock()
+ call := client.pending[seq]
+ delete(client.pending, seq)
+ client.mutex.Unlock()
+
+ switch {
+ case call == nil:
+ // We've got no pending call. That usually means that
+ // WriteRequest partially failed, and call was already
+ // removed; response is a server telling us about an
+ // error reading request body. We should still attempt
+ // to read error body, but there's no one to give it to.
+ err = client.codec.ReadResponseBody(nil)
+ if err != nil {
+ err = errors.New("reading error body: " + err.Error())
+ }
+ case response.Error != "":
+ // We've got an error response. Give this to the request;
+ // any subsequent requests will get the ReadResponseBody
+ // error if there is one.
+ call.Error = ServerError(response.Error)
+ err = client.codec.ReadResponseBody(nil)
+ if err != nil {
+ err = errors.New("reading error body: " + err.Error())
+ }
+ call.done()
+ default:
+ err = client.codec.ReadResponseBody(call.Reply)
+ if err != nil {
+ call.Error = errors.New("reading body " + err.Error())
+ }
+ call.done()
+ }
+ }
+ // Terminate pending calls.
+ client.reqMutex.Lock()
+ client.mutex.Lock()
+ client.shutdown = true
+ closing := client.closing
+ if err == io.EOF {
+ if closing {
+ err = ErrShutdown
+ } else {
+ err = io.ErrUnexpectedEOF
+ }
+ }
+ for _, call := range client.pending {
+ call.Error = err
+ call.done()
+ }
+ client.mutex.Unlock()
+ client.reqMutex.Unlock()
+ if debugLog && err != io.EOF && !closing {
+ log.Println("rpc: client protocol error:", err)
+ }
+}
+
+func (call *Call) done() {
+ select {
+ case call.Done <- call:
+ // ok
+ default:
+ // We don't want to block here. It is the caller's responsibility to make
+ // sure the channel has enough buffer space. See comment in Go().
+ if debugLog {
+ log.Println("rpc: discarding Call reply due to insufficient Done chan capacity")
+ }
+ }
+}
+
+// NewClient returns a new Client to handle requests to the
+// set of services at the other end of the connection.
+// It adds a buffer to the write side of the connection so
+// the header and payload are sent as a unit.
+//
+// The read and write halves of the connection are serialized independently,
+// so no interlocking is required. However each half may be accessed
+// concurrently so the implementation of conn should protect against
+// concurrent reads or concurrent writes.
+func NewClient(conn io.ReadWriteCloser) *Client {
+ encBuf := bufio.NewWriter(conn)
+ client := &gobClientCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(encBuf), encBuf}
+ return NewClientWithCodec(client)
+}
+
+// NewClientWithCodec is like NewClient but uses the specified
+// codec to encode requests and decode responses.
+func NewClientWithCodec(codec ClientCodec) *Client {
+ client := &Client{
+ codec: codec,
+ pending: make(map[uint64]*Call),
+ }
+ go client.input()
+ return client
+}
+
+type gobClientCodec struct {
+ rwc io.ReadWriteCloser
+ dec *gob.Decoder
+ enc *gob.Encoder
+ encBuf *bufio.Writer
+}
+
+func (c *gobClientCodec) WriteRequest(r *Request, body any) (err error) {
+ if err = c.enc.Encode(r); err != nil {
+ return
+ }
+ if err = c.enc.Encode(body); err != nil {
+ return
+ }
+ return c.encBuf.Flush()
+}
+
+func (c *gobClientCodec) ReadResponseHeader(r *Response) error {
+ return c.dec.Decode(r)
+}
+
+func (c *gobClientCodec) ReadResponseBody(body any) error {
+ return c.dec.Decode(body)
+}
+
+func (c *gobClientCodec) Close() error {
+ return c.rwc.Close()
+}
+
+// DialHTTP connects to an HTTP RPC server at the specified network address
+// listening on the default HTTP RPC path.
+func DialHTTP(network, address string) (*Client, error) {
+ return DialHTTPPath(network, address, DefaultRPCPath)
+}
+
+// DialHTTPPath connects to an HTTP RPC server
+// at the specified network address and path.
+func DialHTTPPath(network, address, path string) (*Client, error) {
+ conn, err := net.Dial(network, address)
+ if err != nil {
+ return nil, err
+ }
+ io.WriteString(conn, "CONNECT "+path+" HTTP/1.0\n\n")
+
+ // Require successful HTTP response
+ // before switching to RPC protocol.
+ resp, err := http.ReadResponse(bufio.NewReader(conn), &http.Request{Method: "CONNECT"})
+ if err == nil && resp.Status == connected {
+ return NewClient(conn), nil
+ }
+ if err == nil {
+ err = errors.New("unexpected HTTP response: " + resp.Status)
+ }
+ conn.Close()
+ return nil, &net.OpError{
+ Op: "dial-http",
+ Net: network + " " + address,
+ Addr: nil,
+ Err: err,
+ }
+}
+
+// Dial connects to an RPC server at the specified network address.
+func Dial(network, address string) (*Client, error) {
+ conn, err := net.Dial(network, address)
+ if err != nil {
+ return nil, err
+ }
+ return NewClient(conn), nil
+}
+
+// Close calls the underlying codec's Close method. If the connection is already
+// shutting down, ErrShutdown is returned.
+func (client *Client) Close() error {
+ client.mutex.Lock()
+ if client.closing {
+ client.mutex.Unlock()
+ return ErrShutdown
+ }
+ client.closing = true
+ client.mutex.Unlock()
+ return client.codec.Close()
+}
+
+// Go invokes the function asynchronously. It returns the Call structure representing
+// the invocation. The done channel will signal when the call is complete by returning
+// the same Call object. If done is nil, Go will allocate a new channel.
+// If non-nil, done must be buffered or Go will deliberately crash.
+func (client *Client) Go(serviceMethod string, args any, reply any, done chan *Call) *Call {
+ call := new(Call)
+ call.ServiceMethod = serviceMethod
+ call.Args = args
+ call.Reply = reply
+ if done == nil {
+ done = make(chan *Call, 10) // buffered.
+ } else {
+ // If caller passes done != nil, it must arrange that
+ // done has enough buffer for the number of simultaneous
+ // RPCs that will be using that channel. If the channel
+ // is totally unbuffered, it's best not to run at all.
+ if cap(done) == 0 {
+ log.Panic("rpc: done channel is unbuffered")
+ }
+ }
+ call.Done = done
+ client.send(call)
+ return call
+}
+
+// Call invokes the named function, waits for it to complete, and returns its error status.
+func (client *Client) Call(serviceMethod string, args any, reply any) error {
+ call := <-client.Go(serviceMethod, args, reply, make(chan *Call, 1)).Done
+ return call.Error
+}
diff --git a/src/net/rpc/client_test.go b/src/net/rpc/client_test.go
new file mode 100644
index 0000000..ffc12fa
--- /dev/null
+++ b/src/net/rpc/client_test.go
@@ -0,0 +1,87 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package rpc
+
+import (
+ "errors"
+ "fmt"
+ "net"
+ "strings"
+ "testing"
+)
+
+type shutdownCodec struct {
+ responded chan int
+ closed bool
+}
+
+func (c *shutdownCodec) WriteRequest(*Request, any) error { return nil }
+func (c *shutdownCodec) ReadResponseBody(any) error { return nil }
+func (c *shutdownCodec) ReadResponseHeader(*Response) error {
+ c.responded <- 1
+ return errors.New("shutdownCodec ReadResponseHeader")
+}
+func (c *shutdownCodec) Close() error {
+ c.closed = true
+ return nil
+}
+
+func TestCloseCodec(t *testing.T) {
+ codec := &shutdownCodec{responded: make(chan int)}
+ client := NewClientWithCodec(codec)
+ <-codec.responded
+ client.Close()
+ if !codec.closed {
+ t.Error("client.Close did not close codec")
+ }
+}
+
+// Test that errors in gob shut down the connection. Issue 7689.
+
+type R struct {
+ msg []byte // Not exported, so R does not work with gob.
+}
+
+type S struct{}
+
+func (s *S) Recv(nul *struct{}, reply *R) error {
+ *reply = R{[]byte("foo")}
+ return nil
+}
+
+func TestGobError(t *testing.T) {
+ defer func() {
+ err := recover()
+ if err == nil {
+ t.Fatal("no error")
+ }
+ if !strings.Contains(err.(error).Error(), "reading body unexpected EOF") {
+ t.Fatal("expected `reading body unexpected EOF', got", err)
+ }
+ }()
+ Register(new(S))
+
+ listen, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ panic(err)
+ }
+ go Accept(listen)
+
+ client, err := Dial("tcp", listen.Addr().String())
+ if err != nil {
+ panic(err)
+ }
+
+ var reply Reply
+ err = client.Call("S.Recv", &struct{}{}, &reply)
+ if err != nil {
+ panic(err)
+ }
+
+ fmt.Printf("%#v\n", reply)
+ client.Close()
+
+ listen.Close()
+}
diff --git a/src/net/rpc/debug.go b/src/net/rpc/debug.go
new file mode 100644
index 0000000..9e499fd
--- /dev/null
+++ b/src/net/rpc/debug.go
@@ -0,0 +1,90 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package rpc
+
+/*
+ Some HTML presented at http://machine:port/debug/rpc
+ Lists services, their methods, and some statistics, still rudimentary.
+*/
+
+import (
+ "fmt"
+ "html/template"
+ "net/http"
+ "sort"
+)
+
+const debugText = `<html>
+ <body>
+ <title>Services</title>
+ {{range .}}
+ <hr>
+ Service {{.Name}}
+ <hr>
+ <table>
+ <th align=center>Method</th><th align=center>Calls</th>
+ {{range .Method}}
+ <tr>
+ <td align=left font=fixed>{{.Name}}({{.Type.ArgType}}, {{.Type.ReplyType}}) error</td>
+ <td align=center>{{.Type.NumCalls}}</td>
+ </tr>
+ {{end}}
+ </table>
+ {{end}}
+ </body>
+ </html>`
+
+var debug = template.Must(template.New("RPC debug").Parse(debugText))
+
+// If set, print log statements for internal and I/O errors.
+var debugLog = false
+
+type debugMethod struct {
+ Type *methodType
+ Name string
+}
+
+type methodArray []debugMethod
+
+type debugService struct {
+ Service *service
+ Name string
+ Method methodArray
+}
+
+type serviceArray []debugService
+
+func (s serviceArray) Len() int { return len(s) }
+func (s serviceArray) Less(i, j int) bool { return s[i].Name < s[j].Name }
+func (s serviceArray) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
+
+func (m methodArray) Len() int { return len(m) }
+func (m methodArray) Less(i, j int) bool { return m[i].Name < m[j].Name }
+func (m methodArray) Swap(i, j int) { m[i], m[j] = m[j], m[i] }
+
+type debugHTTP struct {
+ *Server
+}
+
+// Runs at /debug/rpc
+func (server debugHTTP) ServeHTTP(w http.ResponseWriter, req *http.Request) {
+ // Build a sorted version of the data.
+ var services serviceArray
+ server.serviceMap.Range(func(snamei, svci any) bool {
+ svc := svci.(*service)
+ ds := debugService{svc, snamei.(string), make(methodArray, 0, len(svc.method))}
+ for mname, method := range svc.method {
+ ds.Method = append(ds.Method, debugMethod{method, mname})
+ }
+ sort.Sort(ds.Method)
+ services = append(services, ds)
+ return true
+ })
+ sort.Sort(services)
+ err := debug.Execute(w, services)
+ if err != nil {
+ fmt.Fprintln(w, "rpc: error executing template:", err.Error())
+ }
+}
diff --git a/src/net/rpc/jsonrpc/all_test.go b/src/net/rpc/jsonrpc/all_test.go
new file mode 100644
index 0000000..e2ccdfc
--- /dev/null
+++ b/src/net/rpc/jsonrpc/all_test.go
@@ -0,0 +1,352 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package jsonrpc
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "net/rpc"
+ "reflect"
+ "strings"
+ "testing"
+)
+
+type Args struct {
+ A, B int
+}
+
+type Reply struct {
+ C int
+}
+
+type Arith int
+
+type ArithAddResp struct {
+ Id any `json:"id"`
+ Result Reply `json:"result"`
+ Error any `json:"error"`
+}
+
+func (t *Arith) Add(args *Args, reply *Reply) error {
+ reply.C = args.A + args.B
+ return nil
+}
+
+func (t *Arith) Mul(args *Args, reply *Reply) error {
+ reply.C = args.A * args.B
+ return nil
+}
+
+func (t *Arith) Div(args *Args, reply *Reply) error {
+ if args.B == 0 {
+ return errors.New("divide by zero")
+ }
+ reply.C = args.A / args.B
+ return nil
+}
+
+func (t *Arith) Error(args *Args, reply *Reply) error {
+ panic("ERROR")
+}
+
+type BuiltinTypes struct{}
+
+func (BuiltinTypes) Map(i int, reply *map[int]int) error {
+ (*reply)[i] = i
+ return nil
+}
+
+func (BuiltinTypes) Slice(i int, reply *[]int) error {
+ *reply = append(*reply, i)
+ return nil
+}
+
+func (BuiltinTypes) Array(i int, reply *[1]int) error {
+ (*reply)[0] = i
+ return nil
+}
+
+func init() {
+ rpc.Register(new(Arith))
+ rpc.Register(BuiltinTypes{})
+}
+
+func TestServerNoParams(t *testing.T) {
+ cli, srv := net.Pipe()
+ defer cli.Close()
+ go ServeConn(srv)
+ dec := json.NewDecoder(cli)
+
+ fmt.Fprintf(cli, `{"method": "Arith.Add", "id": "123"}`)
+ var resp ArithAddResp
+ if err := dec.Decode(&resp); err != nil {
+ t.Fatalf("Decode after no params: %s", err)
+ }
+ if resp.Error == nil {
+ t.Fatalf("Expected error, got nil")
+ }
+}
+
+func TestServerEmptyMessage(t *testing.T) {
+ cli, srv := net.Pipe()
+ defer cli.Close()
+ go ServeConn(srv)
+ dec := json.NewDecoder(cli)
+
+ fmt.Fprintf(cli, "{}")
+ var resp ArithAddResp
+ if err := dec.Decode(&resp); err != nil {
+ t.Fatalf("Decode after empty: %s", err)
+ }
+ if resp.Error == nil {
+ t.Fatalf("Expected error, got nil")
+ }
+}
+
+func TestServer(t *testing.T) {
+ cli, srv := net.Pipe()
+ defer cli.Close()
+ go ServeConn(srv)
+ dec := json.NewDecoder(cli)
+
+ // Send hand-coded requests to server, parse responses.
+ for i := 0; i < 10; i++ {
+ fmt.Fprintf(cli, `{"method": "Arith.Add", "id": "\u%04d", "params": [{"A": %d, "B": %d}]}`, i, i, i+1)
+ var resp ArithAddResp
+ err := dec.Decode(&resp)
+ if err != nil {
+ t.Fatalf("Decode: %s", err)
+ }
+ if resp.Error != nil {
+ t.Fatalf("resp.Error: %s", resp.Error)
+ }
+ if resp.Id.(string) != string(rune(i)) {
+ t.Fatalf("resp: bad id %q want %q", resp.Id.(string), string(rune(i)))
+ }
+ if resp.Result.C != 2*i+1 {
+ t.Fatalf("resp: bad result: %d+%d=%d", i, i+1, resp.Result.C)
+ }
+ }
+}
+
+func TestClient(t *testing.T) {
+ // Assume server is okay (TestServer is above).
+ // Test client against server.
+ cli, srv := net.Pipe()
+ go ServeConn(srv)
+
+ client := NewClient(cli)
+ defer client.Close()
+
+ // Synchronous calls
+ args := &Args{7, 8}
+ reply := new(Reply)
+ err := client.Call("Arith.Add", args, reply)
+ if err != nil {
+ t.Errorf("Add: expected no error but got string %q", err.Error())
+ }
+ if reply.C != args.A+args.B {
+ t.Errorf("Add: got %d expected %d", reply.C, args.A+args.B)
+ }
+
+ args = &Args{7, 8}
+ reply = new(Reply)
+ err = client.Call("Arith.Mul", args, reply)
+ if err != nil {
+ t.Errorf("Mul: expected no error but got string %q", err.Error())
+ }
+ if reply.C != args.A*args.B {
+ t.Errorf("Mul: got %d expected %d", reply.C, args.A*args.B)
+ }
+
+ // Out of order.
+ args = &Args{7, 8}
+ mulReply := new(Reply)
+ mulCall := client.Go("Arith.Mul", args, mulReply, nil)
+ addReply := new(Reply)
+ addCall := client.Go("Arith.Add", args, addReply, nil)
+
+ addCall = <-addCall.Done
+ if addCall.Error != nil {
+ t.Errorf("Add: expected no error but got string %q", addCall.Error.Error())
+ }
+ if addReply.C != args.A+args.B {
+ t.Errorf("Add: got %d expected %d", addReply.C, args.A+args.B)
+ }
+
+ mulCall = <-mulCall.Done
+ if mulCall.Error != nil {
+ t.Errorf("Mul: expected no error but got string %q", mulCall.Error.Error())
+ }
+ if mulReply.C != args.A*args.B {
+ t.Errorf("Mul: got %d expected %d", mulReply.C, args.A*args.B)
+ }
+
+ // Error test
+ args = &Args{7, 0}
+ reply = new(Reply)
+ err = client.Call("Arith.Div", args, reply)
+ // expect an error: zero divide
+ if err == nil {
+ t.Error("Div: expected error")
+ } else if err.Error() != "divide by zero" {
+ t.Error("Div: expected divide by zero error; got", err)
+ }
+}
+
+func TestBuiltinTypes(t *testing.T) {
+ cli, srv := net.Pipe()
+ go ServeConn(srv)
+
+ client := NewClient(cli)
+ defer client.Close()
+
+ // Map
+ arg := 7
+ replyMap := map[int]int{}
+ err := client.Call("BuiltinTypes.Map", arg, &replyMap)
+ if err != nil {
+ t.Errorf("Map: expected no error but got string %q", err.Error())
+ }
+ if replyMap[arg] != arg {
+ t.Errorf("Map: expected %d got %d", arg, replyMap[arg])
+ }
+
+ // Slice
+ replySlice := []int{}
+ err = client.Call("BuiltinTypes.Slice", arg, &replySlice)
+ if err != nil {
+ t.Errorf("Slice: expected no error but got string %q", err.Error())
+ }
+ if e := []int{arg}; !reflect.DeepEqual(replySlice, e) {
+ t.Errorf("Slice: expected %v got %v", e, replySlice)
+ }
+
+ // Array
+ replyArray := [1]int{}
+ err = client.Call("BuiltinTypes.Array", arg, &replyArray)
+ if err != nil {
+ t.Errorf("Array: expected no error but got string %q", err.Error())
+ }
+ if e := [1]int{arg}; !reflect.DeepEqual(replyArray, e) {
+ t.Errorf("Array: expected %v got %v", e, replyArray)
+ }
+}
+
+func TestMalformedInput(t *testing.T) {
+ cli, srv := net.Pipe()
+ go cli.Write([]byte(`{id:1}`)) // invalid json
+ ServeConn(srv) // must return, not loop
+}
+
+func TestMalformedOutput(t *testing.T) {
+ cli, srv := net.Pipe()
+ go srv.Write([]byte(`{"id":0,"result":null,"error":null}`))
+ go io.ReadAll(srv)
+
+ client := NewClient(cli)
+ defer client.Close()
+
+ args := &Args{7, 8}
+ reply := new(Reply)
+ err := client.Call("Arith.Add", args, reply)
+ if err == nil {
+ t.Error("expected error")
+ }
+}
+
+func TestServerErrorHasNullResult(t *testing.T) {
+ var out strings.Builder
+ sc := NewServerCodec(struct {
+ io.Reader
+ io.Writer
+ io.Closer
+ }{
+ Reader: strings.NewReader(`{"method": "Arith.Add", "id": "123", "params": []}`),
+ Writer: &out,
+ Closer: io.NopCloser(nil),
+ })
+ r := new(rpc.Request)
+ if err := sc.ReadRequestHeader(r); err != nil {
+ t.Fatal(err)
+ }
+ const valueText = "the value we don't want to see"
+ const errorText = "some error"
+ err := sc.WriteResponse(&rpc.Response{
+ ServiceMethod: "Method",
+ Seq: 1,
+ Error: errorText,
+ }, valueText)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !strings.Contains(out.String(), errorText) {
+ t.Fatalf("Response didn't contain expected error %q: %s", errorText, &out)
+ }
+ if strings.Contains(out.String(), valueText) {
+ t.Errorf("Response contains both an error and value: %s", &out)
+ }
+}
+
+func TestUnexpectedError(t *testing.T) {
+ cli, srv := myPipe()
+ go cli.PipeWriter.CloseWithError(errors.New("unexpected error!")) // reader will get this error
+ ServeConn(srv) // must return, not loop
+}
+
+// Copied from package net.
+func myPipe() (*pipe, *pipe) {
+ r1, w1 := io.Pipe()
+ r2, w2 := io.Pipe()
+
+ return &pipe{r1, w2}, &pipe{r2, w1}
+}
+
+type pipe struct {
+ *io.PipeReader
+ *io.PipeWriter
+}
+
+type pipeAddr int
+
+func (pipeAddr) Network() string {
+ return "pipe"
+}
+
+func (pipeAddr) String() string {
+ return "pipe"
+}
+
+func (p *pipe) Close() error {
+ err := p.PipeReader.Close()
+ err1 := p.PipeWriter.Close()
+ if err == nil {
+ err = err1
+ }
+ return err
+}
+
+func (p *pipe) LocalAddr() net.Addr {
+ return pipeAddr(0)
+}
+
+func (p *pipe) RemoteAddr() net.Addr {
+ return pipeAddr(0)
+}
+
+func (p *pipe) SetTimeout(nsec int64) error {
+ return errors.New("net.Pipe does not support timeouts")
+}
+
+func (p *pipe) SetReadTimeout(nsec int64) error {
+ return errors.New("net.Pipe does not support timeouts")
+}
+
+func (p *pipe) SetWriteTimeout(nsec int64) error {
+ return errors.New("net.Pipe does not support timeouts")
+}
diff --git a/src/net/rpc/jsonrpc/client.go b/src/net/rpc/jsonrpc/client.go
new file mode 100644
index 0000000..c473017
--- /dev/null
+++ b/src/net/rpc/jsonrpc/client.go
@@ -0,0 +1,124 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package jsonrpc implements a JSON-RPC 1.0 ClientCodec and ServerCodec
+// for the rpc package.
+// For JSON-RPC 2.0 support, see https://godoc.org/?q=json-rpc+2.0
+package jsonrpc
+
+import (
+ "encoding/json"
+ "fmt"
+ "io"
+ "net"
+ "net/rpc"
+ "sync"
+)
+
+type clientCodec struct {
+ dec *json.Decoder // for reading JSON values
+ enc *json.Encoder // for writing JSON values
+ c io.Closer
+
+ // temporary work space
+ req clientRequest
+ resp clientResponse
+
+ // JSON-RPC responses include the request id but not the request method.
+ // Package rpc expects both.
+ // We save the request method in pending when sending a request
+ // and then look it up by request ID when filling out the rpc Response.
+ mutex sync.Mutex // protects pending
+ pending map[uint64]string // map request id to method name
+}
+
+// NewClientCodec returns a new rpc.ClientCodec using JSON-RPC on conn.
+func NewClientCodec(conn io.ReadWriteCloser) rpc.ClientCodec {
+ return &clientCodec{
+ dec: json.NewDecoder(conn),
+ enc: json.NewEncoder(conn),
+ c: conn,
+ pending: make(map[uint64]string),
+ }
+}
+
+type clientRequest struct {
+ Method string `json:"method"`
+ Params [1]any `json:"params"`
+ Id uint64 `json:"id"`
+}
+
+func (c *clientCodec) WriteRequest(r *rpc.Request, param any) error {
+ c.mutex.Lock()
+ c.pending[r.Seq] = r.ServiceMethod
+ c.mutex.Unlock()
+ c.req.Method = r.ServiceMethod
+ c.req.Params[0] = param
+ c.req.Id = r.Seq
+ return c.enc.Encode(&c.req)
+}
+
+type clientResponse struct {
+ Id uint64 `json:"id"`
+ Result *json.RawMessage `json:"result"`
+ Error any `json:"error"`
+}
+
+func (r *clientResponse) reset() {
+ r.Id = 0
+ r.Result = nil
+ r.Error = nil
+}
+
+func (c *clientCodec) ReadResponseHeader(r *rpc.Response) error {
+ c.resp.reset()
+ if err := c.dec.Decode(&c.resp); err != nil {
+ return err
+ }
+
+ c.mutex.Lock()
+ r.ServiceMethod = c.pending[c.resp.Id]
+ delete(c.pending, c.resp.Id)
+ c.mutex.Unlock()
+
+ r.Error = ""
+ r.Seq = c.resp.Id
+ if c.resp.Error != nil || c.resp.Result == nil {
+ x, ok := c.resp.Error.(string)
+ if !ok {
+ return fmt.Errorf("invalid error %v", c.resp.Error)
+ }
+ if x == "" {
+ x = "unspecified error"
+ }
+ r.Error = x
+ }
+ return nil
+}
+
+func (c *clientCodec) ReadResponseBody(x any) error {
+ if x == nil {
+ return nil
+ }
+ return json.Unmarshal(*c.resp.Result, x)
+}
+
+func (c *clientCodec) Close() error {
+ return c.c.Close()
+}
+
+// NewClient returns a new rpc.Client to handle requests to the
+// set of services at the other end of the connection.
+func NewClient(conn io.ReadWriteCloser) *rpc.Client {
+ return rpc.NewClientWithCodec(NewClientCodec(conn))
+}
+
+// Dial connects to a JSON-RPC server at the specified network address.
+func Dial(network, address string) (*rpc.Client, error) {
+ conn, err := net.Dial(network, address)
+ if err != nil {
+ return nil, err
+ }
+ return NewClient(conn), err
+}
diff --git a/src/net/rpc/jsonrpc/server.go b/src/net/rpc/jsonrpc/server.go
new file mode 100644
index 0000000..3ee4ddf
--- /dev/null
+++ b/src/net/rpc/jsonrpc/server.go
@@ -0,0 +1,134 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package jsonrpc
+
+import (
+ "encoding/json"
+ "errors"
+ "io"
+ "net/rpc"
+ "sync"
+)
+
+var errMissingParams = errors.New("jsonrpc: request body missing params")
+
+type serverCodec struct {
+ dec *json.Decoder // for reading JSON values
+ enc *json.Encoder // for writing JSON values
+ c io.Closer
+
+ // temporary work space
+ req serverRequest
+
+ // JSON-RPC clients can use arbitrary json values as request IDs.
+ // Package rpc expects uint64 request IDs.
+ // We assign uint64 sequence numbers to incoming requests
+ // but save the original request ID in the pending map.
+ // When rpc responds, we use the sequence number in
+ // the response to find the original request ID.
+ mutex sync.Mutex // protects seq, pending
+ seq uint64
+ pending map[uint64]*json.RawMessage
+}
+
+// NewServerCodec returns a new rpc.ServerCodec using JSON-RPC on conn.
+func NewServerCodec(conn io.ReadWriteCloser) rpc.ServerCodec {
+ return &serverCodec{
+ dec: json.NewDecoder(conn),
+ enc: json.NewEncoder(conn),
+ c: conn,
+ pending: make(map[uint64]*json.RawMessage),
+ }
+}
+
+type serverRequest struct {
+ Method string `json:"method"`
+ Params *json.RawMessage `json:"params"`
+ Id *json.RawMessage `json:"id"`
+}
+
+func (r *serverRequest) reset() {
+ r.Method = ""
+ r.Params = nil
+ r.Id = nil
+}
+
+type serverResponse struct {
+ Id *json.RawMessage `json:"id"`
+ Result any `json:"result"`
+ Error any `json:"error"`
+}
+
+func (c *serverCodec) ReadRequestHeader(r *rpc.Request) error {
+ c.req.reset()
+ if err := c.dec.Decode(&c.req); err != nil {
+ return err
+ }
+ r.ServiceMethod = c.req.Method
+
+ // JSON request id can be any JSON value;
+ // RPC package expects uint64. Translate to
+ // internal uint64 and save JSON on the side.
+ c.mutex.Lock()
+ c.seq++
+ c.pending[c.seq] = c.req.Id
+ c.req.Id = nil
+ r.Seq = c.seq
+ c.mutex.Unlock()
+
+ return nil
+}
+
+func (c *serverCodec) ReadRequestBody(x any) error {
+ if x == nil {
+ return nil
+ }
+ if c.req.Params == nil {
+ return errMissingParams
+ }
+ // JSON params is array value.
+ // RPC params is struct.
+ // Unmarshal into array containing struct for now.
+ // Should think about making RPC more general.
+ var params [1]any
+ params[0] = x
+ return json.Unmarshal(*c.req.Params, &params)
+}
+
+var null = json.RawMessage([]byte("null"))
+
+func (c *serverCodec) WriteResponse(r *rpc.Response, x any) error {
+ c.mutex.Lock()
+ b, ok := c.pending[r.Seq]
+ if !ok {
+ c.mutex.Unlock()
+ return errors.New("invalid sequence number in response")
+ }
+ delete(c.pending, r.Seq)
+ c.mutex.Unlock()
+
+ if b == nil {
+ // Invalid request so no id. Use JSON null.
+ b = &null
+ }
+ resp := serverResponse{Id: b}
+ if r.Error == "" {
+ resp.Result = x
+ } else {
+ resp.Error = r.Error
+ }
+ return c.enc.Encode(resp)
+}
+
+func (c *serverCodec) Close() error {
+ return c.c.Close()
+}
+
+// ServeConn runs the JSON-RPC server on a single connection.
+// ServeConn blocks, serving the connection until the client hangs up.
+// The caller typically invokes ServeConn in a go statement.
+func ServeConn(conn io.ReadWriteCloser) {
+ rpc.ServeCodec(NewServerCodec(conn))
+}
diff --git a/src/net/rpc/server.go b/src/net/rpc/server.go
new file mode 100644
index 0000000..5cea2cc
--- /dev/null
+++ b/src/net/rpc/server.go
@@ -0,0 +1,725 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+/*
+Package rpc provides access to the exported methods of an object across a
+network or other I/O connection. A server registers an object, making it visible
+as a service with the name of the type of the object. After registration, exported
+methods of the object will be accessible remotely. A server may register multiple
+objects (services) of different types but it is an error to register multiple
+objects of the same type.
+
+Only methods that satisfy these criteria will be made available for remote access;
+other methods will be ignored:
+
+ - the method's type is exported.
+ - the method is exported.
+ - the method has two arguments, both exported (or builtin) types.
+ - the method's second argument is a pointer.
+ - the method has return type error.
+
+In effect, the method must look schematically like
+
+ func (t *T) MethodName(argType T1, replyType *T2) error
+
+where T1 and T2 can be marshaled by encoding/gob.
+These requirements apply even if a different codec is used.
+(In the future, these requirements may soften for custom codecs.)
+
+The method's first argument represents the arguments provided by the caller; the
+second argument represents the result parameters to be returned to the caller.
+The method's return value, if non-nil, is passed back as a string that the client
+sees as if created by errors.New. If an error is returned, the reply parameter
+will not be sent back to the client.
+
+The server may handle requests on a single connection by calling ServeConn. More
+typically it will create a network listener and call Accept or, for an HTTP
+listener, HandleHTTP and http.Serve.
+
+A client wishing to use the service establishes a connection and then invokes
+NewClient on the connection. The convenience function Dial (DialHTTP) performs
+both steps for a raw network connection (an HTTP connection). The resulting
+Client object has two methods, Call and Go, that specify the service and method to
+call, a pointer containing the arguments, and a pointer to receive the result
+parameters.
+
+The Call method waits for the remote call to complete while the Go method
+launches the call asynchronously and signals completion using the Call
+structure's Done channel.
+
+Unless an explicit codec is set up, package encoding/gob is used to
+transport the data.
+
+Here is a simple example. A server wishes to export an object of type Arith:
+
+ package server
+
+ import "errors"
+
+ type Args struct {
+ A, B int
+ }
+
+ type Quotient struct {
+ Quo, Rem int
+ }
+
+ type Arith int
+
+ func (t *Arith) Multiply(args *Args, reply *int) error {
+ *reply = args.A * args.B
+ return nil
+ }
+
+ func (t *Arith) Divide(args *Args, quo *Quotient) error {
+ if args.B == 0 {
+ return errors.New("divide by zero")
+ }
+ quo.Quo = args.A / args.B
+ quo.Rem = args.A % args.B
+ return nil
+ }
+
+The server calls (for HTTP service):
+
+ arith := new(Arith)
+ rpc.Register(arith)
+ rpc.HandleHTTP()
+ l, err := net.Listen("tcp", ":1234")
+ if err != nil {
+ log.Fatal("listen error:", err)
+ }
+ go http.Serve(l, nil)
+
+At this point, clients can see a service "Arith" with methods "Arith.Multiply" and
+"Arith.Divide". To invoke one, a client first dials the server:
+
+ client, err := rpc.DialHTTP("tcp", serverAddress + ":1234")
+ if err != nil {
+ log.Fatal("dialing:", err)
+ }
+
+Then it can make a remote call:
+
+ // Synchronous call
+ args := &server.Args{7,8}
+ var reply int
+ err = client.Call("Arith.Multiply", args, &reply)
+ if err != nil {
+ log.Fatal("arith error:", err)
+ }
+ fmt.Printf("Arith: %d*%d=%d", args.A, args.B, reply)
+
+or
+
+ // Asynchronous call
+ quotient := new(Quotient)
+ divCall := client.Go("Arith.Divide", args, quotient, nil)
+ replyCall := <-divCall.Done // will be equal to divCall
+ // check errors, print, etc.
+
+A server implementation will often provide a simple, type-safe wrapper for the
+client.
+
+The net/rpc package is frozen and is not accepting new features.
+*/
+package rpc
+
+import (
+ "bufio"
+ "encoding/gob"
+ "errors"
+ "go/token"
+ "io"
+ "log"
+ "net"
+ "net/http"
+ "reflect"
+ "strings"
+ "sync"
+)
+
+const (
+ // Defaults used by HandleHTTP
+ DefaultRPCPath = "/_goRPC_"
+ DefaultDebugPath = "/debug/rpc"
+)
+
+// Precompute the reflect type for error. Can't use error directly
+// because Typeof takes an empty interface value. This is annoying.
+var typeOfError = reflect.TypeOf((*error)(nil)).Elem()
+
+type methodType struct {
+ sync.Mutex // protects counters
+ method reflect.Method
+ ArgType reflect.Type
+ ReplyType reflect.Type
+ numCalls uint
+}
+
+type service struct {
+ name string // name of service
+ rcvr reflect.Value // receiver of methods for the service
+ typ reflect.Type // type of the receiver
+ method map[string]*methodType // registered methods
+}
+
+// Request is a header written before every RPC call. It is used internally
+// but documented here as an aid to debugging, such as when analyzing
+// network traffic.
+type Request struct {
+ ServiceMethod string // format: "Service.Method"
+ Seq uint64 // sequence number chosen by client
+ next *Request // for free list in Server
+}
+
+// Response is a header written before every RPC return. It is used internally
+// but documented here as an aid to debugging, such as when analyzing
+// network traffic.
+type Response struct {
+ ServiceMethod string // echoes that of the Request
+ Seq uint64 // echoes that of the request
+ Error string // error, if any.
+ next *Response // for free list in Server
+}
+
+// Server represents an RPC Server.
+type Server struct {
+ serviceMap sync.Map // map[string]*service
+ reqLock sync.Mutex // protects freeReq
+ freeReq *Request
+ respLock sync.Mutex // protects freeResp
+ freeResp *Response
+}
+
+// NewServer returns a new Server.
+func NewServer() *Server {
+ return &Server{}
+}
+
+// DefaultServer is the default instance of *Server.
+var DefaultServer = NewServer()
+
+// Is this type exported or a builtin?
+func isExportedOrBuiltinType(t reflect.Type) bool {
+ for t.Kind() == reflect.Pointer {
+ t = t.Elem()
+ }
+ // PkgPath will be non-empty even for an exported type,
+ // so we need to check the type name as well.
+ return token.IsExported(t.Name()) || t.PkgPath() == ""
+}
+
+// Register publishes in the server the set of methods of the
+// receiver value that satisfy the following conditions:
+// - exported method of exported type
+// - two arguments, both of exported type
+// - the second argument is a pointer
+// - one return value, of type error
+//
+// It returns an error if the receiver is not an exported type or has
+// no suitable methods. It also logs the error using package log.
+// The client accesses each method using a string of the form "Type.Method",
+// where Type is the receiver's concrete type.
+func (server *Server) Register(rcvr any) error {
+ return server.register(rcvr, "", false)
+}
+
+// RegisterName is like Register but uses the provided name for the type
+// instead of the receiver's concrete type.
+func (server *Server) RegisterName(name string, rcvr any) error {
+ return server.register(rcvr, name, true)
+}
+
+// logRegisterError specifies whether to log problems during method registration.
+// To debug registration, recompile the package with this set to true.
+const logRegisterError = false
+
+func (server *Server) register(rcvr any, name string, useName bool) error {
+ s := new(service)
+ s.typ = reflect.TypeOf(rcvr)
+ s.rcvr = reflect.ValueOf(rcvr)
+ sname := name
+ if !useName {
+ sname = reflect.Indirect(s.rcvr).Type().Name()
+ }
+ if sname == "" {
+ s := "rpc.Register: no service name for type " + s.typ.String()
+ log.Print(s)
+ return errors.New(s)
+ }
+ if !useName && !token.IsExported(sname) {
+ s := "rpc.Register: type " + sname + " is not exported"
+ log.Print(s)
+ return errors.New(s)
+ }
+ s.name = sname
+
+ // Install the methods
+ s.method = suitableMethods(s.typ, logRegisterError)
+
+ if len(s.method) == 0 {
+ str := ""
+
+ // To help the user, see if a pointer receiver would work.
+ method := suitableMethods(reflect.PointerTo(s.typ), false)
+ if len(method) != 0 {
+ str = "rpc.Register: type " + sname + " has no exported methods of suitable type (hint: pass a pointer to value of that type)"
+ } else {
+ str = "rpc.Register: type " + sname + " has no exported methods of suitable type"
+ }
+ log.Print(str)
+ return errors.New(str)
+ }
+
+ if _, dup := server.serviceMap.LoadOrStore(sname, s); dup {
+ return errors.New("rpc: service already defined: " + sname)
+ }
+ return nil
+}
+
+// suitableMethods returns suitable Rpc methods of typ. It will log
+// errors if logErr is true.
+func suitableMethods(typ reflect.Type, logErr bool) map[string]*methodType {
+ methods := make(map[string]*methodType)
+ for m := 0; m < typ.NumMethod(); m++ {
+ method := typ.Method(m)
+ mtype := method.Type
+ mname := method.Name
+ // Method must be exported.
+ if !method.IsExported() {
+ continue
+ }
+ // Method needs three ins: receiver, *args, *reply.
+ if mtype.NumIn() != 3 {
+ if logErr {
+ log.Printf("rpc.Register: method %q has %d input parameters; needs exactly three\n", mname, mtype.NumIn())
+ }
+ continue
+ }
+ // First arg need not be a pointer.
+ argType := mtype.In(1)
+ if !isExportedOrBuiltinType(argType) {
+ if logErr {
+ log.Printf("rpc.Register: argument type of method %q is not exported: %q\n", mname, argType)
+ }
+ continue
+ }
+ // Second arg must be a pointer.
+ replyType := mtype.In(2)
+ if replyType.Kind() != reflect.Pointer {
+ if logErr {
+ log.Printf("rpc.Register: reply type of method %q is not a pointer: %q\n", mname, replyType)
+ }
+ continue
+ }
+ // Reply type must be exported.
+ if !isExportedOrBuiltinType(replyType) {
+ if logErr {
+ log.Printf("rpc.Register: reply type of method %q is not exported: %q\n", mname, replyType)
+ }
+ continue
+ }
+ // Method needs one out.
+ if mtype.NumOut() != 1 {
+ if logErr {
+ log.Printf("rpc.Register: method %q has %d output parameters; needs exactly one\n", mname, mtype.NumOut())
+ }
+ continue
+ }
+ // The return type of the method must be error.
+ if returnType := mtype.Out(0); returnType != typeOfError {
+ if logErr {
+ log.Printf("rpc.Register: return type of method %q is %q, must be error\n", mname, returnType)
+ }
+ continue
+ }
+ methods[mname] = &methodType{method: method, ArgType: argType, ReplyType: replyType}
+ }
+ return methods
+}
+
+// A value sent as a placeholder for the server's response value when the server
+// receives an invalid request. It is never decoded by the client since the Response
+// contains an error when it is used.
+var invalidRequest = struct{}{}
+
+func (server *Server) sendResponse(sending *sync.Mutex, req *Request, reply any, codec ServerCodec, errmsg string) {
+ resp := server.getResponse()
+ // Encode the response header
+ resp.ServiceMethod = req.ServiceMethod
+ if errmsg != "" {
+ resp.Error = errmsg
+ reply = invalidRequest
+ }
+ resp.Seq = req.Seq
+ sending.Lock()
+ err := codec.WriteResponse(resp, reply)
+ if debugLog && err != nil {
+ log.Println("rpc: writing response:", err)
+ }
+ sending.Unlock()
+ server.freeResponse(resp)
+}
+
+func (m *methodType) NumCalls() (n uint) {
+ m.Lock()
+ n = m.numCalls
+ m.Unlock()
+ return n
+}
+
+func (s *service) call(server *Server, sending *sync.Mutex, wg *sync.WaitGroup, mtype *methodType, req *Request, argv, replyv reflect.Value, codec ServerCodec) {
+ if wg != nil {
+ defer wg.Done()
+ }
+ mtype.Lock()
+ mtype.numCalls++
+ mtype.Unlock()
+ function := mtype.method.Func
+ // Invoke the method, providing a new value for the reply.
+ returnValues := function.Call([]reflect.Value{s.rcvr, argv, replyv})
+ // The return value for the method is an error.
+ errInter := returnValues[0].Interface()
+ errmsg := ""
+ if errInter != nil {
+ errmsg = errInter.(error).Error()
+ }
+ server.sendResponse(sending, req, replyv.Interface(), codec, errmsg)
+ server.freeRequest(req)
+}
+
+type gobServerCodec struct {
+ rwc io.ReadWriteCloser
+ dec *gob.Decoder
+ enc *gob.Encoder
+ encBuf *bufio.Writer
+ closed bool
+}
+
+func (c *gobServerCodec) ReadRequestHeader(r *Request) error {
+ return c.dec.Decode(r)
+}
+
+func (c *gobServerCodec) ReadRequestBody(body any) error {
+ return c.dec.Decode(body)
+}
+
+func (c *gobServerCodec) WriteResponse(r *Response, body any) (err error) {
+ if err = c.enc.Encode(r); err != nil {
+ if c.encBuf.Flush() == nil {
+ // Gob couldn't encode the header. Should not happen, so if it does,
+ // shut down the connection to signal that the connection is broken.
+ log.Println("rpc: gob error encoding response:", err)
+ c.Close()
+ }
+ return
+ }
+ if err = c.enc.Encode(body); err != nil {
+ if c.encBuf.Flush() == nil {
+ // Was a gob problem encoding the body but the header has been written.
+ // Shut down the connection to signal that the connection is broken.
+ log.Println("rpc: gob error encoding body:", err)
+ c.Close()
+ }
+ return
+ }
+ return c.encBuf.Flush()
+}
+
+func (c *gobServerCodec) Close() error {
+ if c.closed {
+ // Only call c.rwc.Close once; otherwise the semantics are undefined.
+ return nil
+ }
+ c.closed = true
+ return c.rwc.Close()
+}
+
+// ServeConn runs the server on a single connection.
+// ServeConn blocks, serving the connection until the client hangs up.
+// The caller typically invokes ServeConn in a go statement.
+// ServeConn uses the gob wire format (see package gob) on the
+// connection. To use an alternate codec, use ServeCodec.
+// See NewClient's comment for information about concurrent access.
+func (server *Server) ServeConn(conn io.ReadWriteCloser) {
+ buf := bufio.NewWriter(conn)
+ srv := &gobServerCodec{
+ rwc: conn,
+ dec: gob.NewDecoder(conn),
+ enc: gob.NewEncoder(buf),
+ encBuf: buf,
+ }
+ server.ServeCodec(srv)
+}
+
+// ServeCodec is like ServeConn but uses the specified codec to
+// decode requests and encode responses.
+func (server *Server) ServeCodec(codec ServerCodec) {
+ sending := new(sync.Mutex)
+ wg := new(sync.WaitGroup)
+ for {
+ service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec)
+ if err != nil {
+ if debugLog && err != io.EOF {
+ log.Println("rpc:", err)
+ }
+ if !keepReading {
+ break
+ }
+ // send a response if we actually managed to read a header.
+ if req != nil {
+ server.sendResponse(sending, req, invalidRequest, codec, err.Error())
+ server.freeRequest(req)
+ }
+ continue
+ }
+ wg.Add(1)
+ go service.call(server, sending, wg, mtype, req, argv, replyv, codec)
+ }
+ // We've seen that there are no more requests.
+ // Wait for responses to be sent before closing codec.
+ wg.Wait()
+ codec.Close()
+}
+
+// ServeRequest is like ServeCodec but synchronously serves a single request.
+// It does not close the codec upon completion.
+func (server *Server) ServeRequest(codec ServerCodec) error {
+ sending := new(sync.Mutex)
+ service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec)
+ if err != nil {
+ if !keepReading {
+ return err
+ }
+ // send a response if we actually managed to read a header.
+ if req != nil {
+ server.sendResponse(sending, req, invalidRequest, codec, err.Error())
+ server.freeRequest(req)
+ }
+ return err
+ }
+ service.call(server, sending, nil, mtype, req, argv, replyv, codec)
+ return nil
+}
+
+func (server *Server) getRequest() *Request {
+ server.reqLock.Lock()
+ req := server.freeReq
+ if req == nil {
+ req = new(Request)
+ } else {
+ server.freeReq = req.next
+ *req = Request{}
+ }
+ server.reqLock.Unlock()
+ return req
+}
+
+func (server *Server) freeRequest(req *Request) {
+ server.reqLock.Lock()
+ req.next = server.freeReq
+ server.freeReq = req
+ server.reqLock.Unlock()
+}
+
+func (server *Server) getResponse() *Response {
+ server.respLock.Lock()
+ resp := server.freeResp
+ if resp == nil {
+ resp = new(Response)
+ } else {
+ server.freeResp = resp.next
+ *resp = Response{}
+ }
+ server.respLock.Unlock()
+ return resp
+}
+
+func (server *Server) freeResponse(resp *Response) {
+ server.respLock.Lock()
+ resp.next = server.freeResp
+ server.freeResp = resp
+ server.respLock.Unlock()
+}
+
+func (server *Server) readRequest(codec ServerCodec) (service *service, mtype *methodType, req *Request, argv, replyv reflect.Value, keepReading bool, err error) {
+ service, mtype, req, keepReading, err = server.readRequestHeader(codec)
+ if err != nil {
+ if !keepReading {
+ return
+ }
+ // discard body
+ codec.ReadRequestBody(nil)
+ return
+ }
+
+ // Decode the argument value.
+ argIsValue := false // if true, need to indirect before calling.
+ if mtype.ArgType.Kind() == reflect.Pointer {
+ argv = reflect.New(mtype.ArgType.Elem())
+ } else {
+ argv = reflect.New(mtype.ArgType)
+ argIsValue = true
+ }
+ // argv guaranteed to be a pointer now.
+ if err = codec.ReadRequestBody(argv.Interface()); err != nil {
+ return
+ }
+ if argIsValue {
+ argv = argv.Elem()
+ }
+
+ replyv = reflect.New(mtype.ReplyType.Elem())
+
+ switch mtype.ReplyType.Elem().Kind() {
+ case reflect.Map:
+ replyv.Elem().Set(reflect.MakeMap(mtype.ReplyType.Elem()))
+ case reflect.Slice:
+ replyv.Elem().Set(reflect.MakeSlice(mtype.ReplyType.Elem(), 0, 0))
+ }
+ return
+}
+
+func (server *Server) readRequestHeader(codec ServerCodec) (svc *service, mtype *methodType, req *Request, keepReading bool, err error) {
+ // Grab the request header.
+ req = server.getRequest()
+ err = codec.ReadRequestHeader(req)
+ if err != nil {
+ req = nil
+ if err == io.EOF || err == io.ErrUnexpectedEOF {
+ return
+ }
+ err = errors.New("rpc: server cannot decode request: " + err.Error())
+ return
+ }
+
+ // We read the header successfully. If we see an error now,
+ // we can still recover and move on to the next request.
+ keepReading = true
+
+ dot := strings.LastIndex(req.ServiceMethod, ".")
+ if dot < 0 {
+ err = errors.New("rpc: service/method request ill-formed: " + req.ServiceMethod)
+ return
+ }
+ serviceName := req.ServiceMethod[:dot]
+ methodName := req.ServiceMethod[dot+1:]
+
+ // Look up the request.
+ svci, ok := server.serviceMap.Load(serviceName)
+ if !ok {
+ err = errors.New("rpc: can't find service " + req.ServiceMethod)
+ return
+ }
+ svc = svci.(*service)
+ mtype = svc.method[methodName]
+ if mtype == nil {
+ err = errors.New("rpc: can't find method " + req.ServiceMethod)
+ }
+ return
+}
+
+// Accept accepts connections on the listener and serves requests
+// for each incoming connection. Accept blocks until the listener
+// returns a non-nil error. The caller typically invokes Accept in a
+// go statement.
+func (server *Server) Accept(lis net.Listener) {
+ for {
+ conn, err := lis.Accept()
+ if err != nil {
+ log.Print("rpc.Serve: accept:", err.Error())
+ return
+ }
+ go server.ServeConn(conn)
+ }
+}
+
+// Register publishes the receiver's methods in the DefaultServer.
+func Register(rcvr any) error { return DefaultServer.Register(rcvr) }
+
+// RegisterName is like Register but uses the provided name for the type
+// instead of the receiver's concrete type.
+func RegisterName(name string, rcvr any) error {
+ return DefaultServer.RegisterName(name, rcvr)
+}
+
+// A ServerCodec implements reading of RPC requests and writing of
+// RPC responses for the server side of an RPC session.
+// The server calls ReadRequestHeader and ReadRequestBody in pairs
+// to read requests from the connection, and it calls WriteResponse to
+// write a response back. The server calls Close when finished with the
+// connection. ReadRequestBody may be called with a nil
+// argument to force the body of the request to be read and discarded.
+// See NewClient's comment for information about concurrent access.
+type ServerCodec interface {
+ ReadRequestHeader(*Request) error
+ ReadRequestBody(any) error
+ WriteResponse(*Response, any) error
+
+ // Close can be called multiple times and must be idempotent.
+ Close() error
+}
+
+// ServeConn runs the DefaultServer on a single connection.
+// ServeConn blocks, serving the connection until the client hangs up.
+// The caller typically invokes ServeConn in a go statement.
+// ServeConn uses the gob wire format (see package gob) on the
+// connection. To use an alternate codec, use ServeCodec.
+// See NewClient's comment for information about concurrent access.
+func ServeConn(conn io.ReadWriteCloser) {
+ DefaultServer.ServeConn(conn)
+}
+
+// ServeCodec is like ServeConn but uses the specified codec to
+// decode requests and encode responses.
+func ServeCodec(codec ServerCodec) {
+ DefaultServer.ServeCodec(codec)
+}
+
+// ServeRequest is like ServeCodec but synchronously serves a single request.
+// It does not close the codec upon completion.
+func ServeRequest(codec ServerCodec) error {
+ return DefaultServer.ServeRequest(codec)
+}
+
+// Accept accepts connections on the listener and serves requests
+// to DefaultServer for each incoming connection.
+// Accept blocks; the caller typically invokes it in a go statement.
+func Accept(lis net.Listener) { DefaultServer.Accept(lis) }
+
+// Can connect to RPC service using HTTP CONNECT to rpcPath.
+var connected = "200 Connected to Go RPC"
+
+// ServeHTTP implements an http.Handler that answers RPC requests.
+func (server *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
+ if req.Method != "CONNECT" {
+ w.Header().Set("Content-Type", "text/plain; charset=utf-8")
+ w.WriteHeader(http.StatusMethodNotAllowed)
+ io.WriteString(w, "405 must CONNECT\n")
+ return
+ }
+ conn, _, err := w.(http.Hijacker).Hijack()
+ if err != nil {
+ log.Print("rpc hijacking ", req.RemoteAddr, ": ", err.Error())
+ return
+ }
+ io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n")
+ server.ServeConn(conn)
+}
+
+// HandleHTTP registers an HTTP handler for RPC messages on rpcPath,
+// and a debugging handler on debugPath.
+// It is still necessary to invoke http.Serve(), typically in a go statement.
+func (server *Server) HandleHTTP(rpcPath, debugPath string) {
+ http.Handle(rpcPath, server)
+ http.Handle(debugPath, debugHTTP{server})
+}
+
+// HandleHTTP registers an HTTP handler for RPC messages to DefaultServer
+// on DefaultRPCPath and a debugging handler on DefaultDebugPath.
+// It is still necessary to invoke http.Serve(), typically in a go statement.
+func HandleHTTP() {
+ DefaultServer.HandleHTTP(DefaultRPCPath, DefaultDebugPath)
+}
diff --git a/src/net/rpc/server_test.go b/src/net/rpc/server_test.go
new file mode 100644
index 0000000..6a94d6e
--- /dev/null
+++ b/src/net/rpc/server_test.go
@@ -0,0 +1,839 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package rpc
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "log"
+ "net"
+ "net/http/httptest"
+ "reflect"
+ "runtime"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+)
+
+var (
+ newServer *Server
+ serverAddr, newServerAddr string
+ httpServerAddr string
+ once, newOnce, httpOnce sync.Once
+)
+
+const (
+ newHttpPath = "/foo"
+)
+
+type Args struct {
+ A, B int
+}
+
+type Reply struct {
+ C int
+}
+
+type Arith int
+
+// Some of Arith's methods have value args, some have pointer args. That's deliberate.
+
+func (t *Arith) Add(args Args, reply *Reply) error {
+ reply.C = args.A + args.B
+ return nil
+}
+
+func (t *Arith) Mul(args *Args, reply *Reply) error {
+ reply.C = args.A * args.B
+ return nil
+}
+
+func (t *Arith) Div(args Args, reply *Reply) error {
+ if args.B == 0 {
+ return errors.New("divide by zero")
+ }
+ reply.C = args.A / args.B
+ return nil
+}
+
+func (t *Arith) String(args *Args, reply *string) error {
+ *reply = fmt.Sprintf("%d+%d=%d", args.A, args.B, args.A+args.B)
+ return nil
+}
+
+func (t *Arith) Scan(args string, reply *Reply) (err error) {
+ _, err = fmt.Sscan(args, &reply.C)
+ return
+}
+
+func (t *Arith) Error(args *Args, reply *Reply) error {
+ panic("ERROR")
+}
+
+func (t *Arith) SleepMilli(args *Args, reply *Reply) error {
+ time.Sleep(time.Duration(args.A) * time.Millisecond)
+ return nil
+}
+
+type hidden int
+
+func (t *hidden) Exported(args Args, reply *Reply) error {
+ reply.C = args.A + args.B
+ return nil
+}
+
+type Embed struct {
+ hidden
+}
+
+type BuiltinTypes struct{}
+
+func (BuiltinTypes) Map(args *Args, reply *map[int]int) error {
+ (*reply)[args.A] = args.B
+ return nil
+}
+
+func (BuiltinTypes) Slice(args *Args, reply *[]int) error {
+ *reply = append(*reply, args.A, args.B)
+ return nil
+}
+
+func (BuiltinTypes) Array(args *Args, reply *[2]int) error {
+ (*reply)[0] = args.A
+ (*reply)[1] = args.B
+ return nil
+}
+
+func listenTCP() (net.Listener, string) {
+ l, err := net.Listen("tcp", "127.0.0.1:0") // any available address
+ if err != nil {
+ log.Fatalf("net.Listen tcp :0: %v", err)
+ }
+ return l, l.Addr().String()
+}
+
+func startServer() {
+ Register(new(Arith))
+ Register(new(Embed))
+ RegisterName("net.rpc.Arith", new(Arith))
+ Register(BuiltinTypes{})
+
+ var l net.Listener
+ l, serverAddr = listenTCP()
+ log.Println("Test RPC server listening on", serverAddr)
+ go Accept(l)
+
+ HandleHTTP()
+ httpOnce.Do(startHttpServer)
+}
+
+func startNewServer() {
+ newServer = NewServer()
+ newServer.Register(new(Arith))
+ newServer.Register(new(Embed))
+ newServer.RegisterName("net.rpc.Arith", new(Arith))
+ newServer.RegisterName("newServer.Arith", new(Arith))
+
+ var l net.Listener
+ l, newServerAddr = listenTCP()
+ log.Println("NewServer test RPC server listening on", newServerAddr)
+ go newServer.Accept(l)
+
+ newServer.HandleHTTP(newHttpPath, "/bar")
+ httpOnce.Do(startHttpServer)
+}
+
+func startHttpServer() {
+ server := httptest.NewServer(nil)
+ httpServerAddr = server.Listener.Addr().String()
+ log.Println("Test HTTP RPC server listening on", httpServerAddr)
+}
+
+func TestRPC(t *testing.T) {
+ once.Do(startServer)
+ testRPC(t, serverAddr)
+ newOnce.Do(startNewServer)
+ testRPC(t, newServerAddr)
+ testNewServerRPC(t, newServerAddr)
+}
+
+func testRPC(t *testing.T, addr string) {
+ client, err := Dial("tcp", addr)
+ if err != nil {
+ t.Fatal("dialing", err)
+ }
+ defer client.Close()
+
+ // Synchronous calls
+ args := &Args{7, 8}
+ reply := new(Reply)
+ err = client.Call("Arith.Add", args, reply)
+ if err != nil {
+ t.Errorf("Add: expected no error but got string %q", err.Error())
+ }
+ if reply.C != args.A+args.B {
+ t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
+ }
+
+ // Methods exported from unexported embedded structs
+ args = &Args{7, 0}
+ reply = new(Reply)
+ err = client.Call("Embed.Exported", args, reply)
+ if err != nil {
+ t.Errorf("Add: expected no error but got string %q", err.Error())
+ }
+ if reply.C != args.A+args.B {
+ t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
+ }
+
+ // Nonexistent method
+ args = &Args{7, 0}
+ reply = new(Reply)
+ err = client.Call("Arith.BadOperation", args, reply)
+ // expect an error
+ if err == nil {
+ t.Error("BadOperation: expected error")
+ } else if !strings.HasPrefix(err.Error(), "rpc: can't find method ") {
+ t.Errorf("BadOperation: expected can't find method error; got %q", err)
+ }
+
+ // Unknown service
+ args = &Args{7, 8}
+ reply = new(Reply)
+ err = client.Call("Arith.Unknown", args, reply)
+ if err == nil {
+ t.Error("expected error calling unknown service")
+ } else if !strings.Contains(err.Error(), "method") {
+ t.Error("expected error about method; got", err)
+ }
+
+ // Out of order.
+ args = &Args{7, 8}
+ mulReply := new(Reply)
+ mulCall := client.Go("Arith.Mul", args, mulReply, nil)
+ addReply := new(Reply)
+ addCall := client.Go("Arith.Add", args, addReply, nil)
+
+ addCall = <-addCall.Done
+ if addCall.Error != nil {
+ t.Errorf("Add: expected no error but got string %q", addCall.Error.Error())
+ }
+ if addReply.C != args.A+args.B {
+ t.Errorf("Add: expected %d got %d", addReply.C, args.A+args.B)
+ }
+
+ mulCall = <-mulCall.Done
+ if mulCall.Error != nil {
+ t.Errorf("Mul: expected no error but got string %q", mulCall.Error.Error())
+ }
+ if mulReply.C != args.A*args.B {
+ t.Errorf("Mul: expected %d got %d", mulReply.C, args.A*args.B)
+ }
+
+ // Error test
+ args = &Args{7, 0}
+ reply = new(Reply)
+ err = client.Call("Arith.Div", args, reply)
+ // expect an error: zero divide
+ if err == nil {
+ t.Error("Div: expected error")
+ } else if err.Error() != "divide by zero" {
+ t.Error("Div: expected divide by zero error; got", err)
+ }
+
+ // Bad type.
+ reply = new(Reply)
+ err = client.Call("Arith.Add", reply, reply) // args, reply would be the correct thing to use
+ if err == nil {
+ t.Error("expected error calling Arith.Add with wrong arg type")
+ } else if !strings.Contains(err.Error(), "type") {
+ t.Error("expected error about type; got", err)
+ }
+
+ // Non-struct argument
+ const Val = 12345
+ str := fmt.Sprint(Val)
+ reply = new(Reply)
+ err = client.Call("Arith.Scan", &str, reply)
+ if err != nil {
+ t.Errorf("Scan: expected no error but got string %q", err.Error())
+ } else if reply.C != Val {
+ t.Errorf("Scan: expected %d got %d", Val, reply.C)
+ }
+
+ // Non-struct reply
+ args = &Args{27, 35}
+ str = ""
+ err = client.Call("Arith.String", args, &str)
+ if err != nil {
+ t.Errorf("String: expected no error but got string %q", err.Error())
+ }
+ expect := fmt.Sprintf("%d+%d=%d", args.A, args.B, args.A+args.B)
+ if str != expect {
+ t.Errorf("String: expected %s got %s", expect, str)
+ }
+
+ args = &Args{7, 8}
+ reply = new(Reply)
+ err = client.Call("Arith.Mul", args, reply)
+ if err != nil {
+ t.Errorf("Mul: expected no error but got string %q", err.Error())
+ }
+ if reply.C != args.A*args.B {
+ t.Errorf("Mul: expected %d got %d", reply.C, args.A*args.B)
+ }
+
+ // ServiceName contain "." character
+ args = &Args{7, 8}
+ reply = new(Reply)
+ err = client.Call("net.rpc.Arith.Add", args, reply)
+ if err != nil {
+ t.Errorf("Add: expected no error but got string %q", err.Error())
+ }
+ if reply.C != args.A+args.B {
+ t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
+ }
+}
+
+func testNewServerRPC(t *testing.T, addr string) {
+ client, err := Dial("tcp", addr)
+ if err != nil {
+ t.Fatal("dialing", err)
+ }
+ defer client.Close()
+
+ // Synchronous calls
+ args := &Args{7, 8}
+ reply := new(Reply)
+ err = client.Call("newServer.Arith.Add", args, reply)
+ if err != nil {
+ t.Errorf("Add: expected no error but got string %q", err.Error())
+ }
+ if reply.C != args.A+args.B {
+ t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
+ }
+}
+
+func TestHTTP(t *testing.T) {
+ once.Do(startServer)
+ testHTTPRPC(t, "")
+ newOnce.Do(startNewServer)
+ testHTTPRPC(t, newHttpPath)
+}
+
+func testHTTPRPC(t *testing.T, path string) {
+ var client *Client
+ var err error
+ if path == "" {
+ client, err = DialHTTP("tcp", httpServerAddr)
+ } else {
+ client, err = DialHTTPPath("tcp", httpServerAddr, path)
+ }
+ if err != nil {
+ t.Fatal("dialing", err)
+ }
+ defer client.Close()
+
+ // Synchronous calls
+ args := &Args{7, 8}
+ reply := new(Reply)
+ err = client.Call("Arith.Add", args, reply)
+ if err != nil {
+ t.Errorf("Add: expected no error but got string %q", err.Error())
+ }
+ if reply.C != args.A+args.B {
+ t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
+ }
+}
+
+func TestBuiltinTypes(t *testing.T) {
+ once.Do(startServer)
+
+ client, err := DialHTTP("tcp", httpServerAddr)
+ if err != nil {
+ t.Fatal("dialing", err)
+ }
+ defer client.Close()
+
+ // Map
+ args := &Args{7, 8}
+ replyMap := map[int]int{}
+ err = client.Call("BuiltinTypes.Map", args, &replyMap)
+ if err != nil {
+ t.Errorf("Map: expected no error but got string %q", err.Error())
+ }
+ if replyMap[args.A] != args.B {
+ t.Errorf("Map: expected %d got %d", args.B, replyMap[args.A])
+ }
+
+ // Slice
+ args = &Args{7, 8}
+ replySlice := []int{}
+ err = client.Call("BuiltinTypes.Slice", args, &replySlice)
+ if err != nil {
+ t.Errorf("Slice: expected no error but got string %q", err.Error())
+ }
+ if e := []int{args.A, args.B}; !reflect.DeepEqual(replySlice, e) {
+ t.Errorf("Slice: expected %v got %v", e, replySlice)
+ }
+
+ // Array
+ args = &Args{7, 8}
+ replyArray := [2]int{}
+ err = client.Call("BuiltinTypes.Array", args, &replyArray)
+ if err != nil {
+ t.Errorf("Array: expected no error but got string %q", err.Error())
+ }
+ if e := [2]int{args.A, args.B}; !reflect.DeepEqual(replyArray, e) {
+ t.Errorf("Array: expected %v got %v", e, replyArray)
+ }
+}
+
+// CodecEmulator provides a client-like api and a ServerCodec interface.
+// Can be used to test ServeRequest.
+type CodecEmulator struct {
+ server *Server
+ serviceMethod string
+ args *Args
+ reply *Reply
+ err error
+}
+
+func (codec *CodecEmulator) Call(serviceMethod string, args *Args, reply *Reply) error {
+ codec.serviceMethod = serviceMethod
+ codec.args = args
+ codec.reply = reply
+ codec.err = nil
+ var serverError error
+ if codec.server == nil {
+ serverError = ServeRequest(codec)
+ } else {
+ serverError = codec.server.ServeRequest(codec)
+ }
+ if codec.err == nil && serverError != nil {
+ codec.err = serverError
+ }
+ return codec.err
+}
+
+func (codec *CodecEmulator) ReadRequestHeader(req *Request) error {
+ req.ServiceMethod = codec.serviceMethod
+ req.Seq = 0
+ return nil
+}
+
+func (codec *CodecEmulator) ReadRequestBody(argv any) error {
+ if codec.args == nil {
+ return io.ErrUnexpectedEOF
+ }
+ *(argv.(*Args)) = *codec.args
+ return nil
+}
+
+func (codec *CodecEmulator) WriteResponse(resp *Response, reply any) error {
+ if resp.Error != "" {
+ codec.err = errors.New(resp.Error)
+ } else {
+ *codec.reply = *(reply.(*Reply))
+ }
+ return nil
+}
+
+func (codec *CodecEmulator) Close() error {
+ return nil
+}
+
+func TestServeRequest(t *testing.T) {
+ once.Do(startServer)
+ testServeRequest(t, nil)
+ newOnce.Do(startNewServer)
+ testServeRequest(t, newServer)
+}
+
+func testServeRequest(t *testing.T, server *Server) {
+ client := CodecEmulator{server: server}
+ defer client.Close()
+
+ args := &Args{7, 8}
+ reply := new(Reply)
+ err := client.Call("Arith.Add", args, reply)
+ if err != nil {
+ t.Errorf("Add: expected no error but got string %q", err.Error())
+ }
+ if reply.C != args.A+args.B {
+ t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
+ }
+
+ err = client.Call("Arith.Add", nil, reply)
+ if err == nil {
+ t.Errorf("expected error calling Arith.Add with nil arg")
+ }
+}
+
+type ReplyNotPointer int
+type ArgNotPublic int
+type ReplyNotPublic int
+type NeedsPtrType int
+type local struct{}
+
+func (t *ReplyNotPointer) ReplyNotPointer(args *Args, reply Reply) error {
+ return nil
+}
+
+func (t *ArgNotPublic) ArgNotPublic(args *local, reply *Reply) error {
+ return nil
+}
+
+func (t *ReplyNotPublic) ReplyNotPublic(args *Args, reply *local) error {
+ return nil
+}
+
+func (t *NeedsPtrType) NeedsPtrType(args *Args, reply *Reply) error {
+ return nil
+}
+
+// Check that registration handles lots of bad methods and a type with no suitable methods.
+func TestRegistrationError(t *testing.T) {
+ err := Register(new(ReplyNotPointer))
+ if err == nil {
+ t.Error("expected error registering ReplyNotPointer")
+ }
+ err = Register(new(ArgNotPublic))
+ if err == nil {
+ t.Error("expected error registering ArgNotPublic")
+ }
+ err = Register(new(ReplyNotPublic))
+ if err == nil {
+ t.Error("expected error registering ReplyNotPublic")
+ }
+ err = Register(NeedsPtrType(0))
+ if err == nil {
+ t.Error("expected error registering NeedsPtrType")
+ } else if !strings.Contains(err.Error(), "pointer") {
+ t.Error("expected hint when registering NeedsPtrType")
+ }
+}
+
+type WriteFailCodec int
+
+func (WriteFailCodec) WriteRequest(*Request, any) error {
+ // the panic caused by this error used to not unlock a lock.
+ return errors.New("fail")
+}
+
+func (WriteFailCodec) ReadResponseHeader(*Response) error {
+ select {}
+}
+
+func (WriteFailCodec) ReadResponseBody(any) error {
+ select {}
+}
+
+func (WriteFailCodec) Close() error {
+ return nil
+}
+
+func TestSendDeadlock(t *testing.T) {
+ client := NewClientWithCodec(WriteFailCodec(0))
+ defer client.Close()
+
+ done := make(chan bool)
+ go func() {
+ testSendDeadlock(client)
+ testSendDeadlock(client)
+ done <- true
+ }()
+ select {
+ case <-done:
+ return
+ case <-time.After(5 * time.Second):
+ t.Fatal("deadlock")
+ }
+}
+
+func testSendDeadlock(client *Client) {
+ defer func() {
+ recover()
+ }()
+ args := &Args{7, 8}
+ reply := new(Reply)
+ client.Call("Arith.Add", args, reply)
+}
+
+func dialDirect() (*Client, error) {
+ return Dial("tcp", serverAddr)
+}
+
+func dialHTTP() (*Client, error) {
+ return DialHTTP("tcp", httpServerAddr)
+}
+
+func countMallocs(dial func() (*Client, error), t *testing.T) float64 {
+ once.Do(startServer)
+ client, err := dial()
+ if err != nil {
+ t.Fatal("error dialing", err)
+ }
+ defer client.Close()
+
+ args := &Args{7, 8}
+ reply := new(Reply)
+ return testing.AllocsPerRun(100, func() {
+ err := client.Call("Arith.Add", args, reply)
+ if err != nil {
+ t.Errorf("Add: expected no error but got string %q", err.Error())
+ }
+ if reply.C != args.A+args.B {
+ t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
+ }
+ })
+}
+
+func TestCountMallocs(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping malloc count in short mode")
+ }
+ if runtime.GOMAXPROCS(0) > 1 {
+ t.Skip("skipping; GOMAXPROCS>1")
+ }
+ fmt.Printf("mallocs per rpc round trip: %v\n", countMallocs(dialDirect, t))
+}
+
+func TestCountMallocsOverHTTP(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping malloc count in short mode")
+ }
+ if runtime.GOMAXPROCS(0) > 1 {
+ t.Skip("skipping; GOMAXPROCS>1")
+ }
+ fmt.Printf("mallocs per HTTP rpc round trip: %v\n", countMallocs(dialHTTP, t))
+}
+
+type writeCrasher struct {
+ done chan bool
+}
+
+func (writeCrasher) Close() error {
+ return nil
+}
+
+func (w *writeCrasher) Read(p []byte) (int, error) {
+ <-w.done
+ return 0, io.EOF
+}
+
+func (writeCrasher) Write(p []byte) (int, error) {
+ return 0, errors.New("fake write failure")
+}
+
+func TestClientWriteError(t *testing.T) {
+ w := &writeCrasher{done: make(chan bool)}
+ c := NewClient(w)
+ defer c.Close()
+
+ res := false
+ err := c.Call("foo", 1, &res)
+ if err == nil {
+ t.Fatal("expected error")
+ }
+ if err.Error() != "fake write failure" {
+ t.Error("unexpected value of error:", err)
+ }
+ w.done <- true
+}
+
+func TestTCPClose(t *testing.T) {
+ once.Do(startServer)
+
+ client, err := dialHTTP()
+ if err != nil {
+ t.Fatalf("dialing: %v", err)
+ }
+ defer client.Close()
+
+ args := Args{17, 8}
+ var reply Reply
+ err = client.Call("Arith.Mul", args, &reply)
+ if err != nil {
+ t.Fatal("arith error:", err)
+ }
+ t.Logf("Arith: %d*%d=%d\n", args.A, args.B, reply)
+ if reply.C != args.A*args.B {
+ t.Errorf("Add: expected %d got %d", reply.C, args.A*args.B)
+ }
+}
+
+func TestErrorAfterClientClose(t *testing.T) {
+ once.Do(startServer)
+
+ client, err := dialHTTP()
+ if err != nil {
+ t.Fatalf("dialing: %v", err)
+ }
+ err = client.Close()
+ if err != nil {
+ t.Fatal("close error:", err)
+ }
+ err = client.Call("Arith.Add", &Args{7, 9}, new(Reply))
+ if err != ErrShutdown {
+ t.Errorf("Forever: expected ErrShutdown got %v", err)
+ }
+}
+
+// Tests the fix to issue 11221. Without the fix, this loops forever or crashes.
+func TestAcceptExitAfterListenerClose(t *testing.T) {
+ newServer := NewServer()
+ newServer.Register(new(Arith))
+ newServer.RegisterName("net.rpc.Arith", new(Arith))
+ newServer.RegisterName("newServer.Arith", new(Arith))
+
+ var l net.Listener
+ l, _ = listenTCP()
+ l.Close()
+ newServer.Accept(l)
+}
+
+func TestShutdown(t *testing.T) {
+ var l net.Listener
+ l, _ = listenTCP()
+ ch := make(chan net.Conn, 1)
+ go func() {
+ defer l.Close()
+ c, err := l.Accept()
+ if err != nil {
+ t.Error(err)
+ }
+ ch <- c
+ }()
+ c, err := net.Dial("tcp", l.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ c1 := <-ch
+ if c1 == nil {
+ t.Fatal(err)
+ }
+
+ newServer := NewServer()
+ newServer.Register(new(Arith))
+ go newServer.ServeConn(c1)
+
+ args := &Args{7, 8}
+ reply := new(Reply)
+ client := NewClient(c)
+ err = client.Call("Arith.Add", args, reply)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // On an unloaded system 10ms is usually enough to fail 100% of the time
+ // with a broken server. On a loaded system, a broken server might incorrectly
+ // be reported as passing, but we're OK with that kind of flakiness.
+ // If the code is correct, this test will never fail, regardless of timeout.
+ args.A = 10 // 10 ms
+ done := make(chan *Call, 1)
+ call := client.Go("Arith.SleepMilli", args, reply, done)
+ c.(*net.TCPConn).CloseWrite()
+ <-done
+ if call.Error != nil {
+ t.Fatal(err)
+ }
+}
+
+func benchmarkEndToEnd(dial func() (*Client, error), b *testing.B) {
+ once.Do(startServer)
+ client, err := dial()
+ if err != nil {
+ b.Fatal("error dialing:", err)
+ }
+ defer client.Close()
+
+ // Synchronous calls
+ args := &Args{7, 8}
+ b.ResetTimer()
+
+ b.RunParallel(func(pb *testing.PB) {
+ reply := new(Reply)
+ for pb.Next() {
+ err := client.Call("Arith.Add", args, reply)
+ if err != nil {
+ b.Fatalf("rpc error: Add: expected no error but got string %q", err.Error())
+ }
+ if reply.C != args.A+args.B {
+ b.Fatalf("rpc error: Add: expected %d got %d", reply.C, args.A+args.B)
+ }
+ }
+ })
+}
+
+func benchmarkEndToEndAsync(dial func() (*Client, error), b *testing.B) {
+ if b.N == 0 {
+ return
+ }
+ const MaxConcurrentCalls = 100
+ once.Do(startServer)
+ client, err := dial()
+ if err != nil {
+ b.Fatal("error dialing:", err)
+ }
+ defer client.Close()
+
+ // Asynchronous calls
+ args := &Args{7, 8}
+ procs := 4 * runtime.GOMAXPROCS(-1)
+ send := int32(b.N)
+ recv := int32(b.N)
+ var wg sync.WaitGroup
+ wg.Add(procs)
+ gate := make(chan bool, MaxConcurrentCalls)
+ res := make(chan *Call, MaxConcurrentCalls)
+ b.ResetTimer()
+
+ for p := 0; p < procs; p++ {
+ go func() {
+ for atomic.AddInt32(&send, -1) >= 0 {
+ gate <- true
+ reply := new(Reply)
+ client.Go("Arith.Add", args, reply, res)
+ }
+ }()
+ go func() {
+ for call := range res {
+ A := call.Args.(*Args).A
+ B := call.Args.(*Args).B
+ C := call.Reply.(*Reply).C
+ if A+B != C {
+ b.Errorf("incorrect reply: Add: expected %d got %d", A+B, C)
+ return
+ }
+ <-gate
+ if atomic.AddInt32(&recv, -1) == 0 {
+ close(res)
+ }
+ }
+ wg.Done()
+ }()
+ }
+ wg.Wait()
+}
+
+func BenchmarkEndToEnd(b *testing.B) {
+ benchmarkEndToEnd(dialDirect, b)
+}
+
+func BenchmarkEndToEndHTTP(b *testing.B) {
+ benchmarkEndToEnd(dialHTTP, b)
+}
+
+func BenchmarkEndToEndAsync(b *testing.B) {
+ benchmarkEndToEndAsync(dialDirect, b)
+}
+
+func BenchmarkEndToEndAsyncHTTP(b *testing.B) {
+ benchmarkEndToEndAsync(dialHTTP, b)
+}
diff --git a/src/net/sendfile_linux.go b/src/net/sendfile_linux.go
new file mode 100644
index 0000000..9a7d005
--- /dev/null
+++ b/src/net/sendfile_linux.go
@@ -0,0 +1,53 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "internal/poll"
+ "io"
+ "os"
+)
+
+// sendFile copies the contents of r to c using the sendfile
+// system call to minimize copies.
+//
+// if handled == true, sendFile returns the number (potentially zero) of bytes
+// copied and any non-EOF error.
+//
+// if handled == false, sendFile performed no work.
+func sendFile(c *netFD, r io.Reader) (written int64, err error, handled bool) {
+ var remain int64 = 1<<63 - 1 // by default, copy until EOF
+
+ lr, ok := r.(*io.LimitedReader)
+ if ok {
+ remain, r = lr.N, lr.R
+ if remain <= 0 {
+ return 0, nil, true
+ }
+ }
+ f, ok := r.(*os.File)
+ if !ok {
+ return 0, nil, false
+ }
+
+ sc, err := f.SyscallConn()
+ if err != nil {
+ return 0, nil, false
+ }
+
+ var werr error
+ err = sc.Read(func(fd uintptr) bool {
+ written, werr, handled = poll.SendFile(&c.pfd, int(fd), remain)
+ return true
+ })
+ if err == nil {
+ err = werr
+ }
+
+ if lr != nil {
+ lr.N = remain - written
+ }
+ return written, wrapSyscallError("sendfile", err), handled
+}
diff --git a/src/net/sendfile_linux_test.go b/src/net/sendfile_linux_test.go
new file mode 100644
index 0000000..8cd6acc
--- /dev/null
+++ b/src/net/sendfile_linux_test.go
@@ -0,0 +1,77 @@
+// Copyright 2022 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build linux
+// +build linux
+
+package net
+
+import (
+ "io"
+ "os"
+ "strconv"
+ "testing"
+)
+
+func BenchmarkSendFile(b *testing.B) {
+ for i := 0; i <= 10; i++ {
+ size := 1 << (i + 10)
+ bench := sendFileBench{chunkSize: size}
+ b.Run(strconv.Itoa(size), bench.benchSendFile)
+ }
+}
+
+type sendFileBench struct {
+ chunkSize int
+}
+
+func (bench sendFileBench) benchSendFile(b *testing.B) {
+ fileSize := b.N * bench.chunkSize
+ f := createTempFile(b, fileSize)
+ fileName := f.Name()
+ defer os.Remove(fileName)
+ defer f.Close()
+
+ client, server := spliceTestSocketPair(b, "tcp")
+ defer server.Close()
+
+ cleanUp, err := startSpliceClient(client, "r", bench.chunkSize, fileSize)
+ if err != nil {
+ b.Fatal(err)
+ }
+ defer cleanUp()
+
+ b.ReportAllocs()
+ b.SetBytes(int64(bench.chunkSize))
+ b.ResetTimer()
+
+ // Data go from file to socket via sendfile(2).
+ sent, err := io.Copy(server, f)
+ if err != nil {
+ b.Fatalf("failed to copy data with sendfile, error: %v", err)
+ }
+ if sent != int64(fileSize) {
+ b.Fatalf("bytes sent mismatch\n\texpect: %d\n\tgot: %d", fileSize, sent)
+ }
+}
+
+func createTempFile(b *testing.B, size int) *os.File {
+ f, err := os.CreateTemp("", "linux-sendfile-test")
+ if err != nil {
+ b.Fatalf("failed to create temporary file: %v", err)
+ }
+
+ data := make([]byte, size)
+ if _, err := f.Write(data); err != nil {
+ b.Fatalf("failed to create and feed the file: %v", err)
+ }
+ if err := f.Sync(); err != nil {
+ b.Fatalf("failed to save the file: %v", err)
+ }
+ if _, err := f.Seek(0, io.SeekStart); err != nil {
+ b.Fatalf("failed to rewind the file: %v", err)
+ }
+
+ return f
+}
diff --git a/src/net/sendfile_stub.go b/src/net/sendfile_stub.go
new file mode 100644
index 0000000..c7a2e6a
--- /dev/null
+++ b/src/net/sendfile_stub.go
@@ -0,0 +1,13 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build aix || (js && wasm) || netbsd || openbsd || ios || wasip1
+
+package net
+
+import "io"
+
+func sendFile(c *netFD, r io.Reader) (n int64, err error, handled bool) {
+ return 0, nil, false
+}
diff --git a/src/net/sendfile_test.go b/src/net/sendfile_test.go
new file mode 100644
index 0000000..44a87a1
--- /dev/null
+++ b/src/net/sendfile_test.go
@@ -0,0 +1,364 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !js && !wasip1
+
+package net
+
+import (
+ "bytes"
+ "crypto/sha256"
+ "encoding/hex"
+ "errors"
+ "fmt"
+ "io"
+ "os"
+ "runtime"
+ "sync"
+ "testing"
+ "time"
+)
+
+const (
+ newton = "../testdata/Isaac.Newton-Opticks.txt"
+ newtonLen = 567198
+ newtonSHA256 = "d4a9ac22462b35e7821a4f2706c211093da678620a8f9997989ee7cf8d507bbd"
+)
+
+func TestSendfile(t *testing.T) {
+ ln := newLocalListener(t, "tcp")
+ defer ln.Close()
+
+ errc := make(chan error, 1)
+ go func(ln Listener) {
+ // Wait for a connection.
+ conn, err := ln.Accept()
+ if err != nil {
+ errc <- err
+ close(errc)
+ return
+ }
+
+ go func() {
+ defer close(errc)
+ defer conn.Close()
+
+ f, err := os.Open(newton)
+ if err != nil {
+ errc <- err
+ return
+ }
+ defer f.Close()
+
+ // Return file data using io.Copy, which should use
+ // sendFile if available.
+ sbytes, err := io.Copy(conn, f)
+ if err != nil {
+ errc <- err
+ return
+ }
+
+ if sbytes != newtonLen {
+ errc <- fmt.Errorf("sent %d bytes; expected %d", sbytes, newtonLen)
+ return
+ }
+ }()
+ }(ln)
+
+ // Connect to listener to retrieve file and verify digest matches
+ // expected.
+ c, err := Dial("tcp", ln.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+
+ h := sha256.New()
+ rbytes, err := io.Copy(h, c)
+ if err != nil {
+ t.Error(err)
+ }
+
+ if rbytes != newtonLen {
+ t.Errorf("received %d bytes; expected %d", rbytes, newtonLen)
+ }
+
+ if res := hex.EncodeToString(h.Sum(nil)); res != newtonSHA256 {
+ t.Error("retrieved data hash did not match")
+ }
+
+ for err := range errc {
+ t.Error(err)
+ }
+}
+
+func TestSendfileParts(t *testing.T) {
+ ln := newLocalListener(t, "tcp")
+ defer ln.Close()
+
+ errc := make(chan error, 1)
+ go func(ln Listener) {
+ // Wait for a connection.
+ conn, err := ln.Accept()
+ if err != nil {
+ errc <- err
+ close(errc)
+ return
+ }
+
+ go func() {
+ defer close(errc)
+ defer conn.Close()
+
+ f, err := os.Open(newton)
+ if err != nil {
+ errc <- err
+ return
+ }
+ defer f.Close()
+
+ for i := 0; i < 3; i++ {
+ // Return file data using io.CopyN, which should use
+ // sendFile if available.
+ _, err = io.CopyN(conn, f, 3)
+ if err != nil {
+ errc <- err
+ return
+ }
+ }
+ }()
+ }(ln)
+
+ c, err := Dial("tcp", ln.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+
+ buf := new(bytes.Buffer)
+ buf.ReadFrom(c)
+
+ if want, have := "Produced ", buf.String(); have != want {
+ t.Errorf("unexpected server reply %q, want %q", have, want)
+ }
+
+ for err := range errc {
+ t.Error(err)
+ }
+}
+
+func TestSendfileSeeked(t *testing.T) {
+ ln := newLocalListener(t, "tcp")
+ defer ln.Close()
+
+ const seekTo = 65 << 10
+ const sendSize = 10 << 10
+
+ errc := make(chan error, 1)
+ go func(ln Listener) {
+ // Wait for a connection.
+ conn, err := ln.Accept()
+ if err != nil {
+ errc <- err
+ close(errc)
+ return
+ }
+
+ go func() {
+ defer close(errc)
+ defer conn.Close()
+
+ f, err := os.Open(newton)
+ if err != nil {
+ errc <- err
+ return
+ }
+ defer f.Close()
+ if _, err := f.Seek(seekTo, io.SeekStart); err != nil {
+ errc <- err
+ return
+ }
+
+ _, err = io.CopyN(conn, f, sendSize)
+ if err != nil {
+ errc <- err
+ return
+ }
+ }()
+ }(ln)
+
+ c, err := Dial("tcp", ln.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+
+ buf := new(bytes.Buffer)
+ buf.ReadFrom(c)
+
+ if buf.Len() != sendSize {
+ t.Errorf("Got %d bytes; want %d", buf.Len(), sendSize)
+ }
+
+ for err := range errc {
+ t.Error(err)
+ }
+}
+
+// Test that sendfile doesn't put a pipe into blocking mode.
+func TestSendfilePipe(t *testing.T) {
+ switch runtime.GOOS {
+ case "plan9", "windows":
+ // These systems don't support deadlines on pipes.
+ t.Skipf("skipping on %s", runtime.GOOS)
+ }
+
+ t.Parallel()
+
+ ln := newLocalListener(t, "tcp")
+ defer ln.Close()
+
+ r, w, err := os.Pipe()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer w.Close()
+ defer r.Close()
+
+ copied := make(chan bool)
+
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ // Accept a connection and copy 1 byte from the read end of
+ // the pipe to the connection. This will call into sendfile.
+ defer wg.Done()
+ conn, err := ln.Accept()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ defer conn.Close()
+ _, err = io.CopyN(conn, r, 1)
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ // Signal the main goroutine that we've copied the byte.
+ close(copied)
+ }()
+
+ wg.Add(1)
+ go func() {
+ // Write 1 byte to the write end of the pipe.
+ defer wg.Done()
+ _, err := w.Write([]byte{'a'})
+ if err != nil {
+ t.Error(err)
+ }
+ }()
+
+ wg.Add(1)
+ go func() {
+ // Connect to the server started two goroutines up and
+ // discard any data that it writes.
+ defer wg.Done()
+ conn, err := Dial("tcp", ln.Addr().String())
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ defer conn.Close()
+ io.Copy(io.Discard, conn)
+ }()
+
+ // Wait for the byte to be copied, meaning that sendfile has
+ // been called on the pipe.
+ <-copied
+
+ // Set a very short deadline on the read end of the pipe.
+ if err := r.SetDeadline(time.Now().Add(time.Microsecond)); err != nil {
+ t.Fatal(err)
+ }
+
+ wg.Add(1)
+ go func() {
+ // Wait for much longer than the deadline and write a byte
+ // to the pipe.
+ defer wg.Done()
+ time.Sleep(50 * time.Millisecond)
+ w.Write([]byte{'b'})
+ }()
+
+ // If this read does not time out, the pipe was incorrectly
+ // put into blocking mode.
+ _, err = r.Read(make([]byte, 1))
+ if err == nil {
+ t.Error("Read did not time out")
+ } else if !os.IsTimeout(err) {
+ t.Errorf("got error %v, expected a time out", err)
+ }
+
+ wg.Wait()
+}
+
+// Issue 43822: tests that returns EOF when conn write timeout.
+func TestSendfileOnWriteTimeoutExceeded(t *testing.T) {
+ ln := newLocalListener(t, "tcp")
+ defer ln.Close()
+
+ errc := make(chan error, 1)
+ go func(ln Listener) (retErr error) {
+ defer func() {
+ errc <- retErr
+ close(errc)
+ }()
+
+ conn, err := ln.Accept()
+ if err != nil {
+ return err
+ }
+ defer conn.Close()
+
+ // Set the write deadline in the past(1h ago). It makes
+ // sure that it is always write timeout.
+ if err := conn.SetWriteDeadline(time.Now().Add(-1 * time.Hour)); err != nil {
+ return err
+ }
+
+ f, err := os.Open(newton)
+ if err != nil {
+ return err
+ }
+ defer f.Close()
+
+ _, err = io.Copy(conn, f)
+ if errors.Is(err, os.ErrDeadlineExceeded) {
+ return nil
+ }
+
+ if err == nil {
+ err = fmt.Errorf("expected ErrDeadlineExceeded, but got nil")
+ }
+ return err
+ }(ln)
+
+ conn, err := Dial("tcp", ln.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conn.Close()
+
+ n, err := io.Copy(io.Discard, conn)
+ if err != nil {
+ t.Fatalf("expected nil error, but got %v", err)
+ }
+ if n != 0 {
+ t.Fatalf("expected receive zero, but got %d byte(s)", n)
+ }
+
+ if err := <-errc; err != nil {
+ t.Fatal(err)
+ }
+}
diff --git a/src/net/sendfile_unix_alt.go b/src/net/sendfile_unix_alt.go
new file mode 100644
index 0000000..b867717
--- /dev/null
+++ b/src/net/sendfile_unix_alt.go
@@ -0,0 +1,85 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build (darwin && !ios) || dragonfly || freebsd || solaris
+
+package net
+
+import (
+ "internal/poll"
+ "io"
+ "os"
+)
+
+// sendFile copies the contents of r to c using the sendfile
+// system call to minimize copies.
+//
+// if handled == true, sendFile returns the number of bytes copied and any
+// non-EOF error.
+//
+// if handled == false, sendFile performed no work.
+func sendFile(c *netFD, r io.Reader) (written int64, err error, handled bool) {
+ // Darwin, FreeBSD, DragonFly and Solaris use 0 as the "until EOF" value.
+ // If you pass in more bytes than the file contains, it will
+ // loop back to the beginning ad nauseam until it's sent
+ // exactly the number of bytes told to. As such, we need to
+ // know exactly how many bytes to send.
+ var remain int64 = 0
+
+ lr, ok := r.(*io.LimitedReader)
+ if ok {
+ remain, r = lr.N, lr.R
+ if remain <= 0 {
+ return 0, nil, true
+ }
+ }
+ f, ok := r.(*os.File)
+ if !ok {
+ return 0, nil, false
+ }
+
+ if remain == 0 {
+ fi, err := f.Stat()
+ if err != nil {
+ return 0, err, false
+ }
+
+ remain = fi.Size()
+ }
+
+ // The other quirk with Darwin/FreeBSD/DragonFly/Solaris's sendfile
+ // implementation is that it doesn't use the current position
+ // of the file -- if you pass it offset 0, it starts from
+ // offset 0. There's no way to tell it "start from current
+ // position", so we have to manage that explicitly.
+ pos, err := f.Seek(0, io.SeekCurrent)
+ if err != nil {
+ return 0, err, false
+ }
+
+ sc, err := f.SyscallConn()
+ if err != nil {
+ return 0, nil, false
+ }
+
+ var werr error
+ err = sc.Read(func(fd uintptr) bool {
+ written, werr = poll.SendFile(&c.pfd, int(fd), pos, remain)
+ return true
+ })
+ if err == nil {
+ err = werr
+ }
+
+ if lr != nil {
+ lr.N = remain - written
+ }
+
+ _, err1 := f.Seek(written, io.SeekCurrent)
+ if err1 != nil && err == nil {
+ return written, err1, written > 0
+ }
+
+ return written, wrapSyscallError("sendfile", err), written > 0
+}
diff --git a/src/net/sendfile_windows.go b/src/net/sendfile_windows.go
new file mode 100644
index 0000000..59b1b0d
--- /dev/null
+++ b/src/net/sendfile_windows.go
@@ -0,0 +1,47 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "internal/poll"
+ "io"
+ "os"
+ "syscall"
+)
+
+// sendFile copies the contents of r to c using the TransmitFile
+// system call to minimize copies.
+//
+// if handled == true, sendFile returns the number of bytes copied and any
+// non-EOF error.
+//
+// if handled == false, sendFile performed no work.
+func sendFile(fd *netFD, r io.Reader) (written int64, err error, handled bool) {
+ var n int64 = 0 // by default, copy until EOF.
+
+ lr, ok := r.(*io.LimitedReader)
+ if ok {
+ n, r = lr.N, lr.R
+ if n <= 0 {
+ return 0, nil, true
+ }
+ }
+
+ f, ok := r.(*os.File)
+ if !ok {
+ return 0, nil, false
+ }
+
+ written, err = poll.SendFile(&fd.pfd, syscall.Handle(f.Fd()), n)
+ if err != nil {
+ err = wrapSyscallError("transmitfile", err)
+ }
+
+ // If any byte was copied, regardless of any error
+ // encountered mid-way, handled must be set to true.
+ handled = written > 0
+
+ return
+}
diff --git a/src/net/server_test.go b/src/net/server_test.go
new file mode 100644
index 0000000..2ff0689
--- /dev/null
+++ b/src/net/server_test.go
@@ -0,0 +1,383 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !js && !wasip1
+
+package net
+
+import (
+ "os"
+ "testing"
+)
+
+var tcpServerTests = []struct {
+ snet, saddr string // server endpoint
+ tnet, taddr string // target endpoint for client
+}{
+ {snet: "tcp", saddr: ":0", tnet: "tcp", taddr: "127.0.0.1"},
+ {snet: "tcp", saddr: "0.0.0.0:0", tnet: "tcp", taddr: "127.0.0.1"},
+ {snet: "tcp", saddr: "[::ffff:0.0.0.0]:0", tnet: "tcp", taddr: "127.0.0.1"},
+ {snet: "tcp", saddr: "[::]:0", tnet: "tcp", taddr: "::1"},
+
+ {snet: "tcp", saddr: ":0", tnet: "tcp", taddr: "::1"},
+ {snet: "tcp", saddr: "0.0.0.0:0", tnet: "tcp", taddr: "::1"},
+ {snet: "tcp", saddr: "[::ffff:0.0.0.0]:0", tnet: "tcp", taddr: "::1"},
+ {snet: "tcp", saddr: "[::]:0", tnet: "tcp", taddr: "127.0.0.1"},
+
+ {snet: "tcp", saddr: ":0", tnet: "tcp4", taddr: "127.0.0.1"},
+ {snet: "tcp", saddr: "0.0.0.0:0", tnet: "tcp4", taddr: "127.0.0.1"},
+ {snet: "tcp", saddr: "[::ffff:0.0.0.0]:0", tnet: "tcp4", taddr: "127.0.0.1"},
+ {snet: "tcp", saddr: "[::]:0", tnet: "tcp6", taddr: "::1"},
+
+ {snet: "tcp", saddr: ":0", tnet: "tcp6", taddr: "::1"},
+ {snet: "tcp", saddr: "0.0.0.0:0", tnet: "tcp6", taddr: "::1"},
+ {snet: "tcp", saddr: "[::ffff:0.0.0.0]:0", tnet: "tcp6", taddr: "::1"},
+ {snet: "tcp", saddr: "[::]:0", tnet: "tcp4", taddr: "127.0.0.1"},
+
+ {snet: "tcp", saddr: "127.0.0.1:0", tnet: "tcp", taddr: "127.0.0.1"},
+ {snet: "tcp", saddr: "[::ffff:127.0.0.1]:0", tnet: "tcp", taddr: "127.0.0.1"},
+ {snet: "tcp", saddr: "[::1]:0", tnet: "tcp", taddr: "::1"},
+
+ {snet: "tcp4", saddr: ":0", tnet: "tcp4", taddr: "127.0.0.1"},
+ {snet: "tcp4", saddr: "0.0.0.0:0", tnet: "tcp4", taddr: "127.0.0.1"},
+ {snet: "tcp4", saddr: "[::ffff:0.0.0.0]:0", tnet: "tcp4", taddr: "127.0.0.1"},
+
+ {snet: "tcp4", saddr: "127.0.0.1:0", tnet: "tcp4", taddr: "127.0.0.1"},
+
+ {snet: "tcp6", saddr: ":0", tnet: "tcp6", taddr: "::1"},
+ {snet: "tcp6", saddr: "[::]:0", tnet: "tcp6", taddr: "::1"},
+
+ {snet: "tcp6", saddr: "[::1]:0", tnet: "tcp6", taddr: "::1"},
+}
+
+// TestTCPServer tests concurrent accept-read-write servers.
+func TestTCPServer(t *testing.T) {
+ const N = 3
+
+ for i, tt := range tcpServerTests {
+ t.Run(tt.snet+" "+tt.saddr+"<-"+tt.taddr, func(t *testing.T) {
+ if !testableListenArgs(tt.snet, tt.saddr, tt.taddr) {
+ t.Skip("not testable")
+ }
+
+ ln, err := Listen(tt.snet, tt.saddr)
+ if err != nil {
+ if perr := parseDialError(err); perr != nil {
+ t.Error(perr)
+ }
+ t.Fatal(err)
+ }
+
+ var lss []*localServer
+ var tpchs []chan error
+ defer func() {
+ for _, ls := range lss {
+ ls.teardown()
+ }
+ }()
+ for i := 0; i < N; i++ {
+ ls := (&streamListener{Listener: ln}).newLocalServer()
+ lss = append(lss, ls)
+ tpchs = append(tpchs, make(chan error, 1))
+ }
+ for i := 0; i < N; i++ {
+ ch := tpchs[i]
+ handler := func(ls *localServer, ln Listener) { ls.transponder(ln, ch) }
+ if err := lss[i].buildup(handler); err != nil {
+ t.Fatal(err)
+ }
+ }
+
+ var trchs []chan error
+ for i := 0; i < N; i++ {
+ _, port, err := SplitHostPort(lss[i].Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ d := Dialer{Timeout: someTimeout}
+ c, err := d.Dial(tt.tnet, JoinHostPort(tt.taddr, port))
+ if err != nil {
+ if perr := parseDialError(err); perr != nil {
+ t.Error(perr)
+ }
+ t.Fatal(err)
+ }
+ defer c.Close()
+ trchs = append(trchs, make(chan error, 1))
+ go transceiver(c, []byte("TCP SERVER TEST"), trchs[i])
+ }
+
+ for _, ch := range trchs {
+ for err := range ch {
+ t.Errorf("#%d: %v", i, err)
+ }
+ }
+ for _, ch := range tpchs {
+ for err := range ch {
+ t.Errorf("#%d: %v", i, err)
+ }
+ }
+ })
+ }
+}
+
+// TestUnixAndUnixpacketServer tests concurrent accept-read-write
+// servers
+func TestUnixAndUnixpacketServer(t *testing.T) {
+ var unixAndUnixpacketServerTests = []struct {
+ network, address string
+ }{
+ {"unix", testUnixAddr(t)},
+ {"unix", "@nettest/go/unix"},
+
+ {"unixpacket", testUnixAddr(t)},
+ {"unixpacket", "@nettest/go/unixpacket"},
+ }
+
+ const N = 3
+
+ for i, tt := range unixAndUnixpacketServerTests {
+ if !testableListenArgs(tt.network, tt.address, "") {
+ t.Logf("skipping %s test", tt.network+" "+tt.address)
+ continue
+ }
+
+ ln, err := Listen(tt.network, tt.address)
+ if err != nil {
+ if perr := parseDialError(err); perr != nil {
+ t.Error(perr)
+ }
+ t.Fatal(err)
+ }
+
+ var lss []*localServer
+ var tpchs []chan error
+ defer func() {
+ for _, ls := range lss {
+ ls.teardown()
+ }
+ }()
+ for i := 0; i < N; i++ {
+ ls := (&streamListener{Listener: ln}).newLocalServer()
+ lss = append(lss, ls)
+ tpchs = append(tpchs, make(chan error, 1))
+ }
+ for i := 0; i < N; i++ {
+ ch := tpchs[i]
+ handler := func(ls *localServer, ln Listener) { ls.transponder(ln, ch) }
+ if err := lss[i].buildup(handler); err != nil {
+ t.Fatal(err)
+ }
+ }
+
+ var trchs []chan error
+ for i := 0; i < N; i++ {
+ d := Dialer{Timeout: someTimeout}
+ c, err := d.Dial(lss[i].Listener.Addr().Network(), lss[i].Listener.Addr().String())
+ if err != nil {
+ if perr := parseDialError(err); perr != nil {
+ t.Error(perr)
+ }
+ t.Fatal(err)
+ }
+
+ if addr := c.LocalAddr(); addr != nil {
+ t.Logf("connected %s->%s", addr, lss[i].Listener.Addr())
+ }
+
+ defer c.Close()
+ trchs = append(trchs, make(chan error, 1))
+ go transceiver(c, []byte("UNIX AND UNIXPACKET SERVER TEST"), trchs[i])
+ }
+
+ for _, ch := range trchs {
+ for err := range ch {
+ t.Errorf("#%d: %v", i, err)
+ }
+ }
+ for _, ch := range tpchs {
+ for err := range ch {
+ t.Errorf("#%d: %v", i, err)
+ }
+ }
+ }
+}
+
+var udpServerTests = []struct {
+ snet, saddr string // server endpoint
+ tnet, taddr string // target endpoint for client
+ dial bool // test with Dial
+}{
+ {snet: "udp", saddr: ":0", tnet: "udp", taddr: "127.0.0.1"},
+ {snet: "udp", saddr: "0.0.0.0:0", tnet: "udp", taddr: "127.0.0.1"},
+ {snet: "udp", saddr: "[::ffff:0.0.0.0]:0", tnet: "udp", taddr: "127.0.0.1"},
+ {snet: "udp", saddr: "[::]:0", tnet: "udp", taddr: "::1"},
+
+ {snet: "udp", saddr: ":0", tnet: "udp", taddr: "::1"},
+ {snet: "udp", saddr: "0.0.0.0:0", tnet: "udp", taddr: "::1"},
+ {snet: "udp", saddr: "[::ffff:0.0.0.0]:0", tnet: "udp", taddr: "::1"},
+ {snet: "udp", saddr: "[::]:0", tnet: "udp", taddr: "127.0.0.1"},
+
+ {snet: "udp", saddr: ":0", tnet: "udp4", taddr: "127.0.0.1"},
+ {snet: "udp", saddr: "0.0.0.0:0", tnet: "udp4", taddr: "127.0.0.1"},
+ {snet: "udp", saddr: "[::ffff:0.0.0.0]:0", tnet: "udp4", taddr: "127.0.0.1"},
+ {snet: "udp", saddr: "[::]:0", tnet: "udp6", taddr: "::1"},
+
+ {snet: "udp", saddr: ":0", tnet: "udp6", taddr: "::1"},
+ {snet: "udp", saddr: "0.0.0.0:0", tnet: "udp6", taddr: "::1"},
+ {snet: "udp", saddr: "[::ffff:0.0.0.0]:0", tnet: "udp6", taddr: "::1"},
+ {snet: "udp", saddr: "[::]:0", tnet: "udp4", taddr: "127.0.0.1"},
+
+ {snet: "udp", saddr: "127.0.0.1:0", tnet: "udp", taddr: "127.0.0.1"},
+ {snet: "udp", saddr: "[::ffff:127.0.0.1]:0", tnet: "udp", taddr: "127.0.0.1"},
+ {snet: "udp", saddr: "[::1]:0", tnet: "udp", taddr: "::1"},
+
+ {snet: "udp4", saddr: ":0", tnet: "udp4", taddr: "127.0.0.1"},
+ {snet: "udp4", saddr: "0.0.0.0:0", tnet: "udp4", taddr: "127.0.0.1"},
+ {snet: "udp4", saddr: "[::ffff:0.0.0.0]:0", tnet: "udp4", taddr: "127.0.0.1"},
+
+ {snet: "udp4", saddr: "127.0.0.1:0", tnet: "udp4", taddr: "127.0.0.1"},
+
+ {snet: "udp6", saddr: ":0", tnet: "udp6", taddr: "::1"},
+ {snet: "udp6", saddr: "[::]:0", tnet: "udp6", taddr: "::1"},
+
+ {snet: "udp6", saddr: "[::1]:0", tnet: "udp6", taddr: "::1"},
+
+ {snet: "udp", saddr: "127.0.0.1:0", tnet: "udp", taddr: "127.0.0.1", dial: true},
+
+ {snet: "udp", saddr: "[::1]:0", tnet: "udp", taddr: "::1", dial: true},
+}
+
+func TestUDPServer(t *testing.T) {
+ for i, tt := range udpServerTests {
+ if !testableListenArgs(tt.snet, tt.saddr, tt.taddr) {
+ t.Logf("skipping %s test", tt.snet+" "+tt.saddr+"<-"+tt.taddr)
+ continue
+ }
+
+ c1, err := ListenPacket(tt.snet, tt.saddr)
+ if err != nil {
+ if perr := parseDialError(err); perr != nil {
+ t.Error(perr)
+ }
+ t.Fatal(err)
+ }
+
+ ls := (&packetListener{PacketConn: c1}).newLocalServer()
+ defer ls.teardown()
+ tpch := make(chan error, 1)
+ handler := func(ls *localPacketServer, c PacketConn) { packetTransponder(c, tpch) }
+ if err := ls.buildup(handler); err != nil {
+ t.Fatal(err)
+ }
+
+ trch := make(chan error, 1)
+ _, port, err := SplitHostPort(ls.PacketConn.LocalAddr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ if tt.dial {
+ d := Dialer{Timeout: someTimeout}
+ c2, err := d.Dial(tt.tnet, JoinHostPort(tt.taddr, port))
+ if err != nil {
+ if perr := parseDialError(err); perr != nil {
+ t.Error(perr)
+ }
+ t.Fatal(err)
+ }
+ defer c2.Close()
+ go transceiver(c2, []byte("UDP SERVER TEST"), trch)
+ } else {
+ c2, err := ListenPacket(tt.tnet, JoinHostPort(tt.taddr, "0"))
+ if err != nil {
+ if perr := parseDialError(err); perr != nil {
+ t.Error(perr)
+ }
+ t.Fatal(err)
+ }
+ defer c2.Close()
+ dst, err := ResolveUDPAddr(tt.tnet, JoinHostPort(tt.taddr, port))
+ if err != nil {
+ t.Fatal(err)
+ }
+ go packetTransceiver(c2, []byte("UDP SERVER TEST"), dst, trch)
+ }
+
+ for err := range trch {
+ t.Errorf("#%d: %v", i, err)
+ }
+ for err := range tpch {
+ t.Errorf("#%d: %v", i, err)
+ }
+ }
+}
+
+func TestUnixgramServer(t *testing.T) {
+ var unixgramServerTests = []struct {
+ saddr string // server endpoint
+ caddr string // client endpoint
+ dial bool // test with Dial
+ }{
+ {saddr: testUnixAddr(t), caddr: testUnixAddr(t)},
+ {saddr: testUnixAddr(t), caddr: testUnixAddr(t), dial: true},
+
+ {saddr: "@nettest/go/unixgram/server", caddr: "@nettest/go/unixgram/client"},
+ }
+
+ for i, tt := range unixgramServerTests {
+ if !testableListenArgs("unixgram", tt.saddr, "") {
+ t.Logf("skipping %s test", "unixgram "+tt.saddr+"<-"+tt.caddr)
+ continue
+ }
+
+ c1, err := ListenPacket("unixgram", tt.saddr)
+ if err != nil {
+ if perr := parseDialError(err); perr != nil {
+ t.Error(perr)
+ }
+ t.Fatal(err)
+ }
+
+ ls := (&packetListener{PacketConn: c1}).newLocalServer()
+ defer ls.teardown()
+ tpch := make(chan error, 1)
+ handler := func(ls *localPacketServer, c PacketConn) { packetTransponder(c, tpch) }
+ if err := ls.buildup(handler); err != nil {
+ t.Fatal(err)
+ }
+
+ trch := make(chan error, 1)
+ if tt.dial {
+ d := Dialer{Timeout: someTimeout, LocalAddr: &UnixAddr{Net: "unixgram", Name: tt.caddr}}
+ c2, err := d.Dial("unixgram", ls.PacketConn.LocalAddr().String())
+ if err != nil {
+ if perr := parseDialError(err); perr != nil {
+ t.Error(perr)
+ }
+ t.Fatal(err)
+ }
+ defer os.Remove(c2.LocalAddr().String())
+ defer c2.Close()
+ go transceiver(c2, []byte(c2.LocalAddr().String()), trch)
+ } else {
+ c2, err := ListenPacket("unixgram", tt.caddr)
+ if err != nil {
+ if perr := parseDialError(err); perr != nil {
+ t.Error(perr)
+ }
+ t.Fatal(err)
+ }
+ defer os.Remove(c2.LocalAddr().String())
+ defer c2.Close()
+ go packetTransceiver(c2, []byte("UNIXGRAM SERVER TEST"), ls.PacketConn.LocalAddr(), trch)
+ }
+
+ for err := range trch {
+ t.Errorf("#%d: %v", i, err)
+ }
+ for err := range tpch {
+ t.Errorf("#%d: %v", i, err)
+ }
+ }
+}
diff --git a/src/net/smtp/auth.go b/src/net/smtp/auth.go
new file mode 100644
index 0000000..72eb166
--- /dev/null
+++ b/src/net/smtp/auth.go
@@ -0,0 +1,109 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package smtp
+
+import (
+ "crypto/hmac"
+ "crypto/md5"
+ "errors"
+ "fmt"
+)
+
+// Auth is implemented by an SMTP authentication mechanism.
+type Auth interface {
+ // Start begins an authentication with a server.
+ // It returns the name of the authentication protocol
+ // and optionally data to include in the initial AUTH message
+ // sent to the server.
+ // If it returns a non-nil error, the SMTP client aborts
+ // the authentication attempt and closes the connection.
+ Start(server *ServerInfo) (proto string, toServer []byte, err error)
+
+ // Next continues the authentication. The server has just sent
+ // the fromServer data. If more is true, the server expects a
+ // response, which Next should return as toServer; otherwise
+ // Next should return toServer == nil.
+ // If Next returns a non-nil error, the SMTP client aborts
+ // the authentication attempt and closes the connection.
+ Next(fromServer []byte, more bool) (toServer []byte, err error)
+}
+
+// ServerInfo records information about an SMTP server.
+type ServerInfo struct {
+ Name string // SMTP server name
+ TLS bool // using TLS, with valid certificate for Name
+ Auth []string // advertised authentication mechanisms
+}
+
+type plainAuth struct {
+ identity, username, password string
+ host string
+}
+
+// PlainAuth returns an Auth that implements the PLAIN authentication
+// mechanism as defined in RFC 4616. The returned Auth uses the given
+// username and password to authenticate to host and act as identity.
+// Usually identity should be the empty string, to act as username.
+//
+// PlainAuth will only send the credentials if the connection is using TLS
+// or is connected to localhost. Otherwise authentication will fail with an
+// error, without sending the credentials.
+func PlainAuth(identity, username, password, host string) Auth {
+ return &plainAuth{identity, username, password, host}
+}
+
+func isLocalhost(name string) bool {
+ return name == "localhost" || name == "127.0.0.1" || name == "::1"
+}
+
+func (a *plainAuth) Start(server *ServerInfo) (string, []byte, error) {
+ // Must have TLS, or else localhost server.
+ // Note: If TLS is not true, then we can't trust ANYTHING in ServerInfo.
+ // In particular, it doesn't matter if the server advertises PLAIN auth.
+ // That might just be the attacker saying
+ // "it's ok, you can trust me with your password."
+ if !server.TLS && !isLocalhost(server.Name) {
+ return "", nil, errors.New("unencrypted connection")
+ }
+ if server.Name != a.host {
+ return "", nil, errors.New("wrong host name")
+ }
+ resp := []byte(a.identity + "\x00" + a.username + "\x00" + a.password)
+ return "PLAIN", resp, nil
+}
+
+func (a *plainAuth) Next(fromServer []byte, more bool) ([]byte, error) {
+ if more {
+ // We've already sent everything.
+ return nil, errors.New("unexpected server challenge")
+ }
+ return nil, nil
+}
+
+type cramMD5Auth struct {
+ username, secret string
+}
+
+// CRAMMD5Auth returns an Auth that implements the CRAM-MD5 authentication
+// mechanism as defined in RFC 2195.
+// The returned Auth uses the given username and secret to authenticate
+// to the server using the challenge-response mechanism.
+func CRAMMD5Auth(username, secret string) Auth {
+ return &cramMD5Auth{username, secret}
+}
+
+func (a *cramMD5Auth) Start(server *ServerInfo) (string, []byte, error) {
+ return "CRAM-MD5", nil, nil
+}
+
+func (a *cramMD5Auth) Next(fromServer []byte, more bool) ([]byte, error) {
+ if more {
+ d := hmac.New(md5.New, []byte(a.secret))
+ d.Write(fromServer)
+ s := make([]byte, 0, d.Size())
+ return fmt.Appendf(nil, "%s %x", a.username, d.Sum(s)), nil
+ }
+ return nil, nil
+}
diff --git a/src/net/smtp/example_test.go b/src/net/smtp/example_test.go
new file mode 100644
index 0000000..16419f4
--- /dev/null
+++ b/src/net/smtp/example_test.go
@@ -0,0 +1,83 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package smtp_test
+
+import (
+ "fmt"
+ "log"
+ "net/smtp"
+)
+
+func Example() {
+ // Connect to the remote SMTP server.
+ c, err := smtp.Dial("mail.example.com:25")
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ // Set the sender and recipient first
+ if err := c.Mail("sender@example.org"); err != nil {
+ log.Fatal(err)
+ }
+ if err := c.Rcpt("recipient@example.net"); err != nil {
+ log.Fatal(err)
+ }
+
+ // Send the email body.
+ wc, err := c.Data()
+ if err != nil {
+ log.Fatal(err)
+ }
+ _, err = fmt.Fprintf(wc, "This is the email body")
+ if err != nil {
+ log.Fatal(err)
+ }
+ err = wc.Close()
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ // Send the QUIT command and close the connection.
+ err = c.Quit()
+ if err != nil {
+ log.Fatal(err)
+ }
+}
+
+// variables to make ExamplePlainAuth compile, without adding
+// unnecessary noise there.
+var (
+ from = "gopher@example.net"
+ msg = []byte("dummy message")
+ recipients = []string{"foo@example.com"}
+)
+
+func ExamplePlainAuth() {
+ // hostname is used by PlainAuth to validate the TLS certificate.
+ hostname := "mail.example.com"
+ auth := smtp.PlainAuth("", "user@example.com", "password", hostname)
+
+ err := smtp.SendMail(hostname+":25", auth, from, recipients, msg)
+ if err != nil {
+ log.Fatal(err)
+ }
+}
+
+func ExampleSendMail() {
+ // Set up authentication information.
+ auth := smtp.PlainAuth("", "user@example.com", "password", "mail.example.com")
+
+ // Connect to the server, authenticate, set the sender and recipient,
+ // and send the email all in one step.
+ to := []string{"recipient@example.net"}
+ msg := []byte("To: recipient@example.net\r\n" +
+ "Subject: discount Gophers!\r\n" +
+ "\r\n" +
+ "This is the email body.\r\n")
+ err := smtp.SendMail("mail.example.com:25", auth, "sender@example.org", to, msg)
+ if err != nil {
+ log.Fatal(err)
+ }
+}
diff --git a/src/net/smtp/smtp.go b/src/net/smtp/smtp.go
new file mode 100644
index 0000000..b5a025e
--- /dev/null
+++ b/src/net/smtp/smtp.go
@@ -0,0 +1,432 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package smtp implements the Simple Mail Transfer Protocol as defined in RFC 5321.
+// It also implements the following extensions:
+//
+// 8BITMIME RFC 1652
+// AUTH RFC 2554
+// STARTTLS RFC 3207
+//
+// Additional extensions may be handled by clients.
+//
+// The smtp package is frozen and is not accepting new features.
+// Some external packages provide more functionality. See:
+//
+// https://godoc.org/?q=smtp
+package smtp
+
+import (
+ "crypto/tls"
+ "encoding/base64"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "net/textproto"
+ "strings"
+)
+
+// A Client represents a client connection to an SMTP server.
+type Client struct {
+ // Text is the textproto.Conn used by the Client. It is exported to allow for
+ // clients to add extensions.
+ Text *textproto.Conn
+ // keep a reference to the connection so it can be used to create a TLS
+ // connection later
+ conn net.Conn
+ // whether the Client is using TLS
+ tls bool
+ serverName string
+ // map of supported extensions
+ ext map[string]string
+ // supported auth mechanisms
+ auth []string
+ localName string // the name to use in HELO/EHLO
+ didHello bool // whether we've said HELO/EHLO
+ helloError error // the error from the hello
+}
+
+// Dial returns a new Client connected to an SMTP server at addr.
+// The addr must include a port, as in "mail.example.com:smtp".
+func Dial(addr string) (*Client, error) {
+ conn, err := net.Dial("tcp", addr)
+ if err != nil {
+ return nil, err
+ }
+ host, _, _ := net.SplitHostPort(addr)
+ return NewClient(conn, host)
+}
+
+// NewClient returns a new Client using an existing connection and host as a
+// server name to be used when authenticating.
+func NewClient(conn net.Conn, host string) (*Client, error) {
+ text := textproto.NewConn(conn)
+ _, _, err := text.ReadResponse(220)
+ if err != nil {
+ text.Close()
+ return nil, err
+ }
+ c := &Client{Text: text, conn: conn, serverName: host, localName: "localhost"}
+ _, c.tls = conn.(*tls.Conn)
+ return c, nil
+}
+
+// Close closes the connection.
+func (c *Client) Close() error {
+ return c.Text.Close()
+}
+
+// hello runs a hello exchange if needed.
+func (c *Client) hello() error {
+ if !c.didHello {
+ c.didHello = true
+ err := c.ehlo()
+ if err != nil {
+ c.helloError = c.helo()
+ }
+ }
+ return c.helloError
+}
+
+// Hello sends a HELO or EHLO to the server as the given host name.
+// Calling this method is only necessary if the client needs control
+// over the host name used. The client will introduce itself as "localhost"
+// automatically otherwise. If Hello is called, it must be called before
+// any of the other methods.
+func (c *Client) Hello(localName string) error {
+ if err := validateLine(localName); err != nil {
+ return err
+ }
+ if c.didHello {
+ return errors.New("smtp: Hello called after other methods")
+ }
+ c.localName = localName
+ return c.hello()
+}
+
+// cmd is a convenience function that sends a command and returns the response
+func (c *Client) cmd(expectCode int, format string, args ...any) (int, string, error) {
+ id, err := c.Text.Cmd(format, args...)
+ if err != nil {
+ return 0, "", err
+ }
+ c.Text.StartResponse(id)
+ defer c.Text.EndResponse(id)
+ code, msg, err := c.Text.ReadResponse(expectCode)
+ return code, msg, err
+}
+
+// helo sends the HELO greeting to the server. It should be used only when the
+// server does not support ehlo.
+func (c *Client) helo() error {
+ c.ext = nil
+ _, _, err := c.cmd(250, "HELO %s", c.localName)
+ return err
+}
+
+// ehlo sends the EHLO (extended hello) greeting to the server. It
+// should be the preferred greeting for servers that support it.
+func (c *Client) ehlo() error {
+ _, msg, err := c.cmd(250, "EHLO %s", c.localName)
+ if err != nil {
+ return err
+ }
+ ext := make(map[string]string)
+ extList := strings.Split(msg, "\n")
+ if len(extList) > 1 {
+ extList = extList[1:]
+ for _, line := range extList {
+ k, v, _ := strings.Cut(line, " ")
+ ext[k] = v
+ }
+ }
+ if mechs, ok := ext["AUTH"]; ok {
+ c.auth = strings.Split(mechs, " ")
+ }
+ c.ext = ext
+ return err
+}
+
+// StartTLS sends the STARTTLS command and encrypts all further communication.
+// Only servers that advertise the STARTTLS extension support this function.
+func (c *Client) StartTLS(config *tls.Config) error {
+ if err := c.hello(); err != nil {
+ return err
+ }
+ _, _, err := c.cmd(220, "STARTTLS")
+ if err != nil {
+ return err
+ }
+ c.conn = tls.Client(c.conn, config)
+ c.Text = textproto.NewConn(c.conn)
+ c.tls = true
+ return c.ehlo()
+}
+
+// TLSConnectionState returns the client's TLS connection state.
+// The return values are their zero values if StartTLS did
+// not succeed.
+func (c *Client) TLSConnectionState() (state tls.ConnectionState, ok bool) {
+ tc, ok := c.conn.(*tls.Conn)
+ if !ok {
+ return
+ }
+ return tc.ConnectionState(), true
+}
+
+// Verify checks the validity of an email address on the server.
+// If Verify returns nil, the address is valid. A non-nil return
+// does not necessarily indicate an invalid address. Many servers
+// will not verify addresses for security reasons.
+func (c *Client) Verify(addr string) error {
+ if err := validateLine(addr); err != nil {
+ return err
+ }
+ if err := c.hello(); err != nil {
+ return err
+ }
+ _, _, err := c.cmd(250, "VRFY %s", addr)
+ return err
+}
+
+// Auth authenticates a client using the provided authentication mechanism.
+// A failed authentication closes the connection.
+// Only servers that advertise the AUTH extension support this function.
+func (c *Client) Auth(a Auth) error {
+ if err := c.hello(); err != nil {
+ return err
+ }
+ encoding := base64.StdEncoding
+ mech, resp, err := a.Start(&ServerInfo{c.serverName, c.tls, c.auth})
+ if err != nil {
+ c.Quit()
+ return err
+ }
+ resp64 := make([]byte, encoding.EncodedLen(len(resp)))
+ encoding.Encode(resp64, resp)
+ code, msg64, err := c.cmd(0, strings.TrimSpace(fmt.Sprintf("AUTH %s %s", mech, resp64)))
+ for err == nil {
+ var msg []byte
+ switch code {
+ case 334:
+ msg, err = encoding.DecodeString(msg64)
+ case 235:
+ // the last message isn't base64 because it isn't a challenge
+ msg = []byte(msg64)
+ default:
+ err = &textproto.Error{Code: code, Msg: msg64}
+ }
+ if err == nil {
+ resp, err = a.Next(msg, code == 334)
+ }
+ if err != nil {
+ // abort the AUTH
+ c.cmd(501, "*")
+ c.Quit()
+ break
+ }
+ if resp == nil {
+ break
+ }
+ resp64 = make([]byte, encoding.EncodedLen(len(resp)))
+ encoding.Encode(resp64, resp)
+ code, msg64, err = c.cmd(0, string(resp64))
+ }
+ return err
+}
+
+// Mail issues a MAIL command to the server using the provided email address.
+// If the server supports the 8BITMIME extension, Mail adds the BODY=8BITMIME
+// parameter. If the server supports the SMTPUTF8 extension, Mail adds the
+// SMTPUTF8 parameter.
+// This initiates a mail transaction and is followed by one or more Rcpt calls.
+func (c *Client) Mail(from string) error {
+ if err := validateLine(from); err != nil {
+ return err
+ }
+ if err := c.hello(); err != nil {
+ return err
+ }
+ cmdStr := "MAIL FROM:<%s>"
+ if c.ext != nil {
+ if _, ok := c.ext["8BITMIME"]; ok {
+ cmdStr += " BODY=8BITMIME"
+ }
+ if _, ok := c.ext["SMTPUTF8"]; ok {
+ cmdStr += " SMTPUTF8"
+ }
+ }
+ _, _, err := c.cmd(250, cmdStr, from)
+ return err
+}
+
+// Rcpt issues a RCPT command to the server using the provided email address.
+// A call to Rcpt must be preceded by a call to Mail and may be followed by
+// a Data call or another Rcpt call.
+func (c *Client) Rcpt(to string) error {
+ if err := validateLine(to); err != nil {
+ return err
+ }
+ _, _, err := c.cmd(25, "RCPT TO:<%s>", to)
+ return err
+}
+
+type dataCloser struct {
+ c *Client
+ io.WriteCloser
+}
+
+func (d *dataCloser) Close() error {
+ d.WriteCloser.Close()
+ _, _, err := d.c.Text.ReadResponse(250)
+ return err
+}
+
+// Data issues a DATA command to the server and returns a writer that
+// can be used to write the mail headers and body. The caller should
+// close the writer before calling any more methods on c. A call to
+// Data must be preceded by one or more calls to Rcpt.
+func (c *Client) Data() (io.WriteCloser, error) {
+ _, _, err := c.cmd(354, "DATA")
+ if err != nil {
+ return nil, err
+ }
+ return &dataCloser{c, c.Text.DotWriter()}, nil
+}
+
+var testHookStartTLS func(*tls.Config) // nil, except for tests
+
+// SendMail connects to the server at addr, switches to TLS if
+// possible, authenticates with the optional mechanism a if possible,
+// and then sends an email from address from, to addresses to, with
+// message msg.
+// The addr must include a port, as in "mail.example.com:smtp".
+//
+// The addresses in the to parameter are the SMTP RCPT addresses.
+//
+// The msg parameter should be an RFC 822-style email with headers
+// first, a blank line, and then the message body. The lines of msg
+// should be CRLF terminated. The msg headers should usually include
+// fields such as "From", "To", "Subject", and "Cc". Sending "Bcc"
+// messages is accomplished by including an email address in the to
+// parameter but not including it in the msg headers.
+//
+// The SendMail function and the net/smtp package are low-level
+// mechanisms and provide no support for DKIM signing, MIME
+// attachments (see the mime/multipart package), or other mail
+// functionality. Higher-level packages exist outside of the standard
+// library.
+func SendMail(addr string, a Auth, from string, to []string, msg []byte) error {
+ if err := validateLine(from); err != nil {
+ return err
+ }
+ for _, recp := range to {
+ if err := validateLine(recp); err != nil {
+ return err
+ }
+ }
+ c, err := Dial(addr)
+ if err != nil {
+ return err
+ }
+ defer c.Close()
+ if err = c.hello(); err != nil {
+ return err
+ }
+ if ok, _ := c.Extension("STARTTLS"); ok {
+ config := &tls.Config{ServerName: c.serverName}
+ if testHookStartTLS != nil {
+ testHookStartTLS(config)
+ }
+ if err = c.StartTLS(config); err != nil {
+ return err
+ }
+ }
+ if a != nil && c.ext != nil {
+ if _, ok := c.ext["AUTH"]; !ok {
+ return errors.New("smtp: server doesn't support AUTH")
+ }
+ if err = c.Auth(a); err != nil {
+ return err
+ }
+ }
+ if err = c.Mail(from); err != nil {
+ return err
+ }
+ for _, addr := range to {
+ if err = c.Rcpt(addr); err != nil {
+ return err
+ }
+ }
+ w, err := c.Data()
+ if err != nil {
+ return err
+ }
+ _, err = w.Write(msg)
+ if err != nil {
+ return err
+ }
+ err = w.Close()
+ if err != nil {
+ return err
+ }
+ return c.Quit()
+}
+
+// Extension reports whether an extension is support by the server.
+// The extension name is case-insensitive. If the extension is supported,
+// Extension also returns a string that contains any parameters the
+// server specifies for the extension.
+func (c *Client) Extension(ext string) (bool, string) {
+ if err := c.hello(); err != nil {
+ return false, ""
+ }
+ if c.ext == nil {
+ return false, ""
+ }
+ ext = strings.ToUpper(ext)
+ param, ok := c.ext[ext]
+ return ok, param
+}
+
+// Reset sends the RSET command to the server, aborting the current mail
+// transaction.
+func (c *Client) Reset() error {
+ if err := c.hello(); err != nil {
+ return err
+ }
+ _, _, err := c.cmd(250, "RSET")
+ return err
+}
+
+// Noop sends the NOOP command to the server. It does nothing but check
+// that the connection to the server is okay.
+func (c *Client) Noop() error {
+ if err := c.hello(); err != nil {
+ return err
+ }
+ _, _, err := c.cmd(250, "NOOP")
+ return err
+}
+
+// Quit sends the QUIT command and closes the connection to the server.
+func (c *Client) Quit() error {
+ if err := c.hello(); err != nil {
+ return err
+ }
+ _, _, err := c.cmd(221, "QUIT")
+ if err != nil {
+ return err
+ }
+ return c.Text.Close()
+}
+
+// validateLine checks to see if a line has CR or LF as per RFC 5321.
+func validateLine(line string) error {
+ if strings.ContainsAny(line, "\n\r") {
+ return errors.New("smtp: A line must not contain CR or LF")
+ }
+ return nil
+}
diff --git a/src/net/smtp/smtp_test.go b/src/net/smtp/smtp_test.go
new file mode 100644
index 0000000..259b10b
--- /dev/null
+++ b/src/net/smtp/smtp_test.go
@@ -0,0 +1,1144 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package smtp
+
+import (
+ "bufio"
+ "bytes"
+ "crypto/tls"
+ "crypto/x509"
+ "fmt"
+ "internal/testenv"
+ "io"
+ "net"
+ "net/textproto"
+ "runtime"
+ "strings"
+ "testing"
+ "time"
+)
+
+type authTest struct {
+ auth Auth
+ challenges []string
+ name string
+ responses []string
+}
+
+var authTests = []authTest{
+ {PlainAuth("", "user", "pass", "testserver"), []string{}, "PLAIN", []string{"\x00user\x00pass"}},
+ {PlainAuth("foo", "bar", "baz", "testserver"), []string{}, "PLAIN", []string{"foo\x00bar\x00baz"}},
+ {CRAMMD5Auth("user", "pass"), []string{"<123456.1322876914@testserver>"}, "CRAM-MD5", []string{"", "user 287eb355114cf5c471c26a875f1ca4ae"}},
+}
+
+func TestAuth(t *testing.T) {
+testLoop:
+ for i, test := range authTests {
+ name, resp, err := test.auth.Start(&ServerInfo{"testserver", true, nil})
+ if name != test.name {
+ t.Errorf("#%d got name %s, expected %s", i, name, test.name)
+ }
+ if !bytes.Equal(resp, []byte(test.responses[0])) {
+ t.Errorf("#%d got response %s, expected %s", i, resp, test.responses[0])
+ }
+ if err != nil {
+ t.Errorf("#%d error: %s", i, err)
+ }
+ for j := range test.challenges {
+ challenge := []byte(test.challenges[j])
+ expected := []byte(test.responses[j+1])
+ resp, err := test.auth.Next(challenge, true)
+ if err != nil {
+ t.Errorf("#%d error: %s", i, err)
+ continue testLoop
+ }
+ if !bytes.Equal(resp, expected) {
+ t.Errorf("#%d got %s, expected %s", i, resp, expected)
+ continue testLoop
+ }
+ }
+ }
+}
+
+func TestAuthPlain(t *testing.T) {
+
+ tests := []struct {
+ authName string
+ server *ServerInfo
+ err string
+ }{
+ {
+ authName: "servername",
+ server: &ServerInfo{Name: "servername", TLS: true},
+ },
+ {
+ // OK to use PlainAuth on localhost without TLS
+ authName: "localhost",
+ server: &ServerInfo{Name: "localhost", TLS: false},
+ },
+ {
+ // NOT OK on non-localhost, even if server says PLAIN is OK.
+ // (We don't know that the server is the real server.)
+ authName: "servername",
+ server: &ServerInfo{Name: "servername", Auth: []string{"PLAIN"}},
+ err: "unencrypted connection",
+ },
+ {
+ authName: "servername",
+ server: &ServerInfo{Name: "servername", Auth: []string{"CRAM-MD5"}},
+ err: "unencrypted connection",
+ },
+ {
+ authName: "servername",
+ server: &ServerInfo{Name: "attacker", TLS: true},
+ err: "wrong host name",
+ },
+ }
+ for i, tt := range tests {
+ auth := PlainAuth("foo", "bar", "baz", tt.authName)
+ _, _, err := auth.Start(tt.server)
+ got := ""
+ if err != nil {
+ got = err.Error()
+ }
+ if got != tt.err {
+ t.Errorf("%d. got error = %q; want %q", i, got, tt.err)
+ }
+ }
+}
+
+// Issue 17794: don't send a trailing space on AUTH command when there's no password.
+func TestClientAuthTrimSpace(t *testing.T) {
+ server := "220 hello world\r\n" +
+ "200 some more"
+ var wrote strings.Builder
+ var fake faker
+ fake.ReadWriter = struct {
+ io.Reader
+ io.Writer
+ }{
+ strings.NewReader(server),
+ &wrote,
+ }
+ c, err := NewClient(fake, "fake.host")
+ if err != nil {
+ t.Fatalf("NewClient: %v", err)
+ }
+ c.tls = true
+ c.didHello = true
+ c.Auth(toServerEmptyAuth{})
+ c.Close()
+ if got, want := wrote.String(), "AUTH FOOAUTH\r\n*\r\nQUIT\r\n"; got != want {
+ t.Errorf("wrote %q; want %q", got, want)
+ }
+}
+
+// toServerEmptyAuth is an implementation of Auth that only implements
+// the Start method, and returns "FOOAUTH", nil, nil. Notably, it returns
+// zero bytes for "toServer" so we can test that we don't send spaces at
+// the end of the line. See TestClientAuthTrimSpace.
+type toServerEmptyAuth struct{}
+
+func (toServerEmptyAuth) Start(server *ServerInfo) (proto string, toServer []byte, err error) {
+ return "FOOAUTH", nil, nil
+}
+
+func (toServerEmptyAuth) Next(fromServer []byte, more bool) (toServer []byte, err error) {
+ panic("unexpected call")
+}
+
+type faker struct {
+ io.ReadWriter
+}
+
+func (f faker) Close() error { return nil }
+func (f faker) LocalAddr() net.Addr { return nil }
+func (f faker) RemoteAddr() net.Addr { return nil }
+func (f faker) SetDeadline(time.Time) error { return nil }
+func (f faker) SetReadDeadline(time.Time) error { return nil }
+func (f faker) SetWriteDeadline(time.Time) error { return nil }
+
+func TestBasic(t *testing.T) {
+ server := strings.Join(strings.Split(basicServer, "\n"), "\r\n")
+ client := strings.Join(strings.Split(basicClient, "\n"), "\r\n")
+
+ var cmdbuf strings.Builder
+ bcmdbuf := bufio.NewWriter(&cmdbuf)
+ var fake faker
+ fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(server)), bcmdbuf)
+ c := &Client{Text: textproto.NewConn(fake), localName: "localhost"}
+
+ if err := c.helo(); err != nil {
+ t.Fatalf("HELO failed: %s", err)
+ }
+ if err := c.ehlo(); err == nil {
+ t.Fatalf("Expected first EHLO to fail")
+ }
+ if err := c.ehlo(); err != nil {
+ t.Fatalf("Second EHLO failed: %s", err)
+ }
+
+ c.didHello = true
+ if ok, args := c.Extension("aUtH"); !ok || args != "LOGIN PLAIN" {
+ t.Fatalf("Expected AUTH supported")
+ }
+ if ok, _ := c.Extension("DSN"); ok {
+ t.Fatalf("Shouldn't support DSN")
+ }
+
+ if err := c.Mail("user@gmail.com"); err == nil {
+ t.Fatalf("MAIL should require authentication")
+ }
+
+ if err := c.Verify("user1@gmail.com"); err == nil {
+ t.Fatalf("First VRFY: expected no verification")
+ }
+ if err := c.Verify("user2@gmail.com>\r\nDATA\r\nAnother injected message body\r\n.\r\nQUIT\r\n"); err == nil {
+ t.Fatalf("VRFY should have failed due to a message injection attempt")
+ }
+ if err := c.Verify("user2@gmail.com"); err != nil {
+ t.Fatalf("Second VRFY: expected verification, got %s", err)
+ }
+
+ // fake TLS so authentication won't complain
+ c.tls = true
+ c.serverName = "smtp.google.com"
+ if err := c.Auth(PlainAuth("", "user", "pass", "smtp.google.com")); err != nil {
+ t.Fatalf("AUTH failed: %s", err)
+ }
+
+ if err := c.Rcpt("golang-nuts@googlegroups.com>\r\nDATA\r\nInjected message body\r\n.\r\nQUIT\r\n"); err == nil {
+ t.Fatalf("RCPT should have failed due to a message injection attempt")
+ }
+ if err := c.Mail("user@gmail.com>\r\nDATA\r\nAnother injected message body\r\n.\r\nQUIT\r\n"); err == nil {
+ t.Fatalf("MAIL should have failed due to a message injection attempt")
+ }
+ if err := c.Mail("user@gmail.com"); err != nil {
+ t.Fatalf("MAIL failed: %s", err)
+ }
+ if err := c.Rcpt("golang-nuts@googlegroups.com"); err != nil {
+ t.Fatalf("RCPT failed: %s", err)
+ }
+ msg := `From: user@gmail.com
+To: golang-nuts@googlegroups.com
+Subject: Hooray for Go
+
+Line 1
+.Leading dot line .
+Goodbye.`
+ w, err := c.Data()
+ if err != nil {
+ t.Fatalf("DATA failed: %s", err)
+ }
+ if _, err := w.Write([]byte(msg)); err != nil {
+ t.Fatalf("Data write failed: %s", err)
+ }
+ if err := w.Close(); err != nil {
+ t.Fatalf("Bad data response: %s", err)
+ }
+
+ if err := c.Quit(); err != nil {
+ t.Fatalf("QUIT failed: %s", err)
+ }
+
+ bcmdbuf.Flush()
+ actualcmds := cmdbuf.String()
+ if client != actualcmds {
+ t.Fatalf("Got:\n%s\nExpected:\n%s", actualcmds, client)
+ }
+}
+
+var basicServer = `250 mx.google.com at your service
+502 Unrecognized command.
+250-mx.google.com at your service
+250-SIZE 35651584
+250-AUTH LOGIN PLAIN
+250 8BITMIME
+530 Authentication required
+252 Send some mail, I'll try my best
+250 User is valid
+235 Accepted
+250 Sender OK
+250 Receiver OK
+354 Go ahead
+250 Data OK
+221 OK
+`
+
+var basicClient = `HELO localhost
+EHLO localhost
+EHLO localhost
+MAIL FROM:<user@gmail.com> BODY=8BITMIME
+VRFY user1@gmail.com
+VRFY user2@gmail.com
+AUTH PLAIN AHVzZXIAcGFzcw==
+MAIL FROM:<user@gmail.com> BODY=8BITMIME
+RCPT TO:<golang-nuts@googlegroups.com>
+DATA
+From: user@gmail.com
+To: golang-nuts@googlegroups.com
+Subject: Hooray for Go
+
+Line 1
+..Leading dot line .
+Goodbye.
+.
+QUIT
+`
+
+func TestExtensions(t *testing.T) {
+ fake := func(server string) (c *Client, bcmdbuf *bufio.Writer, cmdbuf *strings.Builder) {
+ server = strings.Join(strings.Split(server, "\n"), "\r\n")
+
+ cmdbuf = &strings.Builder{}
+ bcmdbuf = bufio.NewWriter(cmdbuf)
+ var fake faker
+ fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(server)), bcmdbuf)
+ c = &Client{Text: textproto.NewConn(fake), localName: "localhost"}
+
+ return c, bcmdbuf, cmdbuf
+ }
+
+ t.Run("helo", func(t *testing.T) {
+ const (
+ basicServer = `250 mx.google.com at your service
+250 Sender OK
+221 Goodbye
+`
+
+ basicClient = `HELO localhost
+MAIL FROM:<user@gmail.com>
+QUIT
+`
+ )
+
+ c, bcmdbuf, cmdbuf := fake(basicServer)
+
+ if err := c.helo(); err != nil {
+ t.Fatalf("HELO failed: %s", err)
+ }
+ c.didHello = true
+ if err := c.Mail("user@gmail.com"); err != nil {
+ t.Fatalf("MAIL FROM failed: %s", err)
+ }
+ if err := c.Quit(); err != nil {
+ t.Fatalf("QUIT failed: %s", err)
+ }
+
+ bcmdbuf.Flush()
+ actualcmds := cmdbuf.String()
+ client := strings.Join(strings.Split(basicClient, "\n"), "\r\n")
+ if client != actualcmds {
+ t.Fatalf("Got:\n%s\nExpected:\n%s", actualcmds, client)
+ }
+ })
+
+ t.Run("ehlo", func(t *testing.T) {
+ const (
+ basicServer = `250-mx.google.com at your service
+250 SIZE 35651584
+250 Sender OK
+221 Goodbye
+`
+
+ basicClient = `EHLO localhost
+MAIL FROM:<user@gmail.com>
+QUIT
+`
+ )
+
+ c, bcmdbuf, cmdbuf := fake(basicServer)
+
+ if err := c.Hello("localhost"); err != nil {
+ t.Fatalf("EHLO failed: %s", err)
+ }
+ if ok, _ := c.Extension("8BITMIME"); ok {
+ t.Fatalf("Shouldn't support 8BITMIME")
+ }
+ if ok, _ := c.Extension("SMTPUTF8"); ok {
+ t.Fatalf("Shouldn't support SMTPUTF8")
+ }
+ if err := c.Mail("user@gmail.com"); err != nil {
+ t.Fatalf("MAIL FROM failed: %s", err)
+ }
+ if err := c.Quit(); err != nil {
+ t.Fatalf("QUIT failed: %s", err)
+ }
+
+ bcmdbuf.Flush()
+ actualcmds := cmdbuf.String()
+ client := strings.Join(strings.Split(basicClient, "\n"), "\r\n")
+ if client != actualcmds {
+ t.Fatalf("Got:\n%s\nExpected:\n%s", actualcmds, client)
+ }
+ })
+
+ t.Run("ehlo 8bitmime", func(t *testing.T) {
+ const (
+ basicServer = `250-mx.google.com at your service
+250-SIZE 35651584
+250 8BITMIME
+250 Sender OK
+221 Goodbye
+`
+
+ basicClient = `EHLO localhost
+MAIL FROM:<user@gmail.com> BODY=8BITMIME
+QUIT
+`
+ )
+
+ c, bcmdbuf, cmdbuf := fake(basicServer)
+
+ if err := c.Hello("localhost"); err != nil {
+ t.Fatalf("EHLO failed: %s", err)
+ }
+ if ok, _ := c.Extension("8BITMIME"); !ok {
+ t.Fatalf("Should support 8BITMIME")
+ }
+ if ok, _ := c.Extension("SMTPUTF8"); ok {
+ t.Fatalf("Shouldn't support SMTPUTF8")
+ }
+ if err := c.Mail("user@gmail.com"); err != nil {
+ t.Fatalf("MAIL FROM failed: %s", err)
+ }
+ if err := c.Quit(); err != nil {
+ t.Fatalf("QUIT failed: %s", err)
+ }
+
+ bcmdbuf.Flush()
+ actualcmds := cmdbuf.String()
+ client := strings.Join(strings.Split(basicClient, "\n"), "\r\n")
+ if client != actualcmds {
+ t.Fatalf("Got:\n%s\nExpected:\n%s", actualcmds, client)
+ }
+ })
+
+ t.Run("ehlo smtputf8", func(t *testing.T) {
+ const (
+ basicServer = `250-mx.google.com at your service
+250-SIZE 35651584
+250 SMTPUTF8
+250 Sender OK
+221 Goodbye
+`
+
+ basicClient = `EHLO localhost
+MAIL FROM:<user+📧@gmail.com> SMTPUTF8
+QUIT
+`
+ )
+
+ c, bcmdbuf, cmdbuf := fake(basicServer)
+
+ if err := c.Hello("localhost"); err != nil {
+ t.Fatalf("EHLO failed: %s", err)
+ }
+ if ok, _ := c.Extension("8BITMIME"); ok {
+ t.Fatalf("Shouldn't support 8BITMIME")
+ }
+ if ok, _ := c.Extension("SMTPUTF8"); !ok {
+ t.Fatalf("Should support SMTPUTF8")
+ }
+ if err := c.Mail("user+📧@gmail.com"); err != nil {
+ t.Fatalf("MAIL FROM failed: %s", err)
+ }
+ if err := c.Quit(); err != nil {
+ t.Fatalf("QUIT failed: %s", err)
+ }
+
+ bcmdbuf.Flush()
+ actualcmds := cmdbuf.String()
+ client := strings.Join(strings.Split(basicClient, "\n"), "\r\n")
+ if client != actualcmds {
+ t.Fatalf("Got:\n%s\nExpected:\n%s", actualcmds, client)
+ }
+ })
+
+ t.Run("ehlo 8bitmime smtputf8", func(t *testing.T) {
+ const (
+ basicServer = `250-mx.google.com at your service
+250-SIZE 35651584
+250-8BITMIME
+250 SMTPUTF8
+250 Sender OK
+221 Goodbye
+ `
+
+ basicClient = `EHLO localhost
+MAIL FROM:<user+📧@gmail.com> BODY=8BITMIME SMTPUTF8
+QUIT
+`
+ )
+
+ c, bcmdbuf, cmdbuf := fake(basicServer)
+
+ if err := c.Hello("localhost"); err != nil {
+ t.Fatalf("EHLO failed: %s", err)
+ }
+ c.didHello = true
+ if ok, _ := c.Extension("8BITMIME"); !ok {
+ t.Fatalf("Should support 8BITMIME")
+ }
+ if ok, _ := c.Extension("SMTPUTF8"); !ok {
+ t.Fatalf("Should support SMTPUTF8")
+ }
+ if err := c.Mail("user+📧@gmail.com"); err != nil {
+ t.Fatalf("MAIL FROM failed: %s", err)
+ }
+ if err := c.Quit(); err != nil {
+ t.Fatalf("QUIT failed: %s", err)
+ }
+
+ bcmdbuf.Flush()
+ actualcmds := cmdbuf.String()
+ client := strings.Join(strings.Split(basicClient, "\n"), "\r\n")
+ if client != actualcmds {
+ t.Fatalf("Got:\n%s\nExpected:\n%s", actualcmds, client)
+ }
+ })
+}
+
+func TestNewClient(t *testing.T) {
+ server := strings.Join(strings.Split(newClientServer, "\n"), "\r\n")
+ client := strings.Join(strings.Split(newClientClient, "\n"), "\r\n")
+
+ var cmdbuf strings.Builder
+ bcmdbuf := bufio.NewWriter(&cmdbuf)
+ out := func() string {
+ bcmdbuf.Flush()
+ return cmdbuf.String()
+ }
+ var fake faker
+ fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(server)), bcmdbuf)
+ c, err := NewClient(fake, "fake.host")
+ if err != nil {
+ t.Fatalf("NewClient: %v\n(after %v)", err, out())
+ }
+ defer c.Close()
+ if ok, args := c.Extension("aUtH"); !ok || args != "LOGIN PLAIN" {
+ t.Fatalf("Expected AUTH supported")
+ }
+ if ok, _ := c.Extension("DSN"); ok {
+ t.Fatalf("Shouldn't support DSN")
+ }
+ if err := c.Quit(); err != nil {
+ t.Fatalf("QUIT failed: %s", err)
+ }
+
+ actualcmds := out()
+ if client != actualcmds {
+ t.Fatalf("Got:\n%s\nExpected:\n%s", actualcmds, client)
+ }
+}
+
+var newClientServer = `220 hello world
+250-mx.google.com at your service
+250-SIZE 35651584
+250-AUTH LOGIN PLAIN
+250 8BITMIME
+221 OK
+`
+
+var newClientClient = `EHLO localhost
+QUIT
+`
+
+func TestNewClient2(t *testing.T) {
+ server := strings.Join(strings.Split(newClient2Server, "\n"), "\r\n")
+ client := strings.Join(strings.Split(newClient2Client, "\n"), "\r\n")
+
+ var cmdbuf strings.Builder
+ bcmdbuf := bufio.NewWriter(&cmdbuf)
+ var fake faker
+ fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(server)), bcmdbuf)
+ c, err := NewClient(fake, "fake.host")
+ if err != nil {
+ t.Fatalf("NewClient: %v", err)
+ }
+ defer c.Close()
+ if ok, _ := c.Extension("DSN"); ok {
+ t.Fatalf("Shouldn't support DSN")
+ }
+ if err := c.Quit(); err != nil {
+ t.Fatalf("QUIT failed: %s", err)
+ }
+
+ bcmdbuf.Flush()
+ actualcmds := cmdbuf.String()
+ if client != actualcmds {
+ t.Fatalf("Got:\n%s\nExpected:\n%s", actualcmds, client)
+ }
+}
+
+var newClient2Server = `220 hello world
+502 EH?
+250-mx.google.com at your service
+250-SIZE 35651584
+250-AUTH LOGIN PLAIN
+250 8BITMIME
+221 OK
+`
+
+var newClient2Client = `EHLO localhost
+HELO localhost
+QUIT
+`
+
+func TestNewClientWithTLS(t *testing.T) {
+ cert, err := tls.X509KeyPair(localhostCert, localhostKey)
+ if err != nil {
+ t.Fatalf("loadcert: %v", err)
+ }
+
+ config := tls.Config{Certificates: []tls.Certificate{cert}}
+
+ ln, err := tls.Listen("tcp", "127.0.0.1:0", &config)
+ if err != nil {
+ ln, err = tls.Listen("tcp", "[::1]:0", &config)
+ if err != nil {
+ t.Fatalf("server: listen: %v", err)
+ }
+ }
+
+ go func() {
+ conn, err := ln.Accept()
+ if err != nil {
+ t.Errorf("server: accept: %v", err)
+ return
+ }
+ defer conn.Close()
+
+ _, err = conn.Write([]byte("220 SIGNS\r\n"))
+ if err != nil {
+ t.Errorf("server: write: %v", err)
+ return
+ }
+ }()
+
+ config.InsecureSkipVerify = true
+ conn, err := tls.Dial("tcp", ln.Addr().String(), &config)
+ if err != nil {
+ t.Fatalf("client: dial: %v", err)
+ }
+ defer conn.Close()
+
+ client, err := NewClient(conn, ln.Addr().String())
+ if err != nil {
+ t.Fatalf("smtp: newclient: %v", err)
+ }
+ if !client.tls {
+ t.Errorf("client.tls Got: %t Expected: %t", client.tls, true)
+ }
+}
+
+func TestHello(t *testing.T) {
+
+ if len(helloServer) != len(helloClient) {
+ t.Fatalf("Hello server and client size mismatch")
+ }
+
+ for i := 0; i < len(helloServer); i++ {
+ server := strings.Join(strings.Split(baseHelloServer+helloServer[i], "\n"), "\r\n")
+ client := strings.Join(strings.Split(baseHelloClient+helloClient[i], "\n"), "\r\n")
+ var cmdbuf strings.Builder
+ bcmdbuf := bufio.NewWriter(&cmdbuf)
+ var fake faker
+ fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(server)), bcmdbuf)
+ c, err := NewClient(fake, "fake.host")
+ if err != nil {
+ t.Fatalf("NewClient: %v", err)
+ }
+ defer c.Close()
+ c.localName = "customhost"
+ err = nil
+
+ switch i {
+ case 0:
+ err = c.Hello("hostinjection>\n\rDATA\r\nInjected message body\r\n.\r\nQUIT\r\n")
+ if err == nil {
+ t.Errorf("Expected Hello to be rejected due to a message injection attempt")
+ }
+ err = c.Hello("customhost")
+ case 1:
+ err = c.StartTLS(nil)
+ if err.Error() == "502 Not implemented" {
+ err = nil
+ }
+ case 2:
+ err = c.Verify("test@example.com")
+ case 3:
+ c.tls = true
+ c.serverName = "smtp.google.com"
+ err = c.Auth(PlainAuth("", "user", "pass", "smtp.google.com"))
+ case 4:
+ err = c.Mail("test@example.com")
+ case 5:
+ ok, _ := c.Extension("feature")
+ if ok {
+ t.Errorf("Expected FEATURE not to be supported")
+ }
+ case 6:
+ err = c.Reset()
+ case 7:
+ err = c.Quit()
+ case 8:
+ err = c.Verify("test@example.com")
+ if err != nil {
+ err = c.Hello("customhost")
+ if err != nil {
+ t.Errorf("Want error, got none")
+ }
+ }
+ case 9:
+ err = c.Noop()
+ default:
+ t.Fatalf("Unhandled command")
+ }
+
+ if err != nil {
+ t.Errorf("Command %d failed: %v", i, err)
+ }
+
+ bcmdbuf.Flush()
+ actualcmds := cmdbuf.String()
+ if client != actualcmds {
+ t.Errorf("Got:\n%s\nExpected:\n%s", actualcmds, client)
+ }
+ }
+}
+
+var baseHelloServer = `220 hello world
+502 EH?
+250-mx.google.com at your service
+250 FEATURE
+`
+
+var helloServer = []string{
+ "",
+ "502 Not implemented\n",
+ "250 User is valid\n",
+ "235 Accepted\n",
+ "250 Sender ok\n",
+ "",
+ "250 Reset ok\n",
+ "221 Goodbye\n",
+ "250 Sender ok\n",
+ "250 ok\n",
+}
+
+var baseHelloClient = `EHLO customhost
+HELO customhost
+`
+
+var helloClient = []string{
+ "",
+ "STARTTLS\n",
+ "VRFY test@example.com\n",
+ "AUTH PLAIN AHVzZXIAcGFzcw==\n",
+ "MAIL FROM:<test@example.com>\n",
+ "",
+ "RSET\n",
+ "QUIT\n",
+ "VRFY test@example.com\n",
+ "NOOP\n",
+}
+
+func TestSendMail(t *testing.T) {
+ server := strings.Join(strings.Split(sendMailServer, "\n"), "\r\n")
+ client := strings.Join(strings.Split(sendMailClient, "\n"), "\r\n")
+ var cmdbuf strings.Builder
+ bcmdbuf := bufio.NewWriter(&cmdbuf)
+ l, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatalf("Unable to create listener: %v", err)
+ }
+ defer l.Close()
+
+ // prevent data race on bcmdbuf
+ var done = make(chan struct{})
+ go func(data []string) {
+
+ defer close(done)
+
+ conn, err := l.Accept()
+ if err != nil {
+ t.Errorf("Accept error: %v", err)
+ return
+ }
+ defer conn.Close()
+
+ tc := textproto.NewConn(conn)
+ for i := 0; i < len(data) && data[i] != ""; i++ {
+ tc.PrintfLine(data[i])
+ for len(data[i]) >= 4 && data[i][3] == '-' {
+ i++
+ tc.PrintfLine(data[i])
+ }
+ if data[i] == "221 Goodbye" {
+ return
+ }
+ read := false
+ for !read || data[i] == "354 Go ahead" {
+ msg, err := tc.ReadLine()
+ bcmdbuf.Write([]byte(msg + "\r\n"))
+ read = true
+ if err != nil {
+ t.Errorf("Read error: %v", err)
+ return
+ }
+ if data[i] == "354 Go ahead" && msg == "." {
+ break
+ }
+ }
+ }
+ }(strings.Split(server, "\r\n"))
+
+ err = SendMail(l.Addr().String(), nil, "test@example.com", []string{"other@example.com>\n\rDATA\r\nInjected message body\r\n.\r\nQUIT\r\n"}, []byte(strings.Replace(`From: test@example.com
+To: other@example.com
+Subject: SendMail test
+
+SendMail is working for me.
+`, "\n", "\r\n", -1)))
+ if err == nil {
+ t.Errorf("Expected SendMail to be rejected due to a message injection attempt")
+ }
+
+ err = SendMail(l.Addr().String(), nil, "test@example.com", []string{"other@example.com"}, []byte(strings.Replace(`From: test@example.com
+To: other@example.com
+Subject: SendMail test
+
+SendMail is working for me.
+`, "\n", "\r\n", -1)))
+
+ if err != nil {
+ t.Errorf("%v", err)
+ }
+
+ <-done
+ bcmdbuf.Flush()
+ actualcmds := cmdbuf.String()
+ if client != actualcmds {
+ t.Errorf("Got:\n%s\nExpected:\n%s", actualcmds, client)
+ }
+}
+
+var sendMailServer = `220 hello world
+502 EH?
+250 mx.google.com at your service
+250 Sender ok
+250 Receiver ok
+354 Go ahead
+250 Data ok
+221 Goodbye
+`
+
+var sendMailClient = `EHLO localhost
+HELO localhost
+MAIL FROM:<test@example.com>
+RCPT TO:<other@example.com>
+DATA
+From: test@example.com
+To: other@example.com
+Subject: SendMail test
+
+SendMail is working for me.
+.
+QUIT
+`
+
+func TestSendMailWithAuth(t *testing.T) {
+ l, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatalf("Unable to create listener: %v", err)
+ }
+ defer l.Close()
+
+ errCh := make(chan error)
+ go func() {
+ defer close(errCh)
+ conn, err := l.Accept()
+ if err != nil {
+ errCh <- fmt.Errorf("Accept: %v", err)
+ return
+ }
+ defer conn.Close()
+
+ tc := textproto.NewConn(conn)
+ tc.PrintfLine("220 hello world")
+ msg, err := tc.ReadLine()
+ if err != nil {
+ errCh <- fmt.Errorf("ReadLine error: %v", err)
+ return
+ }
+ const wantMsg = "EHLO localhost"
+ if msg != wantMsg {
+ errCh <- fmt.Errorf("unexpected response %q; want %q", msg, wantMsg)
+ return
+ }
+ err = tc.PrintfLine("250 mx.google.com at your service")
+ if err != nil {
+ errCh <- fmt.Errorf("PrintfLine: %v", err)
+ return
+ }
+ }()
+
+ err = SendMail(l.Addr().String(), PlainAuth("", "user", "pass", "smtp.google.com"), "test@example.com", []string{"other@example.com"}, []byte(strings.Replace(`From: test@example.com
+To: other@example.com
+Subject: SendMail test
+
+SendMail is working for me.
+`, "\n", "\r\n", -1)))
+ if err == nil {
+ t.Error("SendMail: Server doesn't support AUTH, expected to get an error, but got none ")
+ }
+ if err.Error() != "smtp: server doesn't support AUTH" {
+ t.Errorf("Expected: smtp: server doesn't support AUTH, got: %s", err)
+ }
+ err = <-errCh
+ if err != nil {
+ t.Fatalf("server error: %v", err)
+ }
+}
+
+func TestAuthFailed(t *testing.T) {
+ server := strings.Join(strings.Split(authFailedServer, "\n"), "\r\n")
+ client := strings.Join(strings.Split(authFailedClient, "\n"), "\r\n")
+ var cmdbuf strings.Builder
+ bcmdbuf := bufio.NewWriter(&cmdbuf)
+ var fake faker
+ fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(server)), bcmdbuf)
+ c, err := NewClient(fake, "fake.host")
+ if err != nil {
+ t.Fatalf("NewClient: %v", err)
+ }
+ defer c.Close()
+
+ c.tls = true
+ c.serverName = "smtp.google.com"
+ err = c.Auth(PlainAuth("", "user", "pass", "smtp.google.com"))
+
+ if err == nil {
+ t.Error("Auth: expected error; got none")
+ } else if err.Error() != "535 Invalid credentials\nplease see www.example.com" {
+ t.Errorf("Auth: got error: %v, want: %s", err, "535 Invalid credentials\nplease see www.example.com")
+ }
+
+ bcmdbuf.Flush()
+ actualcmds := cmdbuf.String()
+ if client != actualcmds {
+ t.Errorf("Got:\n%s\nExpected:\n%s", actualcmds, client)
+ }
+}
+
+var authFailedServer = `220 hello world
+250-mx.google.com at your service
+250 AUTH LOGIN PLAIN
+535-Invalid credentials
+535 please see www.example.com
+221 Goodbye
+`
+
+var authFailedClient = `EHLO localhost
+AUTH PLAIN AHVzZXIAcGFzcw==
+*
+QUIT
+`
+
+func TestTLSClient(t *testing.T) {
+ if runtime.GOOS == "freebsd" || runtime.GOOS == "js" || runtime.GOOS == "wasip1" {
+ testenv.SkipFlaky(t, 19229)
+ }
+ ln := newLocalListener(t)
+ defer ln.Close()
+ errc := make(chan error)
+ go func() {
+ errc <- sendMail(ln.Addr().String())
+ }()
+ conn, err := ln.Accept()
+ if err != nil {
+ t.Fatalf("failed to accept connection: %v", err)
+ }
+ defer conn.Close()
+ if err := serverHandle(conn, t); err != nil {
+ t.Fatalf("failed to handle connection: %v", err)
+ }
+ if err := <-errc; err != nil {
+ t.Fatalf("client error: %v", err)
+ }
+}
+
+func TestTLSConnState(t *testing.T) {
+ ln := newLocalListener(t)
+ defer ln.Close()
+ clientDone := make(chan bool)
+ serverDone := make(chan bool)
+ go func() {
+ defer close(serverDone)
+ c, err := ln.Accept()
+ if err != nil {
+ t.Errorf("Server accept: %v", err)
+ return
+ }
+ defer c.Close()
+ if err := serverHandle(c, t); err != nil {
+ t.Errorf("server error: %v", err)
+ }
+ }()
+ go func() {
+ defer close(clientDone)
+ c, err := Dial(ln.Addr().String())
+ if err != nil {
+ t.Errorf("Client dial: %v", err)
+ return
+ }
+ defer c.Quit()
+ cfg := &tls.Config{ServerName: "example.com"}
+ testHookStartTLS(cfg) // set the RootCAs
+ if err := c.StartTLS(cfg); err != nil {
+ t.Errorf("StartTLS: %v", err)
+ return
+ }
+ cs, ok := c.TLSConnectionState()
+ if !ok {
+ t.Errorf("TLSConnectionState returned ok == false; want true")
+ return
+ }
+ if cs.Version == 0 || !cs.HandshakeComplete {
+ t.Errorf("ConnectionState = %#v; expect non-zero Version and HandshakeComplete", cs)
+ }
+ }()
+ <-clientDone
+ <-serverDone
+}
+
+func newLocalListener(t *testing.T) net.Listener {
+ ln, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ ln, err = net.Listen("tcp6", "[::1]:0")
+ }
+ if err != nil {
+ t.Fatal(err)
+ }
+ return ln
+}
+
+type smtpSender struct {
+ w io.Writer
+}
+
+func (s smtpSender) send(f string) {
+ s.w.Write([]byte(f + "\r\n"))
+}
+
+// smtp server, finely tailored to deal with our own client only!
+func serverHandle(c net.Conn, t *testing.T) error {
+ send := smtpSender{c}.send
+ send("220 127.0.0.1 ESMTP service ready")
+ s := bufio.NewScanner(c)
+ for s.Scan() {
+ switch s.Text() {
+ case "EHLO localhost":
+ send("250-127.0.0.1 ESMTP offers a warm hug of welcome")
+ send("250-STARTTLS")
+ send("250 Ok")
+ case "STARTTLS":
+ send("220 Go ahead")
+ keypair, err := tls.X509KeyPair(localhostCert, localhostKey)
+ if err != nil {
+ return err
+ }
+ config := &tls.Config{Certificates: []tls.Certificate{keypair}}
+ c = tls.Server(c, config)
+ defer c.Close()
+ return serverHandleTLS(c, t)
+ default:
+ t.Fatalf("unrecognized command: %q", s.Text())
+ }
+ }
+ return s.Err()
+}
+
+func serverHandleTLS(c net.Conn, t *testing.T) error {
+ send := smtpSender{c}.send
+ s := bufio.NewScanner(c)
+ for s.Scan() {
+ switch s.Text() {
+ case "EHLO localhost":
+ send("250 Ok")
+ case "MAIL FROM:<joe1@example.com>":
+ send("250 Ok")
+ case "RCPT TO:<joe2@example.com>":
+ send("250 Ok")
+ case "DATA":
+ send("354 send the mail data, end with .")
+ send("250 Ok")
+ case "Subject: test":
+ case "":
+ case "howdy!":
+ case ".":
+ case "QUIT":
+ send("221 127.0.0.1 Service closing transmission channel")
+ return nil
+ default:
+ t.Fatalf("unrecognized command during TLS: %q", s.Text())
+ }
+ }
+ return s.Err()
+}
+
+func init() {
+ testRootCAs := x509.NewCertPool()
+ testRootCAs.AppendCertsFromPEM(localhostCert)
+ testHookStartTLS = func(config *tls.Config) {
+ config.RootCAs = testRootCAs
+ }
+}
+
+func sendMail(hostPort string) error {
+ from := "joe1@example.com"
+ to := []string{"joe2@example.com"}
+ return SendMail(hostPort, nil, from, to, []byte("Subject: test\n\nhowdy!"))
+}
+
+// localhostCert is a PEM-encoded TLS cert generated from src/crypto/tls:
+//
+// go run generate_cert.go --rsa-bits 1024 --host 127.0.0.1,::1,example.com \
+// --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h
+var localhostCert = []byte(`
+-----BEGIN CERTIFICATE-----
+MIICFDCCAX2gAwIBAgIRAK0xjnaPuNDSreeXb+z+0u4wDQYJKoZIhvcNAQELBQAw
+EjEQMA4GA1UEChMHQWNtZSBDbzAgFw03MDAxMDEwMDAwMDBaGA8yMDg0MDEyOTE2
+MDAwMFowEjEQMA4GA1UEChMHQWNtZSBDbzCBnzANBgkqhkiG9w0BAQEFAAOBjQAw
+gYkCgYEA0nFbQQuOWsjbGtejcpWz153OlziZM4bVjJ9jYruNw5n2Ry6uYQAffhqa
+JOInCmmcVe2siJglsyH9aRh6vKiobBbIUXXUU1ABd56ebAzlt0LobLlx7pZEMy30
+LqIi9E6zmL3YvdGzpYlkFRnRrqwEtWYbGBf3znO250S56CCWH2UCAwEAAaNoMGYw
+DgYDVR0PAQH/BAQDAgKkMBMGA1UdJQQMMAoGCCsGAQUFBwMBMA8GA1UdEwEB/wQF
+MAMBAf8wLgYDVR0RBCcwJYILZXhhbXBsZS5jb22HBH8AAAGHEAAAAAAAAAAAAAAA
+AAAAAAEwDQYJKoZIhvcNAQELBQADgYEAbZtDS2dVuBYvb+MnolWnCNqvw1w5Gtgi
+NmvQQPOMgM3m+oQSCPRTNGSg25e1Qbo7bgQDv8ZTnq8FgOJ/rbkyERw2JckkHpD4
+n4qcK27WkEDBtQFlPihIM8hLIuzWoi/9wygiElTy/tVL3y7fGCvY2/k1KBthtZGF
+tN8URjVmyEo=
+-----END CERTIFICATE-----`)
+
+// localhostKey is the private key for localhostCert.
+var localhostKey = []byte(testingKey(`
+-----BEGIN RSA TESTING KEY-----
+MIICXgIBAAKBgQDScVtBC45ayNsa16NylbPXnc6XOJkzhtWMn2Niu43DmfZHLq5h
+AB9+Gpok4icKaZxV7ayImCWzIf1pGHq8qKhsFshRddRTUAF3np5sDOW3QuhsuXHu
+lkQzLfQuoiL0TrOYvdi90bOliWQVGdGurAS1ZhsYF/fOc7bnRLnoIJYfZQIDAQAB
+AoGBAMst7OgpKyFV6c3JwyI/jWqxDySL3caU+RuTTBaodKAUx2ZEmNJIlx9eudLA
+kucHvoxsM/eRxlxkhdFxdBcwU6J+zqooTnhu/FE3jhrT1lPrbhfGhyKnUrB0KKMM
+VY3IQZyiehpxaeXAwoAou6TbWoTpl9t8ImAqAMY8hlULCUqlAkEA+9+Ry5FSYK/m
+542LujIcCaIGoG1/Te6Sxr3hsPagKC2rH20rDLqXwEedSFOpSS0vpzlPAzy/6Rbb
+PHTJUhNdwwJBANXkA+TkMdbJI5do9/mn//U0LfrCR9NkcoYohxfKz8JuhgRQxzF2
+6jpo3q7CdTuuRixLWVfeJzcrAyNrVcBq87cCQFkTCtOMNC7fZnCTPUv+9q1tcJyB
+vNjJu3yvoEZeIeuzouX9TJE21/33FaeDdsXbRhQEj23cqR38qFHsF1qAYNMCQQDP
+QXLEiJoClkR2orAmqjPLVhR3t2oB3INcnEjLNSq8LHyQEfXyaFfu4U9l5+fRPL2i
+jiC0k/9L5dHUsF0XZothAkEA23ddgRs+Id/HxtojqqUT27B8MT/IGNrYsp4DvS/c
+qgkeluku4GjxRlDMBuXk94xOBEinUs+p/hwP1Alll80Tpg==
+-----END RSA TESTING KEY-----`))
+
+func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") }
diff --git a/src/net/sock_bsd.go b/src/net/sock_bsd.go
new file mode 100644
index 0000000..27daf72
--- /dev/null
+++ b/src/net/sock_bsd.go
@@ -0,0 +1,39 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build darwin || dragonfly || freebsd || netbsd || openbsd
+
+package net
+
+import (
+ "runtime"
+ "syscall"
+)
+
+func maxListenerBacklog() int {
+ var (
+ n uint32
+ err error
+ )
+ switch runtime.GOOS {
+ case "darwin", "ios":
+ n, err = syscall.SysctlUint32("kern.ipc.somaxconn")
+ case "freebsd":
+ n, err = syscall.SysctlUint32("kern.ipc.soacceptqueue")
+ case "netbsd":
+ // NOTE: NetBSD has no somaxconn-like kernel state so far
+ case "openbsd":
+ n, err = syscall.SysctlUint32("kern.somaxconn")
+ }
+ if n == 0 || err != nil {
+ return syscall.SOMAXCONN
+ }
+ // FreeBSD stores the backlog in a uint16, as does Linux.
+ // Assume the other BSDs do too. Truncate number to avoid wrapping.
+ // See issue 5030.
+ if n > 1<<16-1 {
+ n = 1<<16 - 1
+ }
+ return int(n)
+}
diff --git a/src/net/sock_cloexec.go b/src/net/sock_cloexec.go
new file mode 100644
index 0000000..9eeb897
--- /dev/null
+++ b/src/net/sock_cloexec.go
@@ -0,0 +1,48 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This file implements sysSocket for platforms that provide a fast path for
+// setting SetNonblock and CloseOnExec.
+
+//go:build dragonfly || freebsd || linux || netbsd || openbsd || solaris
+
+package net
+
+import (
+ "internal/poll"
+ "os"
+ "syscall"
+)
+
+// Wrapper around the socket system call that marks the returned file
+// descriptor as nonblocking and close-on-exec.
+func sysSocket(family, sotype, proto int) (int, error) {
+ s, err := socketFunc(family, sotype|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, proto)
+ // TODO: We can remove the fallback on Linux and *BSD,
+ // as currently supported versions all support accept4
+ // with SOCK_CLOEXEC, but Solaris does not. See issue #59359.
+ switch err {
+ case nil:
+ return s, nil
+ default:
+ return -1, os.NewSyscallError("socket", err)
+ case syscall.EPROTONOSUPPORT, syscall.EINVAL:
+ }
+
+ // See ../syscall/exec_unix.go for description of ForkLock.
+ syscall.ForkLock.RLock()
+ s, err = socketFunc(family, sotype, proto)
+ if err == nil {
+ syscall.CloseOnExec(s)
+ }
+ syscall.ForkLock.RUnlock()
+ if err != nil {
+ return -1, os.NewSyscallError("socket", err)
+ }
+ if err = syscall.SetNonblock(s, true); err != nil {
+ poll.CloseFunc(s)
+ return -1, os.NewSyscallError("setnonblock", err)
+ }
+ return s, nil
+}
diff --git a/src/net/sock_linux.go b/src/net/sock_linux.go
new file mode 100644
index 0000000..cffe9a2
--- /dev/null
+++ b/src/net/sock_linux.go
@@ -0,0 +1,54 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "internal/syscall/unix"
+ "syscall"
+)
+
+// Linux stores the backlog as:
+//
+// - uint16 in kernel version < 4.1,
+// - uint32 in kernel version >= 4.1
+//
+// Truncate number to avoid wrapping.
+//
+// See issue 5030 and 41470.
+func maxAckBacklog(n int) int {
+ major, minor := unix.KernelVersion()
+ size := 16
+ if major > 4 || (major == 4 && minor >= 1) {
+ size = 32
+ }
+
+ var max uint = 1<<size - 1
+ if uint(n) > max {
+ n = int(max)
+ }
+ return n
+}
+
+func maxListenerBacklog() int {
+ fd, err := open("/proc/sys/net/core/somaxconn")
+ if err != nil {
+ return syscall.SOMAXCONN
+ }
+ defer fd.close()
+ l, ok := fd.readLine()
+ if !ok {
+ return syscall.SOMAXCONN
+ }
+ f := getFields(l)
+ n, _, ok := dtoi(f[0])
+ if n == 0 || !ok {
+ return syscall.SOMAXCONN
+ }
+
+ if n > 1<<16-1 {
+ return maxAckBacklog(n)
+ }
+ return n
+}
diff --git a/src/net/sock_linux_test.go b/src/net/sock_linux_test.go
new file mode 100644
index 0000000..11303cf
--- /dev/null
+++ b/src/net/sock_linux_test.go
@@ -0,0 +1,23 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "internal/syscall/unix"
+ "testing"
+)
+
+func TestMaxAckBacklog(t *testing.T) {
+ n := 196602
+ major, minor := unix.KernelVersion()
+ backlog := maxAckBacklog(n)
+ expected := 1<<16 - 1
+ if major > 4 || (major == 4 && minor >= 1) {
+ expected = n
+ }
+ if backlog != expected {
+ t.Fatalf(`Kernel version: "%d.%d", sk_max_ack_backlog mismatch, got %d, want %d`, major, minor, backlog, expected)
+ }
+}
diff --git a/src/net/sock_plan9.go b/src/net/sock_plan9.go
new file mode 100644
index 0000000..9367ad8
--- /dev/null
+++ b/src/net/sock_plan9.go
@@ -0,0 +1,10 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+func maxListenerBacklog() int {
+ // /sys/include/ape/sys/socket.h:/SOMAXCONN
+ return 5
+}
diff --git a/src/net/sock_posix.go b/src/net/sock_posix.go
new file mode 100644
index 0000000..b3e1806
--- /dev/null
+++ b/src/net/sock_posix.go
@@ -0,0 +1,259 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build unix || windows
+
+package net
+
+import (
+ "context"
+ "internal/poll"
+ "os"
+ "syscall"
+)
+
+// socket returns a network file descriptor that is ready for
+// asynchronous I/O using the network poller.
+func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only bool, laddr, raddr sockaddr, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) (fd *netFD, err error) {
+ s, err := sysSocket(family, sotype, proto)
+ if err != nil {
+ return nil, err
+ }
+ if err = setDefaultSockopts(s, family, sotype, ipv6only); err != nil {
+ poll.CloseFunc(s)
+ return nil, err
+ }
+ if fd, err = newFD(s, family, sotype, net); err != nil {
+ poll.CloseFunc(s)
+ return nil, err
+ }
+
+ // This function makes a network file descriptor for the
+ // following applications:
+ //
+ // - An endpoint holder that opens a passive stream
+ // connection, known as a stream listener
+ //
+ // - An endpoint holder that opens a destination-unspecific
+ // datagram connection, known as a datagram listener
+ //
+ // - An endpoint holder that opens an active stream or a
+ // destination-specific datagram connection, known as a
+ // dialer
+ //
+ // - An endpoint holder that opens the other connection, such
+ // as talking to the protocol stack inside the kernel
+ //
+ // For stream and datagram listeners, they will only require
+ // named sockets, so we can assume that it's just a request
+ // from stream or datagram listeners when laddr is not nil but
+ // raddr is nil. Otherwise we assume it's just for dialers or
+ // the other connection holders.
+
+ if laddr != nil && raddr == nil {
+ switch sotype {
+ case syscall.SOCK_STREAM, syscall.SOCK_SEQPACKET:
+ if err := fd.listenStream(ctx, laddr, listenerBacklog(), ctrlCtxFn); err != nil {
+ fd.Close()
+ return nil, err
+ }
+ return fd, nil
+ case syscall.SOCK_DGRAM:
+ if err := fd.listenDatagram(ctx, laddr, ctrlCtxFn); err != nil {
+ fd.Close()
+ return nil, err
+ }
+ return fd, nil
+ }
+ }
+ if err := fd.dial(ctx, laddr, raddr, ctrlCtxFn); err != nil {
+ fd.Close()
+ return nil, err
+ }
+ return fd, nil
+}
+
+func (fd *netFD) ctrlNetwork() string {
+ switch fd.net {
+ case "unix", "unixgram", "unixpacket":
+ return fd.net
+ }
+ switch fd.net[len(fd.net)-1] {
+ case '4', '6':
+ return fd.net
+ }
+ if fd.family == syscall.AF_INET {
+ return fd.net + "4"
+ }
+ return fd.net + "6"
+}
+
+func (fd *netFD) addrFunc() func(syscall.Sockaddr) Addr {
+ switch fd.family {
+ case syscall.AF_INET, syscall.AF_INET6:
+ switch fd.sotype {
+ case syscall.SOCK_STREAM:
+ return sockaddrToTCP
+ case syscall.SOCK_DGRAM:
+ return sockaddrToUDP
+ case syscall.SOCK_RAW:
+ return sockaddrToIP
+ }
+ case syscall.AF_UNIX:
+ switch fd.sotype {
+ case syscall.SOCK_STREAM:
+ return sockaddrToUnix
+ case syscall.SOCK_DGRAM:
+ return sockaddrToUnixgram
+ case syscall.SOCK_SEQPACKET:
+ return sockaddrToUnixpacket
+ }
+ }
+ return func(syscall.Sockaddr) Addr { return nil }
+}
+
+func (fd *netFD) dial(ctx context.Context, laddr, raddr sockaddr, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) error {
+ var c *rawConn
+ var err error
+ if ctrlCtxFn != nil {
+ c, err = newRawConn(fd)
+ if err != nil {
+ return err
+ }
+ var ctrlAddr string
+ if raddr != nil {
+ ctrlAddr = raddr.String()
+ } else if laddr != nil {
+ ctrlAddr = laddr.String()
+ }
+ if err := ctrlCtxFn(ctx, fd.ctrlNetwork(), ctrlAddr, c); err != nil {
+ return err
+ }
+ }
+
+ var lsa syscall.Sockaddr
+ if laddr != nil {
+ if lsa, err = laddr.sockaddr(fd.family); err != nil {
+ return err
+ } else if lsa != nil {
+ if err = syscall.Bind(fd.pfd.Sysfd, lsa); err != nil {
+ return os.NewSyscallError("bind", err)
+ }
+ }
+ }
+ var rsa syscall.Sockaddr // remote address from the user
+ var crsa syscall.Sockaddr // remote address we actually connected to
+ if raddr != nil {
+ if rsa, err = raddr.sockaddr(fd.family); err != nil {
+ return err
+ }
+ if crsa, err = fd.connect(ctx, lsa, rsa); err != nil {
+ return err
+ }
+ fd.isConnected = true
+ } else {
+ if err := fd.init(); err != nil {
+ return err
+ }
+ }
+ // Record the local and remote addresses from the actual socket.
+ // Get the local address by calling Getsockname.
+ // For the remote address, use
+ // 1) the one returned by the connect method, if any; or
+ // 2) the one from Getpeername, if it succeeds; or
+ // 3) the one passed to us as the raddr parameter.
+ lsa, _ = syscall.Getsockname(fd.pfd.Sysfd)
+ if crsa != nil {
+ fd.setAddr(fd.addrFunc()(lsa), fd.addrFunc()(crsa))
+ } else if rsa, _ = syscall.Getpeername(fd.pfd.Sysfd); rsa != nil {
+ fd.setAddr(fd.addrFunc()(lsa), fd.addrFunc()(rsa))
+ } else {
+ fd.setAddr(fd.addrFunc()(lsa), raddr)
+ }
+ return nil
+}
+
+func (fd *netFD) listenStream(ctx context.Context, laddr sockaddr, backlog int, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) error {
+ var err error
+ if err = setDefaultListenerSockopts(fd.pfd.Sysfd); err != nil {
+ return err
+ }
+ var lsa syscall.Sockaddr
+ if lsa, err = laddr.sockaddr(fd.family); err != nil {
+ return err
+ }
+
+ if ctrlCtxFn != nil {
+ c, err := newRawConn(fd)
+ if err != nil {
+ return err
+ }
+ if err := ctrlCtxFn(ctx, fd.ctrlNetwork(), laddr.String(), c); err != nil {
+ return err
+ }
+ }
+
+ if err = syscall.Bind(fd.pfd.Sysfd, lsa); err != nil {
+ return os.NewSyscallError("bind", err)
+ }
+ if err = listenFunc(fd.pfd.Sysfd, backlog); err != nil {
+ return os.NewSyscallError("listen", err)
+ }
+ if err = fd.init(); err != nil {
+ return err
+ }
+ lsa, _ = syscall.Getsockname(fd.pfd.Sysfd)
+ fd.setAddr(fd.addrFunc()(lsa), nil)
+ return nil
+}
+
+func (fd *netFD) listenDatagram(ctx context.Context, laddr sockaddr, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) error {
+ switch addr := laddr.(type) {
+ case *UDPAddr:
+ // We provide a socket that listens to a wildcard
+ // address with reusable UDP port when the given laddr
+ // is an appropriate UDP multicast address prefix.
+ // This makes it possible for a single UDP listener to
+ // join multiple different group addresses, for
+ // multiple UDP listeners that listen on the same UDP
+ // port to join the same group address.
+ if addr.IP != nil && addr.IP.IsMulticast() {
+ if err := setDefaultMulticastSockopts(fd.pfd.Sysfd); err != nil {
+ return err
+ }
+ addr := *addr
+ switch fd.family {
+ case syscall.AF_INET:
+ addr.IP = IPv4zero
+ case syscall.AF_INET6:
+ addr.IP = IPv6unspecified
+ }
+ laddr = &addr
+ }
+ }
+ var err error
+ var lsa syscall.Sockaddr
+ if lsa, err = laddr.sockaddr(fd.family); err != nil {
+ return err
+ }
+
+ if ctrlCtxFn != nil {
+ c, err := newRawConn(fd)
+ if err != nil {
+ return err
+ }
+ if err := ctrlCtxFn(ctx, fd.ctrlNetwork(), laddr.String(), c); err != nil {
+ return err
+ }
+ }
+ if err = syscall.Bind(fd.pfd.Sysfd, lsa); err != nil {
+ return os.NewSyscallError("bind", err)
+ }
+ if err = fd.init(); err != nil {
+ return err
+ }
+ lsa, _ = syscall.Getsockname(fd.pfd.Sysfd)
+ fd.setAddr(fd.addrFunc()(lsa), nil)
+ return nil
+}
diff --git a/src/net/sock_stub.go b/src/net/sock_stub.go
new file mode 100644
index 0000000..e163755
--- /dev/null
+++ b/src/net/sock_stub.go
@@ -0,0 +1,15 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build aix || (js && wasm) || solaris || wasip1
+
+package net
+
+import "syscall"
+
+func maxListenerBacklog() int {
+ // TODO: Implement this
+ // NOTE: Never return a number bigger than 1<<16 - 1. See issue 5030.
+ return syscall.SOMAXCONN
+}
diff --git a/src/net/sock_windows.go b/src/net/sock_windows.go
new file mode 100644
index 0000000..fa11c7a
--- /dev/null
+++ b/src/net/sock_windows.go
@@ -0,0 +1,41 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "internal/syscall/windows"
+ "os"
+ "syscall"
+)
+
+func maxListenerBacklog() int {
+ // TODO: Implement this
+ // NOTE: Never return a number bigger than 1<<16 - 1. See issue 5030.
+ return syscall.SOMAXCONN
+}
+
+func sysSocket(family, sotype, proto int) (syscall.Handle, error) {
+ s, err := wsaSocketFunc(int32(family), int32(sotype), int32(proto),
+ nil, 0, windows.WSA_FLAG_OVERLAPPED|windows.WSA_FLAG_NO_HANDLE_INHERIT)
+ if err == nil {
+ return s, nil
+ }
+ // WSA_FLAG_NO_HANDLE_INHERIT flag is not supported on some
+ // old versions of Windows, see
+ // https://msdn.microsoft.com/en-us/library/windows/desktop/ms742212(v=vs.85).aspx
+ // for details. Just use syscall.Socket, if windows.WSASocket failed.
+
+ // See ../syscall/exec_unix.go for description of ForkLock.
+ syscall.ForkLock.RLock()
+ s, err = socketFunc(family, sotype, proto)
+ if err == nil {
+ syscall.CloseOnExec(s)
+ }
+ syscall.ForkLock.RUnlock()
+ if err != nil {
+ return syscall.InvalidHandle, os.NewSyscallError("socket", err)
+ }
+ return s, nil
+}
diff --git a/src/net/sockaddr_posix.go b/src/net/sockaddr_posix.go
new file mode 100644
index 0000000..e44fc76
--- /dev/null
+++ b/src/net/sockaddr_posix.go
@@ -0,0 +1,34 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build unix || (js && wasm) || wasip1 || windows
+
+package net
+
+import (
+ "syscall"
+)
+
+// A sockaddr represents a TCP, UDP, IP or Unix network endpoint
+// address that can be converted into a syscall.Sockaddr.
+type sockaddr interface {
+ Addr
+
+ // family returns the platform-dependent address family
+ // identifier.
+ family() int
+
+ // isWildcard reports whether the address is a wildcard
+ // address.
+ isWildcard() bool
+
+ // sockaddr returns the address converted into a syscall
+ // sockaddr type that implements syscall.Sockaddr
+ // interface. It returns a nil interface when the address is
+ // nil.
+ sockaddr(family int) (syscall.Sockaddr, error)
+
+ // toLocal maps the zero address to a local system address (127.0.0.1 or ::1)
+ toLocal(net string) sockaddr
+}
diff --git a/src/net/sockopt_aix.go b/src/net/sockopt_aix.go
new file mode 100644
index 0000000..7729a44
--- /dev/null
+++ b/src/net/sockopt_aix.go
@@ -0,0 +1,39 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "os"
+ "syscall"
+)
+
+func setDefaultSockopts(s, family, sotype int, ipv6only bool) error {
+ if family == syscall.AF_INET6 && sotype != syscall.SOCK_RAW {
+ // Allow both IP versions even if the OS default
+ // is otherwise. Note that some operating systems
+ // never admit this option.
+ syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, boolint(ipv6only))
+ }
+ if (sotype == syscall.SOCK_DGRAM || sotype == syscall.SOCK_RAW) && family != syscall.AF_UNIX {
+ // Allow broadcast.
+ return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1))
+ }
+ return nil
+}
+
+func setDefaultListenerSockopts(s int) error {
+ // Allow reuse of recently-used addresses.
+ return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1))
+}
+
+func setDefaultMulticastSockopts(s int) error {
+ // Allow multicast UDP and raw IP datagram sockets to listen
+ // concurrently across multiple listeners.
+ if err := syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1); err != nil {
+ return os.NewSyscallError("setsockopt", err)
+ }
+ // Allow reuse of recently-used ports.
+ return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEPORT, 1))
+}
diff --git a/src/net/sockopt_bsd.go b/src/net/sockopt_bsd.go
new file mode 100644
index 0000000..ff99811
--- /dev/null
+++ b/src/net/sockopt_bsd.go
@@ -0,0 +1,57 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build darwin || dragonfly || freebsd || netbsd || openbsd
+
+package net
+
+import (
+ "os"
+ "runtime"
+ "syscall"
+)
+
+func setDefaultSockopts(s, family, sotype int, ipv6only bool) error {
+ if runtime.GOOS == "dragonfly" && sotype != syscall.SOCK_RAW {
+ // On DragonFly BSD, we adjust the ephemeral port
+ // range because unlike other BSD systems its default
+ // port range doesn't conform to IANA recommendation
+ // as described in RFC 6056 and is pretty narrow.
+ switch family {
+ case syscall.AF_INET:
+ syscall.SetsockoptInt(s, syscall.IPPROTO_IP, syscall.IP_PORTRANGE, syscall.IP_PORTRANGE_HIGH)
+ case syscall.AF_INET6:
+ syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_PORTRANGE, syscall.IPV6_PORTRANGE_HIGH)
+ }
+ }
+ if family == syscall.AF_INET6 && sotype != syscall.SOCK_RAW && supportsIPv4map() {
+ // Allow both IP versions even if the OS default
+ // is otherwise. Note that some operating systems
+ // never admit this option.
+ syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, boolint(ipv6only))
+ }
+ if (sotype == syscall.SOCK_DGRAM || sotype == syscall.SOCK_RAW) && family != syscall.AF_UNIX {
+ // Allow broadcast.
+ return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1))
+ }
+ return nil
+}
+
+func setDefaultListenerSockopts(s int) error {
+ // Allow reuse of recently-used addresses.
+ return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1))
+}
+
+func setDefaultMulticastSockopts(s int) error {
+ // Allow multicast UDP and raw IP datagram sockets to listen
+ // concurrently across multiple listeners.
+ if err := syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1); err != nil {
+ return os.NewSyscallError("setsockopt", err)
+ }
+ // Allow reuse of recently-used ports.
+ // This option is supported only in descendants of 4.4BSD,
+ // to make an effective multicast application that requires
+ // quick draw possible.
+ return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEPORT, 1))
+}
diff --git a/src/net/sockopt_linux.go b/src/net/sockopt_linux.go
new file mode 100644
index 0000000..3d54429
--- /dev/null
+++ b/src/net/sockopt_linux.go
@@ -0,0 +1,35 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "os"
+ "syscall"
+)
+
+func setDefaultSockopts(s, family, sotype int, ipv6only bool) error {
+ if family == syscall.AF_INET6 && sotype != syscall.SOCK_RAW {
+ // Allow both IP versions even if the OS default
+ // is otherwise. Note that some operating systems
+ // never admit this option.
+ syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, boolint(ipv6only))
+ }
+ if (sotype == syscall.SOCK_DGRAM || sotype == syscall.SOCK_RAW) && family != syscall.AF_UNIX {
+ // Allow broadcast.
+ return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1))
+ }
+ return nil
+}
+
+func setDefaultListenerSockopts(s int) error {
+ // Allow reuse of recently-used addresses.
+ return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1))
+}
+
+func setDefaultMulticastSockopts(s int) error {
+ // Allow multicast UDP and raw IP datagram sockets to listen
+ // concurrently across multiple listeners.
+ return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1))
+}
diff --git a/src/net/sockopt_plan9.go b/src/net/sockopt_plan9.go
new file mode 100644
index 0000000..02468cd
--- /dev/null
+++ b/src/net/sockopt_plan9.go
@@ -0,0 +1,19 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import "syscall"
+
+func setKeepAlive(fd *netFD, keepalive bool) error {
+ if keepalive {
+ _, e := fd.ctl.WriteAt([]byte("keepalive"), 0)
+ return e
+ }
+ return nil
+}
+
+func setLinger(fd *netFD, sec int) error {
+ return syscall.EPLAN9
+}
diff --git a/src/net/sockopt_posix.go b/src/net/sockopt_posix.go
new file mode 100644
index 0000000..32e8fcd
--- /dev/null
+++ b/src/net/sockopt_posix.go
@@ -0,0 +1,134 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build unix || windows
+
+package net
+
+import (
+ "internal/bytealg"
+ "runtime"
+ "syscall"
+)
+
+// Boolean to int.
+func boolint(b bool) int {
+ if b {
+ return 1
+ }
+ return 0
+}
+
+func ipv4AddrToInterface(ip IP) (*Interface, error) {
+ ift, err := Interfaces()
+ if err != nil {
+ return nil, err
+ }
+ for _, ifi := range ift {
+ ifat, err := ifi.Addrs()
+ if err != nil {
+ return nil, err
+ }
+ for _, ifa := range ifat {
+ switch v := ifa.(type) {
+ case *IPAddr:
+ if ip.Equal(v.IP) {
+ return &ifi, nil
+ }
+ case *IPNet:
+ if ip.Equal(v.IP) {
+ return &ifi, nil
+ }
+ }
+ }
+ }
+ if ip.Equal(IPv4zero) {
+ return nil, nil
+ }
+ return nil, errNoSuchInterface
+}
+
+func interfaceToIPv4Addr(ifi *Interface) (IP, error) {
+ if ifi == nil {
+ return IPv4zero, nil
+ }
+ ifat, err := ifi.Addrs()
+ if err != nil {
+ return nil, err
+ }
+ for _, ifa := range ifat {
+ switch v := ifa.(type) {
+ case *IPAddr:
+ if v.IP.To4() != nil {
+ return v.IP, nil
+ }
+ case *IPNet:
+ if v.IP.To4() != nil {
+ return v.IP, nil
+ }
+ }
+ }
+ return nil, errNoSuchInterface
+}
+
+func setIPv4MreqToInterface(mreq *syscall.IPMreq, ifi *Interface) error {
+ if ifi == nil {
+ return nil
+ }
+ ifat, err := ifi.Addrs()
+ if err != nil {
+ return err
+ }
+ for _, ifa := range ifat {
+ switch v := ifa.(type) {
+ case *IPAddr:
+ if a := v.IP.To4(); a != nil {
+ copy(mreq.Interface[:], a)
+ goto done
+ }
+ case *IPNet:
+ if a := v.IP.To4(); a != nil {
+ copy(mreq.Interface[:], a)
+ goto done
+ }
+ }
+ }
+done:
+ if bytealg.Equal(mreq.Multiaddr[:], IPv4zero.To4()) {
+ return errNoSuchMulticastInterface
+ }
+ return nil
+}
+
+func setReadBuffer(fd *netFD, bytes int) error {
+ err := fd.pfd.SetsockoptInt(syscall.SOL_SOCKET, syscall.SO_RCVBUF, bytes)
+ runtime.KeepAlive(fd)
+ return wrapSyscallError("setsockopt", err)
+}
+
+func setWriteBuffer(fd *netFD, bytes int) error {
+ err := fd.pfd.SetsockoptInt(syscall.SOL_SOCKET, syscall.SO_SNDBUF, bytes)
+ runtime.KeepAlive(fd)
+ return wrapSyscallError("setsockopt", err)
+}
+
+func setKeepAlive(fd *netFD, keepalive bool) error {
+ err := fd.pfd.SetsockoptInt(syscall.SOL_SOCKET, syscall.SO_KEEPALIVE, boolint(keepalive))
+ runtime.KeepAlive(fd)
+ return wrapSyscallError("setsockopt", err)
+}
+
+func setLinger(fd *netFD, sec int) error {
+ var l syscall.Linger
+ if sec >= 0 {
+ l.Onoff = 1
+ l.Linger = int32(sec)
+ } else {
+ l.Onoff = 0
+ l.Linger = 0
+ }
+ err := fd.pfd.SetsockoptLinger(syscall.SOL_SOCKET, syscall.SO_LINGER, &l)
+ runtime.KeepAlive(fd)
+ return wrapSyscallError("setsockopt", err)
+}
diff --git a/src/net/sockopt_solaris.go b/src/net/sockopt_solaris.go
new file mode 100644
index 0000000..3d54429
--- /dev/null
+++ b/src/net/sockopt_solaris.go
@@ -0,0 +1,35 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "os"
+ "syscall"
+)
+
+func setDefaultSockopts(s, family, sotype int, ipv6only bool) error {
+ if family == syscall.AF_INET6 && sotype != syscall.SOCK_RAW {
+ // Allow both IP versions even if the OS default
+ // is otherwise. Note that some operating systems
+ // never admit this option.
+ syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, boolint(ipv6only))
+ }
+ if (sotype == syscall.SOCK_DGRAM || sotype == syscall.SOCK_RAW) && family != syscall.AF_UNIX {
+ // Allow broadcast.
+ return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1))
+ }
+ return nil
+}
+
+func setDefaultListenerSockopts(s int) error {
+ // Allow reuse of recently-used addresses.
+ return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1))
+}
+
+func setDefaultMulticastSockopts(s int) error {
+ // Allow multicast UDP and raw IP datagram sockets to listen
+ // concurrently across multiple listeners.
+ return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1))
+}
diff --git a/src/net/sockopt_stub.go b/src/net/sockopt_stub.go
new file mode 100644
index 0000000..186d891
--- /dev/null
+++ b/src/net/sockopt_stub.go
@@ -0,0 +1,37 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build (js && wasm) || wasip1
+
+package net
+
+import "syscall"
+
+func setDefaultSockopts(s, family, sotype int, ipv6only bool) error {
+ return nil
+}
+
+func setDefaultListenerSockopts(s int) error {
+ return nil
+}
+
+func setDefaultMulticastSockopts(s int) error {
+ return nil
+}
+
+func setReadBuffer(fd *netFD, bytes int) error {
+ return syscall.ENOPROTOOPT
+}
+
+func setWriteBuffer(fd *netFD, bytes int) error {
+ return syscall.ENOPROTOOPT
+}
+
+func setKeepAlive(fd *netFD, keepalive bool) error {
+ return syscall.ENOPROTOOPT
+}
+
+func setLinger(fd *netFD, sec int) error {
+ return syscall.ENOPROTOOPT
+}
diff --git a/src/net/sockopt_windows.go b/src/net/sockopt_windows.go
new file mode 100644
index 0000000..8afaf34
--- /dev/null
+++ b/src/net/sockopt_windows.go
@@ -0,0 +1,40 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "os"
+ "syscall"
+)
+
+func setDefaultSockopts(s syscall.Handle, family, sotype int, ipv6only bool) error {
+ if family == syscall.AF_INET6 && sotype != syscall.SOCK_RAW {
+ // Allow both IP versions even if the OS default
+ // is otherwise. Note that some operating systems
+ // never admit this option.
+ syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, boolint(ipv6only))
+ }
+ if (sotype == syscall.SOCK_DGRAM || sotype == syscall.SOCK_RAW) && family != syscall.AF_UNIX && family != syscall.AF_INET6 {
+ // Allow broadcast.
+ return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1))
+ }
+ return nil
+}
+
+func setDefaultListenerSockopts(s syscall.Handle) error {
+ // Windows will reuse recently-used addresses by default.
+ // SO_REUSEADDR should not be used here, as it allows
+ // a socket to forcibly bind to a port in use by another socket.
+ // This could lead to a non-deterministic behavior, where
+ // connection requests over the port cannot be guaranteed
+ // to be handled by the correct socket.
+ return nil
+}
+
+func setDefaultMulticastSockopts(s syscall.Handle) error {
+ // Allow multicast UDP and raw IP datagram sockets to listen
+ // concurrently across multiple listeners.
+ return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1))
+}
diff --git a/src/net/sockoptip_bsdvar.go b/src/net/sockoptip_bsdvar.go
new file mode 100644
index 0000000..3e9ba1e
--- /dev/null
+++ b/src/net/sockoptip_bsdvar.go
@@ -0,0 +1,30 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build aix || darwin || dragonfly || freebsd || netbsd || openbsd || solaris
+
+package net
+
+import (
+ "runtime"
+ "syscall"
+)
+
+func setIPv4MulticastInterface(fd *netFD, ifi *Interface) error {
+ ip, err := interfaceToIPv4Addr(ifi)
+ if err != nil {
+ return wrapSyscallError("setsockopt", err)
+ }
+ var a [4]byte
+ copy(a[:], ip.To4())
+ err = fd.pfd.SetsockoptInet4Addr(syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF, a)
+ runtime.KeepAlive(fd)
+ return wrapSyscallError("setsockopt", err)
+}
+
+func setIPv4MulticastLoopback(fd *netFD, v bool) error {
+ err := fd.pfd.SetsockoptByte(syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP, byte(boolint(v)))
+ runtime.KeepAlive(fd)
+ return wrapSyscallError("setsockopt", err)
+}
diff --git a/src/net/sockoptip_linux.go b/src/net/sockoptip_linux.go
new file mode 100644
index 0000000..bd7d834
--- /dev/null
+++ b/src/net/sockoptip_linux.go
@@ -0,0 +1,27 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "runtime"
+ "syscall"
+)
+
+func setIPv4MulticastInterface(fd *netFD, ifi *Interface) error {
+ var v int32
+ if ifi != nil {
+ v = int32(ifi.Index)
+ }
+ mreq := &syscall.IPMreqn{Ifindex: v}
+ err := fd.pfd.SetsockoptIPMreqn(syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF, mreq)
+ runtime.KeepAlive(fd)
+ return wrapSyscallError("setsockopt", err)
+}
+
+func setIPv4MulticastLoopback(fd *netFD, v bool) error {
+ err := fd.pfd.SetsockoptInt(syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP, boolint(v))
+ runtime.KeepAlive(fd)
+ return wrapSyscallError("setsockopt", err)
+}
diff --git a/src/net/sockoptip_posix.go b/src/net/sockoptip_posix.go
new file mode 100644
index 0000000..572ea45
--- /dev/null
+++ b/src/net/sockoptip_posix.go
@@ -0,0 +1,49 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build unix || windows
+
+package net
+
+import (
+ "runtime"
+ "syscall"
+)
+
+func joinIPv4Group(fd *netFD, ifi *Interface, ip IP) error {
+ mreq := &syscall.IPMreq{Multiaddr: [4]byte{ip[0], ip[1], ip[2], ip[3]}}
+ if err := setIPv4MreqToInterface(mreq, ifi); err != nil {
+ return err
+ }
+ err := fd.pfd.SetsockoptIPMreq(syscall.IPPROTO_IP, syscall.IP_ADD_MEMBERSHIP, mreq)
+ runtime.KeepAlive(fd)
+ return wrapSyscallError("setsockopt", err)
+}
+
+func setIPv6MulticastInterface(fd *netFD, ifi *Interface) error {
+ var v int
+ if ifi != nil {
+ v = ifi.Index
+ }
+ err := fd.pfd.SetsockoptInt(syscall.IPPROTO_IPV6, syscall.IPV6_MULTICAST_IF, v)
+ runtime.KeepAlive(fd)
+ return wrapSyscallError("setsockopt", err)
+}
+
+func setIPv6MulticastLoopback(fd *netFD, v bool) error {
+ err := fd.pfd.SetsockoptInt(syscall.IPPROTO_IPV6, syscall.IPV6_MULTICAST_LOOP, boolint(v))
+ runtime.KeepAlive(fd)
+ return wrapSyscallError("setsockopt", err)
+}
+
+func joinIPv6Group(fd *netFD, ifi *Interface, ip IP) error {
+ mreq := &syscall.IPv6Mreq{}
+ copy(mreq.Multiaddr[:], ip)
+ if ifi != nil {
+ mreq.Interface = uint32(ifi.Index)
+ }
+ err := fd.pfd.SetsockoptIPv6Mreq(syscall.IPPROTO_IPV6, syscall.IPV6_JOIN_GROUP, mreq)
+ runtime.KeepAlive(fd)
+ return wrapSyscallError("setsockopt", err)
+}
diff --git a/src/net/sockoptip_stub.go b/src/net/sockoptip_stub.go
new file mode 100644
index 0000000..a37c312
--- /dev/null
+++ b/src/net/sockoptip_stub.go
@@ -0,0 +1,33 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build (js && wasm) || wasip1
+
+package net
+
+import "syscall"
+
+func setIPv4MulticastInterface(fd *netFD, ifi *Interface) error {
+ return syscall.ENOPROTOOPT
+}
+
+func setIPv4MulticastLoopback(fd *netFD, v bool) error {
+ return syscall.ENOPROTOOPT
+}
+
+func joinIPv4Group(fd *netFD, ifi *Interface, ip IP) error {
+ return syscall.ENOPROTOOPT
+}
+
+func setIPv6MulticastInterface(fd *netFD, ifi *Interface) error {
+ return syscall.ENOPROTOOPT
+}
+
+func setIPv6MulticastLoopback(fd *netFD, v bool) error {
+ return syscall.ENOPROTOOPT
+}
+
+func joinIPv6Group(fd *netFD, ifi *Interface, ip IP) error {
+ return syscall.ENOPROTOOPT
+}
diff --git a/src/net/sockoptip_windows.go b/src/net/sockoptip_windows.go
new file mode 100644
index 0000000..6267603
--- /dev/null
+++ b/src/net/sockoptip_windows.go
@@ -0,0 +1,30 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "os"
+ "runtime"
+ "syscall"
+ "unsafe"
+)
+
+func setIPv4MulticastInterface(fd *netFD, ifi *Interface) error {
+ ip, err := interfaceToIPv4Addr(ifi)
+ if err != nil {
+ return os.NewSyscallError("setsockopt", err)
+ }
+ var a [4]byte
+ copy(a[:], ip.To4())
+ err = fd.pfd.Setsockopt(syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF, (*byte)(unsafe.Pointer(&a[0])), 4)
+ runtime.KeepAlive(fd)
+ return wrapSyscallError("setsockopt", err)
+}
+
+func setIPv4MulticastLoopback(fd *netFD, v bool) error {
+ err := fd.pfd.SetsockoptInt(syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP, boolint(v))
+ runtime.KeepAlive(fd)
+ return wrapSyscallError("setsockopt", err)
+}
diff --git a/src/net/splice_linux.go b/src/net/splice_linux.go
new file mode 100644
index 0000000..ab2ab70
--- /dev/null
+++ b/src/net/splice_linux.go
@@ -0,0 +1,44 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "internal/poll"
+ "io"
+)
+
+// splice transfers data from r to c using the splice system call to minimize
+// copies from and to userspace. c must be a TCP connection. Currently, splice
+// is only enabled if r is a TCP or a stream-oriented Unix connection.
+//
+// If splice returns handled == false, it has performed no work.
+func splice(c *netFD, r io.Reader) (written int64, err error, handled bool) {
+ var remain int64 = 1<<63 - 1 // by default, copy until EOF
+ lr, ok := r.(*io.LimitedReader)
+ if ok {
+ remain, r = lr.N, lr.R
+ if remain <= 0 {
+ return 0, nil, true
+ }
+ }
+
+ var s *netFD
+ if tc, ok := r.(*TCPConn); ok {
+ s = tc.fd
+ } else if uc, ok := r.(*UnixConn); ok {
+ if uc.fd.net != "unix" {
+ return 0, nil, false
+ }
+ s = uc.fd
+ } else {
+ return 0, nil, false
+ }
+
+ written, handled, sc, err := poll.Splice(&c.pfd, &s.pfd, remain)
+ if lr != nil {
+ lr.N -= written
+ }
+ return written, wrapSyscallError(sc, err), handled
+}
diff --git a/src/net/splice_stub.go b/src/net/splice_stub.go
new file mode 100644
index 0000000..3cdadb1
--- /dev/null
+++ b/src/net/splice_stub.go
@@ -0,0 +1,13 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !linux
+
+package net
+
+import "io"
+
+func splice(c *netFD, r io.Reader) (int64, error, bool) {
+ return 0, nil, false
+}
diff --git a/src/net/splice_test.go b/src/net/splice_test.go
new file mode 100644
index 0000000..75a8f27
--- /dev/null
+++ b/src/net/splice_test.go
@@ -0,0 +1,532 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build linux
+
+package net
+
+import (
+ "io"
+ "log"
+ "os"
+ "os/exec"
+ "strconv"
+ "sync"
+ "testing"
+ "time"
+)
+
+func TestSplice(t *testing.T) {
+ t.Run("tcp-to-tcp", func(t *testing.T) { testSplice(t, "tcp", "tcp") })
+ if !testableNetwork("unixgram") {
+ t.Skip("skipping unix-to-tcp tests")
+ }
+ t.Run("unix-to-tcp", func(t *testing.T) { testSplice(t, "unix", "tcp") })
+ t.Run("tcp-to-file", func(t *testing.T) { testSpliceToFile(t, "tcp", "file") })
+ t.Run("unix-to-file", func(t *testing.T) { testSpliceToFile(t, "unix", "file") })
+ t.Run("no-unixpacket", testSpliceNoUnixpacket)
+ t.Run("no-unixgram", testSpliceNoUnixgram)
+}
+
+func testSpliceToFile(t *testing.T, upNet, downNet string) {
+ t.Run("simple", spliceTestCase{upNet, downNet, 128, 128, 0}.testFile)
+ t.Run("multipleWrite", spliceTestCase{upNet, downNet, 4096, 1 << 20, 0}.testFile)
+ t.Run("big", spliceTestCase{upNet, downNet, 5 << 20, 1 << 30, 0}.testFile)
+ t.Run("honorsLimitedReader", spliceTestCase{upNet, downNet, 4096, 1 << 20, 1 << 10}.testFile)
+ t.Run("updatesLimitedReaderN", spliceTestCase{upNet, downNet, 1024, 4096, 4096 + 100}.testFile)
+ t.Run("limitedReaderAtLimit", spliceTestCase{upNet, downNet, 32, 128, 128}.testFile)
+}
+
+func testSplice(t *testing.T, upNet, downNet string) {
+ t.Run("simple", spliceTestCase{upNet, downNet, 128, 128, 0}.test)
+ t.Run("multipleWrite", spliceTestCase{upNet, downNet, 4096, 1 << 20, 0}.test)
+ t.Run("big", spliceTestCase{upNet, downNet, 5 << 20, 1 << 30, 0}.test)
+ t.Run("honorsLimitedReader", spliceTestCase{upNet, downNet, 4096, 1 << 20, 1 << 10}.test)
+ t.Run("updatesLimitedReaderN", spliceTestCase{upNet, downNet, 1024, 4096, 4096 + 100}.test)
+ t.Run("limitedReaderAtLimit", spliceTestCase{upNet, downNet, 32, 128, 128}.test)
+ t.Run("readerAtEOF", func(t *testing.T) { testSpliceReaderAtEOF(t, upNet, downNet) })
+ t.Run("issue25985", func(t *testing.T) { testSpliceIssue25985(t, upNet, downNet) })
+}
+
+type spliceTestCase struct {
+ upNet, downNet string
+
+ chunkSize, totalSize int
+ limitReadSize int
+}
+
+func (tc spliceTestCase) test(t *testing.T) {
+ clientUp, serverUp := spliceTestSocketPair(t, tc.upNet)
+ defer serverUp.Close()
+ cleanup, err := startSpliceClient(clientUp, "w", tc.chunkSize, tc.totalSize)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cleanup()
+ clientDown, serverDown := spliceTestSocketPair(t, tc.downNet)
+ defer serverDown.Close()
+ cleanup, err = startSpliceClient(clientDown, "r", tc.chunkSize, tc.totalSize)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cleanup()
+ var (
+ r io.Reader = serverUp
+ size = tc.totalSize
+ )
+ if tc.limitReadSize > 0 {
+ if tc.limitReadSize < size {
+ size = tc.limitReadSize
+ }
+
+ r = &io.LimitedReader{
+ N: int64(tc.limitReadSize),
+ R: serverUp,
+ }
+ defer serverUp.Close()
+ }
+ n, err := io.Copy(serverDown, r)
+ serverDown.Close()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if want := int64(size); want != n {
+ t.Errorf("want %d bytes spliced, got %d", want, n)
+ }
+
+ if tc.limitReadSize > 0 {
+ wantN := 0
+ if tc.limitReadSize > size {
+ wantN = tc.limitReadSize - size
+ }
+
+ if n := r.(*io.LimitedReader).N; n != int64(wantN) {
+ t.Errorf("r.N = %d, want %d", n, wantN)
+ }
+ }
+}
+
+func (tc spliceTestCase) testFile(t *testing.T) {
+ f, err := os.OpenFile(os.DevNull, os.O_WRONLY, 0)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer f.Close()
+
+ client, server := spliceTestSocketPair(t, tc.upNet)
+ defer server.Close()
+
+ cleanup, err := startSpliceClient(client, "w", tc.chunkSize, tc.totalSize)
+ if err != nil {
+ client.Close()
+ t.Fatal("failed to start splice client:", err)
+ }
+ defer cleanup()
+
+ var (
+ r io.Reader = server
+ actualSize = tc.totalSize
+ )
+ if tc.limitReadSize > 0 {
+ if tc.limitReadSize < actualSize {
+ actualSize = tc.limitReadSize
+ }
+
+ r = &io.LimitedReader{
+ N: int64(tc.limitReadSize),
+ R: r,
+ }
+ }
+
+ got, err := io.Copy(f, r)
+ if err != nil {
+ t.Fatalf("failed to ReadFrom with error: %v", err)
+ }
+ if want := int64(actualSize); got != want {
+ t.Errorf("got %d bytes, want %d", got, want)
+ }
+ if tc.limitReadSize > 0 {
+ wantN := 0
+ if tc.limitReadSize > actualSize {
+ wantN = tc.limitReadSize - actualSize
+ }
+
+ if gotN := r.(*io.LimitedReader).N; gotN != int64(wantN) {
+ t.Errorf("r.N = %d, want %d", gotN, wantN)
+ }
+ }
+}
+
+func testSpliceReaderAtEOF(t *testing.T, upNet, downNet string) {
+ clientUp, serverUp := spliceTestSocketPair(t, upNet)
+ defer clientUp.Close()
+ clientDown, serverDown := spliceTestSocketPair(t, downNet)
+ defer clientDown.Close()
+
+ serverUp.Close()
+
+ // We'd like to call net.splice here and check the handled return
+ // value, but we disable splice on old Linux kernels.
+ //
+ // In that case, poll.Splice and net.splice return a non-nil error
+ // and handled == false. We'd ideally like to see handled == true
+ // because the source reader is at EOF, but if we're running on an old
+ // kernel, and splice is disabled, we won't see EOF from net.splice,
+ // because we won't touch the reader at all.
+ //
+ // Trying to untangle the errors from net.splice and match them
+ // against the errors created by the poll package would be brittle,
+ // so this is a higher level test.
+ //
+ // The following ReadFrom should return immediately, regardless of
+ // whether splice is disabled or not. The other side should then
+ // get a goodbye signal. Test for the goodbye signal.
+ msg := "bye"
+ go func() {
+ serverDown.(io.ReaderFrom).ReadFrom(serverUp)
+ io.WriteString(serverDown, msg)
+ serverDown.Close()
+ }()
+
+ buf := make([]byte, 3)
+ _, err := io.ReadFull(clientDown, buf)
+ if err != nil {
+ t.Errorf("clientDown: %v", err)
+ }
+ if string(buf) != msg {
+ t.Errorf("clientDown got %q, want %q", buf, msg)
+ }
+}
+
+func testSpliceIssue25985(t *testing.T, upNet, downNet string) {
+ front := newLocalListener(t, upNet)
+ defer front.Close()
+ back := newLocalListener(t, downNet)
+ defer back.Close()
+
+ var wg sync.WaitGroup
+ wg.Add(2)
+
+ proxy := func() {
+ src, err := front.Accept()
+ if err != nil {
+ return
+ }
+ dst, err := Dial(downNet, back.Addr().String())
+ if err != nil {
+ return
+ }
+ defer dst.Close()
+ defer src.Close()
+ go func() {
+ io.Copy(src, dst)
+ wg.Done()
+ }()
+ go func() {
+ io.Copy(dst, src)
+ wg.Done()
+ }()
+ }
+
+ go proxy()
+
+ toFront, err := Dial(upNet, front.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ io.WriteString(toFront, "foo")
+ toFront.Close()
+
+ fromProxy, err := back.Accept()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer fromProxy.Close()
+
+ _, err = io.ReadAll(fromProxy)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ wg.Wait()
+}
+
+func testSpliceNoUnixpacket(t *testing.T) {
+ clientUp, serverUp := spliceTestSocketPair(t, "unixpacket")
+ defer clientUp.Close()
+ defer serverUp.Close()
+ clientDown, serverDown := spliceTestSocketPair(t, "tcp")
+ defer clientDown.Close()
+ defer serverDown.Close()
+ // If splice called poll.Splice here, we'd get err == syscall.EINVAL
+ // and handled == false. If poll.Splice gets an EINVAL on the first
+ // try, it assumes the kernel it's running on doesn't support splice
+ // for unix sockets and returns handled == false. This works for our
+ // purposes by somewhat of an accident, but is not entirely correct.
+ //
+ // What we want is err == nil and handled == false, i.e. we never
+ // called poll.Splice, because we know the unix socket's network.
+ _, err, handled := splice(serverDown.(*TCPConn).fd, serverUp)
+ if err != nil || handled != false {
+ t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled)
+ }
+}
+
+func testSpliceNoUnixgram(t *testing.T) {
+ addr, err := ResolveUnixAddr("unixgram", testUnixAddr(t))
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer os.Remove(addr.Name)
+ up, err := ListenUnixgram("unixgram", addr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer up.Close()
+ clientDown, serverDown := spliceTestSocketPair(t, "tcp")
+ defer clientDown.Close()
+ defer serverDown.Close()
+ // Analogous to testSpliceNoUnixpacket.
+ _, err, handled := splice(serverDown.(*TCPConn).fd, up)
+ if err != nil || handled != false {
+ t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled)
+ }
+}
+
+func BenchmarkSplice(b *testing.B) {
+ testHookUninstaller.Do(uninstallTestHooks)
+
+ b.Run("tcp-to-tcp", func(b *testing.B) { benchSplice(b, "tcp", "tcp") })
+ b.Run("unix-to-tcp", func(b *testing.B) { benchSplice(b, "unix", "tcp") })
+}
+
+func benchSplice(b *testing.B, upNet, downNet string) {
+ for i := 0; i <= 10; i++ {
+ chunkSize := 1 << uint(i+10)
+ tc := spliceTestCase{
+ upNet: upNet,
+ downNet: downNet,
+ chunkSize: chunkSize,
+ }
+
+ b.Run(strconv.Itoa(chunkSize), tc.bench)
+ }
+}
+
+func (tc spliceTestCase) bench(b *testing.B) {
+ // To benchmark the genericReadFrom code path, set this to false.
+ useSplice := true
+
+ clientUp, serverUp := spliceTestSocketPair(b, tc.upNet)
+ defer serverUp.Close()
+
+ cleanup, err := startSpliceClient(clientUp, "w", tc.chunkSize, tc.chunkSize*b.N)
+ if err != nil {
+ b.Fatal(err)
+ }
+ defer cleanup()
+
+ clientDown, serverDown := spliceTestSocketPair(b, tc.downNet)
+ defer serverDown.Close()
+
+ cleanup, err = startSpliceClient(clientDown, "r", tc.chunkSize, tc.chunkSize*b.N)
+ if err != nil {
+ b.Fatal(err)
+ }
+ defer cleanup()
+
+ b.SetBytes(int64(tc.chunkSize))
+ b.ResetTimer()
+
+ if useSplice {
+ _, err := io.Copy(serverDown, serverUp)
+ if err != nil {
+ b.Fatal(err)
+ }
+ } else {
+ type onlyReader struct {
+ io.Reader
+ }
+ _, err := io.Copy(serverDown, onlyReader{serverUp})
+ if err != nil {
+ b.Fatal(err)
+ }
+ }
+}
+
+func spliceTestSocketPair(t testing.TB, net string) (client, server Conn) {
+ t.Helper()
+ ln := newLocalListener(t, net)
+ defer ln.Close()
+ var cerr, serr error
+ acceptDone := make(chan struct{})
+ go func() {
+ server, serr = ln.Accept()
+ acceptDone <- struct{}{}
+ }()
+ client, cerr = Dial(ln.Addr().Network(), ln.Addr().String())
+ <-acceptDone
+ if cerr != nil {
+ if server != nil {
+ server.Close()
+ }
+ t.Fatal(cerr)
+ }
+ if serr != nil {
+ if client != nil {
+ client.Close()
+ }
+ t.Fatal(serr)
+ }
+ return client, server
+}
+
+func startSpliceClient(conn Conn, op string, chunkSize, totalSize int) (func(), error) {
+ f, err := conn.(interface{ File() (*os.File, error) }).File()
+ if err != nil {
+ return nil, err
+ }
+
+ cmd := exec.Command(os.Args[0], os.Args[1:]...)
+ cmd.Env = []string{
+ "GO_NET_TEST_SPLICE=1",
+ "GO_NET_TEST_SPLICE_OP=" + op,
+ "GO_NET_TEST_SPLICE_CHUNK_SIZE=" + strconv.Itoa(chunkSize),
+ "GO_NET_TEST_SPLICE_TOTAL_SIZE=" + strconv.Itoa(totalSize),
+ "TMPDIR=" + os.Getenv("TMPDIR"),
+ }
+ cmd.ExtraFiles = append(cmd.ExtraFiles, f)
+ cmd.Stdout = os.Stdout
+ cmd.Stderr = os.Stderr
+
+ if err := cmd.Start(); err != nil {
+ return nil, err
+ }
+
+ donec := make(chan struct{})
+ go func() {
+ cmd.Wait()
+ conn.Close()
+ f.Close()
+ close(donec)
+ }()
+
+ return func() {
+ select {
+ case <-donec:
+ case <-time.After(5 * time.Second):
+ log.Printf("killing splice client after 5 second shutdown timeout")
+ cmd.Process.Kill()
+ select {
+ case <-donec:
+ case <-time.After(5 * time.Second):
+ log.Printf("splice client didn't die after 10 seconds")
+ }
+ }
+ }, nil
+}
+
+func init() {
+ if os.Getenv("GO_NET_TEST_SPLICE") == "" {
+ return
+ }
+ defer os.Exit(0)
+
+ f := os.NewFile(uintptr(3), "splice-test-conn")
+ defer f.Close()
+
+ conn, err := FileConn(f)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ var chunkSize int
+ if chunkSize, err = strconv.Atoi(os.Getenv("GO_NET_TEST_SPLICE_CHUNK_SIZE")); err != nil {
+ log.Fatal(err)
+ }
+ buf := make([]byte, chunkSize)
+
+ var totalSize int
+ if totalSize, err = strconv.Atoi(os.Getenv("GO_NET_TEST_SPLICE_TOTAL_SIZE")); err != nil {
+ log.Fatal(err)
+ }
+
+ var fn func([]byte) (int, error)
+ switch op := os.Getenv("GO_NET_TEST_SPLICE_OP"); op {
+ case "r":
+ fn = conn.Read
+ case "w":
+ defer conn.Close()
+
+ fn = conn.Write
+ default:
+ log.Fatalf("unknown op %q", op)
+ }
+
+ var n int
+ for count := 0; count < totalSize; count += n {
+ if count+chunkSize > totalSize {
+ buf = buf[:totalSize-count]
+ }
+
+ var err error
+ if n, err = fn(buf); err != nil {
+ return
+ }
+ }
+}
+
+func BenchmarkSpliceFile(b *testing.B) {
+ b.Run("tcp-to-file", func(b *testing.B) { benchmarkSpliceFile(b, "tcp") })
+ b.Run("unix-to-file", func(b *testing.B) { benchmarkSpliceFile(b, "unix") })
+}
+
+func benchmarkSpliceFile(b *testing.B, proto string) {
+ for i := 0; i <= 10; i++ {
+ size := 1 << (i + 10)
+ bench := spliceFileBench{
+ proto: proto,
+ chunkSize: size,
+ }
+ b.Run(strconv.Itoa(size), bench.benchSpliceFile)
+ }
+}
+
+type spliceFileBench struct {
+ proto string
+ chunkSize int
+}
+
+func (bench spliceFileBench) benchSpliceFile(b *testing.B) {
+ f, err := os.OpenFile(os.DevNull, os.O_WRONLY, 0)
+ if err != nil {
+ b.Fatal(err)
+ }
+ defer f.Close()
+
+ totalSize := b.N * bench.chunkSize
+
+ client, server := spliceTestSocketPair(b, bench.proto)
+ defer server.Close()
+
+ cleanup, err := startSpliceClient(client, "w", bench.chunkSize, totalSize)
+ if err != nil {
+ client.Close()
+ b.Fatalf("failed to start splice client: %v", err)
+ }
+ defer cleanup()
+
+ b.ReportAllocs()
+ b.SetBytes(int64(bench.chunkSize))
+ b.ResetTimer()
+
+ got, err := io.Copy(f, server)
+ if err != nil {
+ b.Fatalf("failed to ReadFrom with error: %v", err)
+ }
+ if want := int64(totalSize); got != want {
+ b.Errorf("bytes sent mismatch, got: %d, want: %d", got, want)
+ }
+}
diff --git a/src/net/sys_cloexec.go b/src/net/sys_cloexec.go
new file mode 100644
index 0000000..6e61d40
--- /dev/null
+++ b/src/net/sys_cloexec.go
@@ -0,0 +1,36 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This file implements sysSocket for platforms that do not provide a fast path
+// for setting SetNonblock and CloseOnExec.
+
+//go:build aix || darwin
+
+package net
+
+import (
+ "internal/poll"
+ "os"
+ "syscall"
+)
+
+// Wrapper around the socket system call that marks the returned file
+// descriptor as nonblocking and close-on-exec.
+func sysSocket(family, sotype, proto int) (int, error) {
+ // See ../syscall/exec_unix.go for description of ForkLock.
+ syscall.ForkLock.RLock()
+ s, err := socketFunc(family, sotype, proto)
+ if err == nil {
+ syscall.CloseOnExec(s)
+ }
+ syscall.ForkLock.RUnlock()
+ if err != nil {
+ return -1, os.NewSyscallError("socket", err)
+ }
+ if err = syscall.SetNonblock(s, true); err != nil {
+ poll.CloseFunc(s)
+ return -1, os.NewSyscallError("setnonblock", err)
+ }
+ return s, nil
+}
diff --git a/src/net/tcpsock.go b/src/net/tcpsock.go
new file mode 100644
index 0000000..358e487
--- /dev/null
+++ b/src/net/tcpsock.go
@@ -0,0 +1,398 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "context"
+ "internal/itoa"
+ "io"
+ "net/netip"
+ "os"
+ "syscall"
+ "time"
+)
+
+// BUG(mikio): On JS and Windows, the File method of TCPConn and
+// TCPListener is not implemented.
+
+// TCPAddr represents the address of a TCP end point.
+type TCPAddr struct {
+ IP IP
+ Port int
+ Zone string // IPv6 scoped addressing zone
+}
+
+// AddrPort returns the TCPAddr a as a netip.AddrPort.
+//
+// If a.Port does not fit in a uint16, it's silently truncated.
+//
+// If a is nil, a zero value is returned.
+func (a *TCPAddr) AddrPort() netip.AddrPort {
+ if a == nil {
+ return netip.AddrPort{}
+ }
+ na, _ := netip.AddrFromSlice(a.IP)
+ na = na.WithZone(a.Zone)
+ return netip.AddrPortFrom(na, uint16(a.Port))
+}
+
+// Network returns the address's network name, "tcp".
+func (a *TCPAddr) Network() string { return "tcp" }
+
+func (a *TCPAddr) String() string {
+ if a == nil {
+ return "<nil>"
+ }
+ ip := ipEmptyString(a.IP)
+ if a.Zone != "" {
+ return JoinHostPort(ip+"%"+a.Zone, itoa.Itoa(a.Port))
+ }
+ return JoinHostPort(ip, itoa.Itoa(a.Port))
+}
+
+func (a *TCPAddr) isWildcard() bool {
+ if a == nil || a.IP == nil {
+ return true
+ }
+ return a.IP.IsUnspecified()
+}
+
+func (a *TCPAddr) opAddr() Addr {
+ if a == nil {
+ return nil
+ }
+ return a
+}
+
+// ResolveTCPAddr returns an address of TCP end point.
+//
+// The network must be a TCP network name.
+//
+// If the host in the address parameter is not a literal IP address or
+// the port is not a literal port number, ResolveTCPAddr resolves the
+// address to an address of TCP end point.
+// Otherwise, it parses the address as a pair of literal IP address
+// and port number.
+// The address parameter can use a host name, but this is not
+// recommended, because it will return at most one of the host name's
+// IP addresses.
+//
+// See func Dial for a description of the network and address
+// parameters.
+func ResolveTCPAddr(network, address string) (*TCPAddr, error) {
+ switch network {
+ case "tcp", "tcp4", "tcp6":
+ case "": // a hint wildcard for Go 1.0 undocumented behavior
+ network = "tcp"
+ default:
+ return nil, UnknownNetworkError(network)
+ }
+ addrs, err := DefaultResolver.internetAddrList(context.Background(), network, address)
+ if err != nil {
+ return nil, err
+ }
+ return addrs.forResolve(network, address).(*TCPAddr), nil
+}
+
+// TCPAddrFromAddrPort returns addr as a TCPAddr. If addr.IsValid() is false,
+// then the returned TCPAddr will contain a nil IP field, indicating an
+// address family-agnostic unspecified address.
+func TCPAddrFromAddrPort(addr netip.AddrPort) *TCPAddr {
+ return &TCPAddr{
+ IP: addr.Addr().AsSlice(),
+ Zone: addr.Addr().Zone(),
+ Port: int(addr.Port()),
+ }
+}
+
+// TCPConn is an implementation of the Conn interface for TCP network
+// connections.
+type TCPConn struct {
+ conn
+}
+
+// SyscallConn returns a raw network connection.
+// This implements the syscall.Conn interface.
+func (c *TCPConn) SyscallConn() (syscall.RawConn, error) {
+ if !c.ok() {
+ return nil, syscall.EINVAL
+ }
+ return newRawConn(c.fd)
+}
+
+// ReadFrom implements the io.ReaderFrom ReadFrom method.
+func (c *TCPConn) ReadFrom(r io.Reader) (int64, error) {
+ if !c.ok() {
+ return 0, syscall.EINVAL
+ }
+ n, err := c.readFrom(r)
+ if err != nil && err != io.EOF {
+ err = &OpError{Op: "readfrom", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
+ }
+ return n, err
+}
+
+// CloseRead shuts down the reading side of the TCP connection.
+// Most callers should just use Close.
+func (c *TCPConn) CloseRead() error {
+ if !c.ok() {
+ return syscall.EINVAL
+ }
+ if err := c.fd.closeRead(); err != nil {
+ return &OpError{Op: "close", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
+ }
+ return nil
+}
+
+// CloseWrite shuts down the writing side of the TCP connection.
+// Most callers should just use Close.
+func (c *TCPConn) CloseWrite() error {
+ if !c.ok() {
+ return syscall.EINVAL
+ }
+ if err := c.fd.closeWrite(); err != nil {
+ return &OpError{Op: "close", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
+ }
+ return nil
+}
+
+// SetLinger sets the behavior of Close on a connection which still
+// has data waiting to be sent or to be acknowledged.
+//
+// If sec < 0 (the default), the operating system finishes sending the
+// data in the background.
+//
+// If sec == 0, the operating system discards any unsent or
+// unacknowledged data.
+//
+// If sec > 0, the data is sent in the background as with sec < 0.
+// On some operating systems including Linux, this may cause Close to block
+// until all data has been sent or discarded.
+// On some operating systems after sec seconds have elapsed any remaining
+// unsent data may be discarded.
+func (c *TCPConn) SetLinger(sec int) error {
+ if !c.ok() {
+ return syscall.EINVAL
+ }
+ if err := setLinger(c.fd, sec); err != nil {
+ return &OpError{Op: "set", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
+ }
+ return nil
+}
+
+// SetKeepAlive sets whether the operating system should send
+// keep-alive messages on the connection.
+func (c *TCPConn) SetKeepAlive(keepalive bool) error {
+ if !c.ok() {
+ return syscall.EINVAL
+ }
+ if err := setKeepAlive(c.fd, keepalive); err != nil {
+ return &OpError{Op: "set", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
+ }
+ return nil
+}
+
+// SetKeepAlivePeriod sets period between keep-alives.
+func (c *TCPConn) SetKeepAlivePeriod(d time.Duration) error {
+ if !c.ok() {
+ return syscall.EINVAL
+ }
+ if err := setKeepAlivePeriod(c.fd, d); err != nil {
+ return &OpError{Op: "set", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
+ }
+ return nil
+}
+
+// SetNoDelay controls whether the operating system should delay
+// packet transmission in hopes of sending fewer packets (Nagle's
+// algorithm). The default is true (no delay), meaning that data is
+// sent as soon as possible after a Write.
+func (c *TCPConn) SetNoDelay(noDelay bool) error {
+ if !c.ok() {
+ return syscall.EINVAL
+ }
+ if err := setNoDelay(c.fd, noDelay); err != nil {
+ return &OpError{Op: "set", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
+ }
+ return nil
+}
+
+// MultipathTCP reports whether the ongoing connection is using MPTCP.
+//
+// If Multipath TCP is not supported by the host, by the other peer or
+// intentionally / accidentally filtered out by a device in between, a
+// fallback to TCP will be done. This method does its best to check if
+// MPTCP is still being used or not.
+//
+// On Linux, more conditions are verified on kernels >= v5.16, improving
+// the results.
+func (c *TCPConn) MultipathTCP() (bool, error) {
+ if !c.ok() {
+ return false, syscall.EINVAL
+ }
+ return isUsingMultipathTCP(c.fd), nil
+}
+
+func newTCPConn(fd *netFD, keepAlive time.Duration, keepAliveHook func(time.Duration)) *TCPConn {
+ setNoDelay(fd, true)
+ if keepAlive == 0 {
+ keepAlive = defaultTCPKeepAlive
+ }
+ if keepAlive > 0 {
+ setKeepAlive(fd, true)
+ setKeepAlivePeriod(fd, keepAlive)
+ if keepAliveHook != nil {
+ keepAliveHook(keepAlive)
+ }
+ }
+ return &TCPConn{conn{fd}}
+}
+
+// DialTCP acts like Dial for TCP networks.
+//
+// The network must be a TCP network name; see func Dial for details.
+//
+// If laddr is nil, a local address is automatically chosen.
+// If the IP field of raddr is nil or an unspecified IP address, the
+// local system is assumed.
+func DialTCP(network string, laddr, raddr *TCPAddr) (*TCPConn, error) {
+ switch network {
+ case "tcp", "tcp4", "tcp6":
+ default:
+ return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: UnknownNetworkError(network)}
+ }
+ if raddr == nil {
+ return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: nil, Err: errMissingAddress}
+ }
+ sd := &sysDialer{network: network, address: raddr.String()}
+ c, err := sd.dialTCP(context.Background(), laddr, raddr)
+ if err != nil {
+ return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
+ }
+ return c, nil
+}
+
+// TCPListener is a TCP network listener. Clients should typically
+// use variables of type Listener instead of assuming TCP.
+type TCPListener struct {
+ fd *netFD
+ lc ListenConfig
+}
+
+// SyscallConn returns a raw network connection.
+// This implements the syscall.Conn interface.
+//
+// The returned RawConn only supports calling Control. Read and
+// Write return an error.
+func (l *TCPListener) SyscallConn() (syscall.RawConn, error) {
+ if !l.ok() {
+ return nil, syscall.EINVAL
+ }
+ return newRawListener(l.fd)
+}
+
+// AcceptTCP accepts the next incoming call and returns the new
+// connection.
+func (l *TCPListener) AcceptTCP() (*TCPConn, error) {
+ if !l.ok() {
+ return nil, syscall.EINVAL
+ }
+ c, err := l.accept()
+ if err != nil {
+ return nil, &OpError{Op: "accept", Net: l.fd.net, Source: nil, Addr: l.fd.laddr, Err: err}
+ }
+ return c, nil
+}
+
+// Accept implements the Accept method in the Listener interface; it
+// waits for the next call and returns a generic Conn.
+func (l *TCPListener) Accept() (Conn, error) {
+ if !l.ok() {
+ return nil, syscall.EINVAL
+ }
+ c, err := l.accept()
+ if err != nil {
+ return nil, &OpError{Op: "accept", Net: l.fd.net, Source: nil, Addr: l.fd.laddr, Err: err}
+ }
+ return c, nil
+}
+
+// Close stops listening on the TCP address.
+// Already Accepted connections are not closed.
+func (l *TCPListener) Close() error {
+ if !l.ok() {
+ return syscall.EINVAL
+ }
+ if err := l.close(); err != nil {
+ return &OpError{Op: "close", Net: l.fd.net, Source: nil, Addr: l.fd.laddr, Err: err}
+ }
+ return nil
+}
+
+// Addr returns the listener's network address, a *TCPAddr.
+// The Addr returned is shared by all invocations of Addr, so
+// do not modify it.
+func (l *TCPListener) Addr() Addr { return l.fd.laddr }
+
+// SetDeadline sets the deadline associated with the listener.
+// A zero time value disables the deadline.
+func (l *TCPListener) SetDeadline(t time.Time) error {
+ if !l.ok() {
+ return syscall.EINVAL
+ }
+ if err := l.fd.pfd.SetDeadline(t); err != nil {
+ return &OpError{Op: "set", Net: l.fd.net, Source: nil, Addr: l.fd.laddr, Err: err}
+ }
+ return nil
+}
+
+// File returns a copy of the underlying os.File.
+// It is the caller's responsibility to close f when finished.
+// Closing l does not affect f, and closing f does not affect l.
+//
+// The returned os.File's file descriptor is different from the
+// connection's. Attempting to change properties of the original
+// using this duplicate may or may not have the desired effect.
+func (l *TCPListener) File() (f *os.File, err error) {
+ if !l.ok() {
+ return nil, syscall.EINVAL
+ }
+ f, err = l.file()
+ if err != nil {
+ return nil, &OpError{Op: "file", Net: l.fd.net, Source: nil, Addr: l.fd.laddr, Err: err}
+ }
+ return
+}
+
+// ListenTCP acts like Listen for TCP networks.
+//
+// The network must be a TCP network name; see func Dial for details.
+//
+// If the IP field of laddr is nil or an unspecified IP address,
+// ListenTCP listens on all available unicast and anycast IP addresses
+// of the local system.
+// If the Port field of laddr is 0, a port number is automatically
+// chosen.
+func ListenTCP(network string, laddr *TCPAddr) (*TCPListener, error) {
+ switch network {
+ case "tcp", "tcp4", "tcp6":
+ default:
+ return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: laddr.opAddr(), Err: UnknownNetworkError(network)}
+ }
+ if laddr == nil {
+ laddr = &TCPAddr{}
+ }
+ sl := &sysListener{network: network, address: laddr.String()}
+ ln, err := sl.listenTCP(context.Background(), laddr)
+ if err != nil {
+ return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: laddr.opAddr(), Err: err}
+ }
+ return ln, nil
+}
+
+// roundDurationUp rounds d to the next multiple of to.
+func roundDurationUp(d time.Duration, to time.Duration) time.Duration {
+ return (d + to - 1) / to
+}
diff --git a/src/net/tcpsock_plan9.go b/src/net/tcpsock_plan9.go
new file mode 100644
index 0000000..d55948f
--- /dev/null
+++ b/src/net/tcpsock_plan9.go
@@ -0,0 +1,86 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "context"
+ "io"
+ "os"
+)
+
+func (c *TCPConn) readFrom(r io.Reader) (int64, error) {
+ return genericReadFrom(c, r)
+}
+
+func (sd *sysDialer) dialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConn, error) {
+ if h := sd.testHookDialTCP; h != nil {
+ return h(ctx, sd.network, laddr, raddr)
+ }
+ if h := testHookDialTCP; h != nil {
+ return h(ctx, sd.network, laddr, raddr)
+ }
+ return sd.doDialTCP(ctx, laddr, raddr)
+}
+
+func (sd *sysDialer) doDialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConn, error) {
+ switch sd.network {
+ case "tcp4":
+ // Plan 9 doesn't complain about [::]:0->127.0.0.1, so it's up to us.
+ if laddr != nil && len(laddr.IP) != 0 && laddr.IP.To4() == nil {
+ return nil, &AddrError{Err: "non-IPv4 local address", Addr: laddr.String()}
+ }
+ case "tcp", "tcp6":
+ default:
+ return nil, UnknownNetworkError(sd.network)
+ }
+ if raddr == nil {
+ return nil, errMissingAddress
+ }
+ fd, err := dialPlan9(ctx, sd.network, laddr, raddr)
+ if err != nil {
+ return nil, err
+ }
+ return newTCPConn(fd, sd.Dialer.KeepAlive, testHookSetKeepAlive), nil
+}
+
+func (ln *TCPListener) ok() bool { return ln != nil && ln.fd != nil && ln.fd.ctl != nil }
+
+func (ln *TCPListener) accept() (*TCPConn, error) {
+ fd, err := ln.fd.acceptPlan9()
+ if err != nil {
+ return nil, err
+ }
+ return newTCPConn(fd, ln.lc.KeepAlive, nil), nil
+}
+
+func (ln *TCPListener) close() error {
+ if err := ln.fd.pfd.Close(); err != nil {
+ return err
+ }
+ if _, err := ln.fd.ctl.WriteString("hangup"); err != nil {
+ ln.fd.ctl.Close()
+ return err
+ }
+ if err := ln.fd.ctl.Close(); err != nil {
+ return err
+ }
+ return nil
+}
+
+func (ln *TCPListener) file() (*os.File, error) {
+ f, err := ln.dup()
+ if err != nil {
+ return nil, err
+ }
+ return f, nil
+}
+
+func (sl *sysListener) listenTCP(ctx context.Context, laddr *TCPAddr) (*TCPListener, error) {
+ fd, err := listenPlan9(ctx, sl.network, laddr)
+ if err != nil {
+ return nil, err
+ }
+ return &TCPListener{fd: fd, lc: sl.ListenConfig}, nil
+}
diff --git a/src/net/tcpsock_posix.go b/src/net/tcpsock_posix.go
new file mode 100644
index 0000000..e6f425b
--- /dev/null
+++ b/src/net/tcpsock_posix.go
@@ -0,0 +1,187 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build unix || (js && wasm) || wasip1 || windows
+
+package net
+
+import (
+ "context"
+ "io"
+ "os"
+ "syscall"
+)
+
+func sockaddrToTCP(sa syscall.Sockaddr) Addr {
+ switch sa := sa.(type) {
+ case *syscall.SockaddrInet4:
+ return &TCPAddr{IP: sa.Addr[0:], Port: sa.Port}
+ case *syscall.SockaddrInet6:
+ return &TCPAddr{IP: sa.Addr[0:], Port: sa.Port, Zone: zoneCache.name(int(sa.ZoneId))}
+ }
+ return nil
+}
+
+func (a *TCPAddr) family() int {
+ if a == nil || len(a.IP) <= IPv4len {
+ return syscall.AF_INET
+ }
+ if a.IP.To4() != nil {
+ return syscall.AF_INET
+ }
+ return syscall.AF_INET6
+}
+
+func (a *TCPAddr) sockaddr(family int) (syscall.Sockaddr, error) {
+ if a == nil {
+ return nil, nil
+ }
+ return ipToSockaddr(family, a.IP, a.Port, a.Zone)
+}
+
+func (a *TCPAddr) toLocal(net string) sockaddr {
+ return &TCPAddr{loopbackIP(net), a.Port, a.Zone}
+}
+
+func (c *TCPConn) readFrom(r io.Reader) (int64, error) {
+ if n, err, handled := splice(c.fd, r); handled {
+ return n, err
+ }
+ if n, err, handled := sendFile(c.fd, r); handled {
+ return n, err
+ }
+ return genericReadFrom(c, r)
+}
+
+func (sd *sysDialer) dialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConn, error) {
+ if h := sd.testHookDialTCP; h != nil {
+ return h(ctx, sd.network, laddr, raddr)
+ }
+ if h := testHookDialTCP; h != nil {
+ return h(ctx, sd.network, laddr, raddr)
+ }
+ return sd.doDialTCP(ctx, laddr, raddr)
+}
+
+func (sd *sysDialer) doDialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConn, error) {
+ return sd.doDialTCPProto(ctx, laddr, raddr, 0)
+}
+
+func (sd *sysDialer) doDialTCPProto(ctx context.Context, laddr, raddr *TCPAddr, proto int) (*TCPConn, error) {
+ ctrlCtxFn := sd.Dialer.ControlContext
+ if ctrlCtxFn == nil && sd.Dialer.Control != nil {
+ ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
+ return sd.Dialer.Control(network, address, c)
+ }
+ }
+ fd, err := internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_STREAM, proto, "dial", ctrlCtxFn)
+
+ // TCP has a rarely used mechanism called a 'simultaneous connection' in
+ // which Dial("tcp", addr1, addr2) run on the machine at addr1 can
+ // connect to a simultaneous Dial("tcp", addr2, addr1) run on the machine
+ // at addr2, without either machine executing Listen. If laddr == nil,
+ // it means we want the kernel to pick an appropriate originating local
+ // address. Some Linux kernels cycle blindly through a fixed range of
+ // local ports, regardless of destination port. If a kernel happens to
+ // pick local port 50001 as the source for a Dial("tcp", "", "localhost:50001"),
+ // then the Dial will succeed, having simultaneously connected to itself.
+ // This can only happen when we are letting the kernel pick a port (laddr == nil)
+ // and when there is no listener for the destination address.
+ // It's hard to argue this is anything other than a kernel bug. If we
+ // see this happen, rather than expose the buggy effect to users, we
+ // close the fd and try again. If it happens twice more, we relent and
+ // use the result. See also:
+ // https://golang.org/issue/2690
+ // https://stackoverflow.com/questions/4949858/
+ //
+ // The opposite can also happen: if we ask the kernel to pick an appropriate
+ // originating local address, sometimes it picks one that is already in use.
+ // So if the error is EADDRNOTAVAIL, we have to try again too, just for
+ // a different reason.
+ //
+ // The kernel socket code is no doubt enjoying watching us squirm.
+ for i := 0; i < 2 && (laddr == nil || laddr.Port == 0) && (selfConnect(fd, err) || spuriousENOTAVAIL(err)); i++ {
+ if err == nil {
+ fd.Close()
+ }
+ fd, err = internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_STREAM, proto, "dial", ctrlCtxFn)
+ }
+
+ if err != nil {
+ return nil, err
+ }
+ return newTCPConn(fd, sd.Dialer.KeepAlive, testHookSetKeepAlive), nil
+}
+
+func selfConnect(fd *netFD, err error) bool {
+ // If the connect failed, we clearly didn't connect to ourselves.
+ if err != nil {
+ return false
+ }
+
+ // The socket constructor can return an fd with raddr nil under certain
+ // unknown conditions. The errors in the calls there to Getpeername
+ // are discarded, but we can't catch the problem there because those
+ // calls are sometimes legally erroneous with a "socket not connected".
+ // Since this code (selfConnect) is already trying to work around
+ // a problem, we make sure if this happens we recognize trouble and
+ // ask the DialTCP routine to try again.
+ // TODO: try to understand what's really going on.
+ if fd.laddr == nil || fd.raddr == nil {
+ return true
+ }
+ l := fd.laddr.(*TCPAddr)
+ r := fd.raddr.(*TCPAddr)
+ return l.Port == r.Port && l.IP.Equal(r.IP)
+}
+
+func spuriousENOTAVAIL(err error) bool {
+ if op, ok := err.(*OpError); ok {
+ err = op.Err
+ }
+ if sys, ok := err.(*os.SyscallError); ok {
+ err = sys.Err
+ }
+ return err == syscall.EADDRNOTAVAIL
+}
+
+func (ln *TCPListener) ok() bool { return ln != nil && ln.fd != nil }
+
+func (ln *TCPListener) accept() (*TCPConn, error) {
+ fd, err := ln.fd.accept()
+ if err != nil {
+ return nil, err
+ }
+ return newTCPConn(fd, ln.lc.KeepAlive, nil), nil
+}
+
+func (ln *TCPListener) close() error {
+ return ln.fd.Close()
+}
+
+func (ln *TCPListener) file() (*os.File, error) {
+ f, err := ln.fd.dup()
+ if err != nil {
+ return nil, err
+ }
+ return f, nil
+}
+
+func (sl *sysListener) listenTCP(ctx context.Context, laddr *TCPAddr) (*TCPListener, error) {
+ return sl.listenTCPProto(ctx, laddr, 0)
+}
+
+func (sl *sysListener) listenTCPProto(ctx context.Context, laddr *TCPAddr, proto int) (*TCPListener, error) {
+ var ctrlCtxFn func(cxt context.Context, network, address string, c syscall.RawConn) error
+ if sl.ListenConfig.Control != nil {
+ ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
+ return sl.ListenConfig.Control(network, address, c)
+ }
+ }
+ fd, err := internetSocket(ctx, sl.network, laddr, nil, syscall.SOCK_STREAM, proto, "listen", ctrlCtxFn)
+ if err != nil {
+ return nil, err
+ }
+ return &TCPListener{fd: fd, lc: sl.ListenConfig}, nil
+}
diff --git a/src/net/tcpsock_test.go b/src/net/tcpsock_test.go
new file mode 100644
index 0000000..f720a22
--- /dev/null
+++ b/src/net/tcpsock_test.go
@@ -0,0 +1,785 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !js && !wasip1
+
+package net
+
+import (
+ "fmt"
+ "internal/testenv"
+ "io"
+ "os"
+ "reflect"
+ "runtime"
+ "sync"
+ "testing"
+ "time"
+)
+
+func BenchmarkTCP4OneShot(b *testing.B) {
+ benchmarkTCP(b, false, false, "127.0.0.1:0")
+}
+
+func BenchmarkTCP4OneShotTimeout(b *testing.B) {
+ benchmarkTCP(b, false, true, "127.0.0.1:0")
+}
+
+func BenchmarkTCP4Persistent(b *testing.B) {
+ benchmarkTCP(b, true, false, "127.0.0.1:0")
+}
+
+func BenchmarkTCP4PersistentTimeout(b *testing.B) {
+ benchmarkTCP(b, true, true, "127.0.0.1:0")
+}
+
+func BenchmarkTCP6OneShot(b *testing.B) {
+ if !supportsIPv6() {
+ b.Skip("ipv6 is not supported")
+ }
+ benchmarkTCP(b, false, false, "[::1]:0")
+}
+
+func BenchmarkTCP6OneShotTimeout(b *testing.B) {
+ if !supportsIPv6() {
+ b.Skip("ipv6 is not supported")
+ }
+ benchmarkTCP(b, false, true, "[::1]:0")
+}
+
+func BenchmarkTCP6Persistent(b *testing.B) {
+ if !supportsIPv6() {
+ b.Skip("ipv6 is not supported")
+ }
+ benchmarkTCP(b, true, false, "[::1]:0")
+}
+
+func BenchmarkTCP6PersistentTimeout(b *testing.B) {
+ if !supportsIPv6() {
+ b.Skip("ipv6 is not supported")
+ }
+ benchmarkTCP(b, true, true, "[::1]:0")
+}
+
+func benchmarkTCP(b *testing.B, persistent, timeout bool, laddr string) {
+ testHookUninstaller.Do(uninstallTestHooks)
+
+ const msgLen = 512
+ conns := b.N
+ numConcurrent := runtime.GOMAXPROCS(-1) * 2
+ msgs := 1
+ if persistent {
+ conns = numConcurrent
+ msgs = b.N / conns
+ if msgs == 0 {
+ msgs = 1
+ }
+ if conns > b.N {
+ conns = b.N
+ }
+ }
+ sendMsg := func(c Conn, buf []byte) bool {
+ n, err := c.Write(buf)
+ if n != len(buf) || err != nil {
+ b.Log(err)
+ return false
+ }
+ return true
+ }
+ recvMsg := func(c Conn, buf []byte) bool {
+ for read := 0; read != len(buf); {
+ n, err := c.Read(buf)
+ read += n
+ if err != nil {
+ b.Log(err)
+ return false
+ }
+ }
+ return true
+ }
+ ln, err := Listen("tcp", laddr)
+ if err != nil {
+ b.Fatal(err)
+ }
+ defer ln.Close()
+ serverSem := make(chan bool, numConcurrent)
+ // Acceptor.
+ go func() {
+ for {
+ c, err := ln.Accept()
+ if err != nil {
+ break
+ }
+ serverSem <- true
+ // Server connection.
+ go func(c Conn) {
+ defer func() {
+ c.Close()
+ <-serverSem
+ }()
+ if timeout {
+ c.SetDeadline(time.Now().Add(time.Hour)) // Not intended to fire.
+ }
+ var buf [msgLen]byte
+ for m := 0; m < msgs; m++ {
+ if !recvMsg(c, buf[:]) || !sendMsg(c, buf[:]) {
+ break
+ }
+ }
+ }(c)
+ }
+ }()
+ clientSem := make(chan bool, numConcurrent)
+ for i := 0; i < conns; i++ {
+ clientSem <- true
+ // Client connection.
+ go func() {
+ defer func() {
+ <-clientSem
+ }()
+ c, err := Dial("tcp", ln.Addr().String())
+ if err != nil {
+ b.Log(err)
+ return
+ }
+ defer c.Close()
+ if timeout {
+ c.SetDeadline(time.Now().Add(time.Hour)) // Not intended to fire.
+ }
+ var buf [msgLen]byte
+ for m := 0; m < msgs; m++ {
+ if !sendMsg(c, buf[:]) || !recvMsg(c, buf[:]) {
+ break
+ }
+ }
+ }()
+ }
+ for i := 0; i < numConcurrent; i++ {
+ clientSem <- true
+ serverSem <- true
+ }
+}
+
+func BenchmarkTCP4ConcurrentReadWrite(b *testing.B) {
+ benchmarkTCPConcurrentReadWrite(b, "127.0.0.1:0")
+}
+
+func BenchmarkTCP6ConcurrentReadWrite(b *testing.B) {
+ if !supportsIPv6() {
+ b.Skip("ipv6 is not supported")
+ }
+ benchmarkTCPConcurrentReadWrite(b, "[::1]:0")
+}
+
+func benchmarkTCPConcurrentReadWrite(b *testing.B, laddr string) {
+ testHookUninstaller.Do(uninstallTestHooks)
+
+ // The benchmark creates GOMAXPROCS client/server pairs.
+ // Each pair creates 4 goroutines: client reader/writer and server reader/writer.
+ // The benchmark stresses concurrent reading and writing to the same connection.
+ // Such pattern is used in net/http and net/rpc.
+
+ b.StopTimer()
+
+ P := runtime.GOMAXPROCS(0)
+ N := b.N / P
+ W := 1000
+
+ // Setup P client/server connections.
+ clients := make([]Conn, P)
+ servers := make([]Conn, P)
+ ln, err := Listen("tcp", laddr)
+ if err != nil {
+ b.Fatal(err)
+ }
+ defer ln.Close()
+ done := make(chan bool)
+ go func() {
+ for p := 0; p < P; p++ {
+ s, err := ln.Accept()
+ if err != nil {
+ b.Error(err)
+ return
+ }
+ servers[p] = s
+ }
+ done <- true
+ }()
+ for p := 0; p < P; p++ {
+ c, err := Dial("tcp", ln.Addr().String())
+ if err != nil {
+ b.Fatal(err)
+ }
+ clients[p] = c
+ }
+ <-done
+
+ b.StartTimer()
+
+ var wg sync.WaitGroup
+ wg.Add(4 * P)
+ for p := 0; p < P; p++ {
+ // Client writer.
+ go func(c Conn) {
+ defer wg.Done()
+ var buf [1]byte
+ for i := 0; i < N; i++ {
+ v := byte(i)
+ for w := 0; w < W; w++ {
+ v *= v
+ }
+ buf[0] = v
+ _, err := c.Write(buf[:])
+ if err != nil {
+ b.Error(err)
+ return
+ }
+ }
+ }(clients[p])
+
+ // Pipe between server reader and server writer.
+ pipe := make(chan byte, 128)
+
+ // Server reader.
+ go func(s Conn) {
+ defer wg.Done()
+ var buf [1]byte
+ for i := 0; i < N; i++ {
+ _, err := s.Read(buf[:])
+ if err != nil {
+ b.Error(err)
+ return
+ }
+ pipe <- buf[0]
+ }
+ }(servers[p])
+
+ // Server writer.
+ go func(s Conn) {
+ defer wg.Done()
+ var buf [1]byte
+ for i := 0; i < N; i++ {
+ v := <-pipe
+ for w := 0; w < W; w++ {
+ v *= v
+ }
+ buf[0] = v
+ _, err := s.Write(buf[:])
+ if err != nil {
+ b.Error(err)
+ return
+ }
+ }
+ s.Close()
+ }(servers[p])
+
+ // Client reader.
+ go func(c Conn) {
+ defer wg.Done()
+ var buf [1]byte
+ for i := 0; i < N; i++ {
+ _, err := c.Read(buf[:])
+ if err != nil {
+ b.Error(err)
+ return
+ }
+ }
+ c.Close()
+ }(clients[p])
+ }
+ wg.Wait()
+}
+
+type resolveTCPAddrTest struct {
+ network string
+ litAddrOrName string
+ addr *TCPAddr
+ err error
+}
+
+var resolveTCPAddrTests = []resolveTCPAddrTest{
+ {"tcp", "127.0.0.1:0", &TCPAddr{IP: IPv4(127, 0, 0, 1), Port: 0}, nil},
+ {"tcp4", "127.0.0.1:65535", &TCPAddr{IP: IPv4(127, 0, 0, 1), Port: 65535}, nil},
+
+ {"tcp", "[::1]:0", &TCPAddr{IP: ParseIP("::1"), Port: 0}, nil},
+ {"tcp6", "[::1]:65535", &TCPAddr{IP: ParseIP("::1"), Port: 65535}, nil},
+
+ {"tcp", "[::1%en0]:1", &TCPAddr{IP: ParseIP("::1"), Port: 1, Zone: "en0"}, nil},
+ {"tcp6", "[::1%911]:2", &TCPAddr{IP: ParseIP("::1"), Port: 2, Zone: "911"}, nil},
+
+ {"", "127.0.0.1:0", &TCPAddr{IP: IPv4(127, 0, 0, 1), Port: 0}, nil}, // Go 1.0 behavior
+ {"", "[::1]:0", &TCPAddr{IP: ParseIP("::1"), Port: 0}, nil}, // Go 1.0 behavior
+
+ {"tcp", ":12345", &TCPAddr{Port: 12345}, nil},
+
+ {"http", "127.0.0.1:0", nil, UnknownNetworkError("http")},
+
+ {"tcp", "127.0.0.1:http", &TCPAddr{IP: ParseIP("127.0.0.1"), Port: 80}, nil},
+ {"tcp", "[::ffff:127.0.0.1]:http", &TCPAddr{IP: ParseIP("::ffff:127.0.0.1"), Port: 80}, nil},
+ {"tcp", "[2001:db8::1]:http", &TCPAddr{IP: ParseIP("2001:db8::1"), Port: 80}, nil},
+ {"tcp4", "127.0.0.1:http", &TCPAddr{IP: ParseIP("127.0.0.1"), Port: 80}, nil},
+ {"tcp4", "[::ffff:127.0.0.1]:http", &TCPAddr{IP: ParseIP("127.0.0.1"), Port: 80}, nil},
+ {"tcp6", "[2001:db8::1]:http", &TCPAddr{IP: ParseIP("2001:db8::1"), Port: 80}, nil},
+
+ {"tcp4", "[2001:db8::1]:http", nil, &AddrError{Err: errNoSuitableAddress.Error(), Addr: "2001:db8::1"}},
+ {"tcp6", "127.0.0.1:http", nil, &AddrError{Err: errNoSuitableAddress.Error(), Addr: "127.0.0.1"}},
+ {"tcp6", "[::ffff:127.0.0.1]:http", nil, &AddrError{Err: errNoSuitableAddress.Error(), Addr: "::ffff:127.0.0.1"}},
+}
+
+func TestResolveTCPAddr(t *testing.T) {
+ origTestHookLookupIP := testHookLookupIP
+ defer func() { testHookLookupIP = origTestHookLookupIP }()
+ testHookLookupIP = lookupLocalhost
+
+ for _, tt := range resolveTCPAddrTests {
+ addr, err := ResolveTCPAddr(tt.network, tt.litAddrOrName)
+ if !reflect.DeepEqual(addr, tt.addr) || !reflect.DeepEqual(err, tt.err) {
+ t.Errorf("ResolveTCPAddr(%q, %q) = %#v, %v, want %#v, %v", tt.network, tt.litAddrOrName, addr, err, tt.addr, tt.err)
+ continue
+ }
+ if err == nil {
+ addr2, err := ResolveTCPAddr(addr.Network(), addr.String())
+ if !reflect.DeepEqual(addr2, tt.addr) || err != tt.err {
+ t.Errorf("(%q, %q): ResolveTCPAddr(%q, %q) = %#v, %v, want %#v, %v", tt.network, tt.litAddrOrName, addr.Network(), addr.String(), addr2, err, tt.addr, tt.err)
+ }
+ }
+ }
+}
+
+var tcpListenerNameTests = []struct {
+ net string
+ laddr *TCPAddr
+}{
+ {"tcp4", &TCPAddr{IP: IPv4(127, 0, 0, 1)}},
+ {"tcp4", &TCPAddr{}},
+ {"tcp4", nil},
+}
+
+func TestTCPListenerName(t *testing.T) {
+ testenv.MustHaveExternalNetwork(t)
+
+ for _, tt := range tcpListenerNameTests {
+ ln, err := ListenTCP(tt.net, tt.laddr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ln.Close()
+ la := ln.Addr()
+ if a, ok := la.(*TCPAddr); !ok || a.Port == 0 {
+ t.Fatalf("got %v; expected a proper address with non-zero port number", la)
+ }
+ }
+}
+
+func TestIPv6LinkLocalUnicastTCP(t *testing.T) {
+ testenv.MustHaveExternalNetwork(t)
+
+ if !supportsIPv6() {
+ t.Skip("IPv6 is not supported")
+ }
+
+ for i, tt := range ipv6LinkLocalUnicastTCPTests {
+ ln, err := Listen(tt.network, tt.address)
+ if err != nil {
+ // It might return "LookupHost returned no
+ // suitable address" error on some platforms.
+ t.Log(err)
+ continue
+ }
+ ls := (&streamListener{Listener: ln}).newLocalServer()
+ defer ls.teardown()
+ ch := make(chan error, 1)
+ handler := func(ls *localServer, ln Listener) { ls.transponder(ln, ch) }
+ if err := ls.buildup(handler); err != nil {
+ t.Fatal(err)
+ }
+ if la, ok := ln.Addr().(*TCPAddr); !ok || !tt.nameLookup && la.Zone == "" {
+ t.Fatalf("got %v; expected a proper address with zone identifier", la)
+ }
+
+ c, err := Dial(tt.network, ls.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+ if la, ok := c.LocalAddr().(*TCPAddr); !ok || !tt.nameLookup && la.Zone == "" {
+ t.Fatalf("got %v; expected a proper address with zone identifier", la)
+ }
+ if ra, ok := c.RemoteAddr().(*TCPAddr); !ok || !tt.nameLookup && ra.Zone == "" {
+ t.Fatalf("got %v; expected a proper address with zone identifier", ra)
+ }
+
+ if _, err := c.Write([]byte("TCP OVER IPV6 LINKLOCAL TEST")); err != nil {
+ t.Fatal(err)
+ }
+ b := make([]byte, 32)
+ if _, err := c.Read(b); err != nil {
+ t.Fatal(err)
+ }
+
+ for err := range ch {
+ t.Errorf("#%d: %v", i, err)
+ }
+ }
+}
+
+func TestTCPConcurrentAccept(t *testing.T) {
+ defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(4))
+ ln, err := Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ const N = 10
+ var wg sync.WaitGroup
+ wg.Add(N)
+ for i := 0; i < N; i++ {
+ go func() {
+ for {
+ c, err := ln.Accept()
+ if err != nil {
+ break
+ }
+ c.Close()
+ }
+ wg.Done()
+ }()
+ }
+ attempts := 10 * N
+ fails := 0
+ d := &Dialer{Timeout: 200 * time.Millisecond}
+ for i := 0; i < attempts; i++ {
+ c, err := d.Dial("tcp", ln.Addr().String())
+ if err != nil {
+ fails++
+ } else {
+ c.Close()
+ }
+ }
+ ln.Close()
+ wg.Wait()
+ if fails > attempts/9 { // see issues 7400 and 7541
+ t.Fatalf("too many Dial failed: %v", fails)
+ }
+ if fails > 0 {
+ t.Logf("# of failed Dials: %v", fails)
+ }
+}
+
+func TestTCPReadWriteAllocs(t *testing.T) {
+ switch runtime.GOOS {
+ case "plan9":
+ // The implementation of asynchronous cancelable
+ // I/O on Plan 9 allocates memory.
+ // See net/fd_io_plan9.go.
+ t.Skipf("not supported on %s", runtime.GOOS)
+ }
+
+ ln, err := Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ln.Close()
+ var server Conn
+ errc := make(chan error, 1)
+ go func() {
+ var err error
+ server, err = ln.Accept()
+ errc <- err
+ }()
+ client, err := Dial("tcp", ln.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer client.Close()
+ if err := <-errc; err != nil {
+ t.Fatal(err)
+ }
+ defer server.Close()
+
+ var buf [128]byte
+ allocs := testing.AllocsPerRun(1000, func() {
+ _, err := server.Write(buf[:])
+ if err != nil {
+ t.Fatal(err)
+ }
+ _, err = io.ReadFull(client, buf[:])
+ if err != nil {
+ t.Fatal(err)
+ }
+ })
+ if allocs > 0 {
+ t.Fatalf("got %v; want 0", allocs)
+ }
+
+ var bufwrt [128]byte
+ ch := make(chan bool)
+ defer close(ch)
+ go func() {
+ for <-ch {
+ _, err := server.Write(bufwrt[:])
+ errc <- err
+ }
+ }()
+ allocs = testing.AllocsPerRun(1000, func() {
+ ch <- true
+ if _, err = io.ReadFull(client, buf[:]); err != nil {
+ t.Fatal(err)
+ }
+ if err := <-errc; err != nil {
+ t.Fatal(err)
+ }
+ })
+ if allocs > 0 {
+ t.Fatalf("got %v; want 0", allocs)
+ }
+}
+
+func TestTCPStress(t *testing.T) {
+ const conns = 2
+ const msgLen = 512
+ msgs := int(1e4)
+ if testing.Short() {
+ msgs = 1e2
+ }
+
+ sendMsg := func(c Conn, buf []byte) bool {
+ n, err := c.Write(buf)
+ if n != len(buf) || err != nil {
+ t.Log(err)
+ return false
+ }
+ return true
+ }
+ recvMsg := func(c Conn, buf []byte) bool {
+ for read := 0; read != len(buf); {
+ n, err := c.Read(buf)
+ read += n
+ if err != nil {
+ t.Log(err)
+ return false
+ }
+ }
+ return true
+ }
+
+ ln, err := Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ done := make(chan bool)
+ // Acceptor.
+ go func() {
+ defer func() {
+ done <- true
+ }()
+ for {
+ c, err := ln.Accept()
+ if err != nil {
+ break
+ }
+ // Server connection.
+ go func(c Conn) {
+ defer c.Close()
+ var buf [msgLen]byte
+ for m := 0; m < msgs; m++ {
+ if !recvMsg(c, buf[:]) || !sendMsg(c, buf[:]) {
+ break
+ }
+ }
+ }(c)
+ }
+ }()
+ for i := 0; i < conns; i++ {
+ // Client connection.
+ go func() {
+ defer func() {
+ done <- true
+ }()
+ c, err := Dial("tcp", ln.Addr().String())
+ if err != nil {
+ t.Log(err)
+ return
+ }
+ defer c.Close()
+ var buf [msgLen]byte
+ for m := 0; m < msgs; m++ {
+ if !sendMsg(c, buf[:]) || !recvMsg(c, buf[:]) {
+ break
+ }
+ }
+ }()
+ }
+ for i := 0; i < conns; i++ {
+ <-done
+ }
+ ln.Close()
+ <-done
+}
+
+// Test that >32-bit reads work on 64-bit systems.
+// On 32-bit systems this tests that maxint reads work.
+func TestTCPBig(t *testing.T) {
+ if !*testTCPBig {
+ t.Skip("test disabled; use -tcpbig to enable")
+ }
+
+ for _, writev := range []bool{false, true} {
+ t.Run(fmt.Sprintf("writev=%v", writev), func(t *testing.T) {
+ ln := newLocalListener(t, "tcp")
+ defer ln.Close()
+
+ x := int(1 << 30)
+ x = x*5 + 1<<20 // just over 5 GB on 64-bit, just over 1GB on 32-bit
+ done := make(chan int)
+ go func() {
+ defer close(done)
+ c, err := ln.Accept()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ buf := make([]byte, x)
+ var n int
+ if writev {
+ var n64 int64
+ n64, err = (&Buffers{buf}).WriteTo(c)
+ n = int(n64)
+ } else {
+ n, err = c.Write(buf)
+ }
+ if n != len(buf) || err != nil {
+ t.Errorf("Write(buf) = %d, %v, want %d, nil", n, err, x)
+ }
+ c.Close()
+ }()
+
+ c, err := Dial("tcp", ln.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ buf := make([]byte, x)
+ n, err := io.ReadFull(c, buf)
+ if n != len(buf) || err != nil {
+ t.Errorf("Read(buf) = %d, %v, want %d, nil", n, err, x)
+ }
+ c.Close()
+ <-done
+ })
+ }
+}
+
+func TestCopyPipeIntoTCP(t *testing.T) {
+ ln := newLocalListener(t, "tcp")
+ defer ln.Close()
+
+ errc := make(chan error, 1)
+ defer func() {
+ if err := <-errc; err != nil {
+ t.Error(err)
+ }
+ }()
+ go func() {
+ c, err := ln.Accept()
+ if err != nil {
+ errc <- err
+ return
+ }
+ defer c.Close()
+
+ buf := make([]byte, 100)
+ n, err := io.ReadFull(c, buf)
+ if err != io.ErrUnexpectedEOF || n != 2 {
+ errc <- fmt.Errorf("got err=%q n=%v; want err=%q n=2", err, n, io.ErrUnexpectedEOF)
+ return
+ }
+
+ errc <- nil
+ }()
+
+ c, err := Dial("tcp", ln.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+
+ r, w, err := os.Pipe()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer r.Close()
+
+ errc2 := make(chan error, 1)
+ defer func() {
+ if err := <-errc2; err != nil {
+ t.Error(err)
+ }
+ }()
+
+ defer w.Close()
+
+ go func() {
+ _, err := io.Copy(c, r)
+ errc2 <- err
+ }()
+
+ // Split write into 2 packets. That makes Windows TransmitFile
+ // drop second packet.
+ packet := make([]byte, 1)
+ _, err = w.Write(packet)
+ if err != nil {
+ t.Fatal(err)
+ }
+ time.Sleep(100 * time.Millisecond)
+ _, err = w.Write(packet)
+ if err != nil {
+ t.Fatal(err)
+ }
+}
+
+func BenchmarkSetReadDeadline(b *testing.B) {
+ ln := newLocalListener(b, "tcp")
+ defer ln.Close()
+ var serv Conn
+ done := make(chan error)
+ go func() {
+ var err error
+ serv, err = ln.Accept()
+ done <- err
+ }()
+ c, err := Dial("tcp", ln.Addr().String())
+ if err != nil {
+ b.Fatal(err)
+ }
+ defer c.Close()
+ if err := <-done; err != nil {
+ b.Fatal(err)
+ }
+ defer serv.Close()
+ c.SetWriteDeadline(time.Now().Add(2 * time.Hour))
+ deadline := time.Now().Add(time.Hour)
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ c.SetReadDeadline(deadline)
+ deadline = deadline.Add(1)
+ }
+}
+
+func TestDialTCPDefaultKeepAlive(t *testing.T) {
+ ln := newLocalListener(t, "tcp")
+ defer ln.Close()
+
+ got := time.Duration(-1)
+ testHookSetKeepAlive = func(d time.Duration) { got = d }
+ defer func() { testHookSetKeepAlive = func(time.Duration) {} }()
+
+ c, err := DialTCP("tcp", nil, ln.Addr().(*TCPAddr))
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+
+ if got != defaultTCPKeepAlive {
+ t.Errorf("got keepalive %v; want %v", got, defaultTCPKeepAlive)
+ }
+}
diff --git a/src/net/tcpsock_unix_test.go b/src/net/tcpsock_unix_test.go
new file mode 100644
index 0000000..35fd937
--- /dev/null
+++ b/src/net/tcpsock_unix_test.go
@@ -0,0 +1,112 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !js && !plan9 && !wasip1 && !windows
+
+package net
+
+import (
+ "context"
+ "math/rand"
+ "runtime"
+ "sync"
+ "syscall"
+ "testing"
+ "time"
+)
+
+// See golang.org/issue/14548.
+func TestTCPSpuriousConnSetupCompletion(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping in short mode")
+ }
+
+ ln := newLocalListener(t, "tcp")
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func(ln Listener) {
+ defer wg.Done()
+ for {
+ c, err := ln.Accept()
+ if err != nil {
+ return
+ }
+ wg.Add(1)
+ go func(c Conn) {
+ var b [1]byte
+ c.Read(b[:])
+ c.Close()
+ wg.Done()
+ }(c)
+ }
+ }(ln)
+
+ attempts := int(1e4) // larger is better
+ wg.Add(attempts)
+ throttle := make(chan struct{}, runtime.GOMAXPROCS(-1)*2)
+ for i := 0; i < attempts; i++ {
+ throttle <- struct{}{}
+ go func(i int) {
+ defer func() {
+ <-throttle
+ wg.Done()
+ }()
+ d := Dialer{Timeout: 50 * time.Millisecond}
+ c, err := d.Dial(ln.Addr().Network(), ln.Addr().String())
+ if err != nil {
+ if perr := parseDialError(err); perr != nil {
+ t.Errorf("#%d: %v (original error: %v)", i, perr, err)
+ }
+ return
+ }
+ var b [1]byte
+ if _, err := c.Write(b[:]); err != nil {
+ if perr := parseWriteError(err); perr != nil {
+ t.Errorf("#%d: %v", i, err)
+ }
+ if samePlatformError(err, syscall.ENOTCONN) {
+ t.Errorf("#%d: %v", i, err)
+ }
+ }
+ c.Close()
+ }(i)
+ }
+
+ ln.Close()
+ wg.Wait()
+}
+
+// Issue 19289.
+// Test that a canceled Dial does not cause a subsequent Dial to succeed.
+func TestTCPSpuriousConnSetupCompletionWithCancel(t *testing.T) {
+ mustHaveExternalNetwork(t)
+
+ defer dnsWaitGroup.Wait()
+ t.Parallel()
+ const tries = 10000
+ var wg sync.WaitGroup
+ wg.Add(tries * 2)
+ sem := make(chan bool, 5)
+ for i := 0; i < tries; i++ {
+ sem <- true
+ ctx, cancel := context.WithCancel(context.Background())
+ go func() {
+ defer wg.Done()
+ time.Sleep(time.Duration(rand.Int63n(int64(5 * time.Millisecond))))
+ cancel()
+ }()
+ go func(i int) {
+ defer wg.Done()
+ var dialer Dialer
+ // Try to connect to a real host on a port
+ // that it is not listening on.
+ _, err := dialer.DialContext(ctx, "tcp", "golang.org:3")
+ if err == nil {
+ t.Errorf("Dial to unbound port succeeded on attempt %d", i)
+ }
+ <-sem
+ }(i)
+ }
+ wg.Wait()
+}
diff --git a/src/net/tcpsockopt_darwin.go b/src/net/tcpsockopt_darwin.go
new file mode 100644
index 0000000..53c6756
--- /dev/null
+++ b/src/net/tcpsockopt_darwin.go
@@ -0,0 +1,25 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "runtime"
+ "syscall"
+ "time"
+)
+
+// syscall.TCP_KEEPINTVL is missing on some darwin architectures.
+const sysTCP_KEEPINTVL = 0x101
+
+func setKeepAlivePeriod(fd *netFD, d time.Duration) error {
+ // The kernel expects seconds so round to next highest second.
+ secs := int(roundDurationUp(d, time.Second))
+ if err := fd.pfd.SetsockoptInt(syscall.IPPROTO_TCP, sysTCP_KEEPINTVL, secs); err != nil {
+ return wrapSyscallError("setsockopt", err)
+ }
+ err := fd.pfd.SetsockoptInt(syscall.IPPROTO_TCP, syscall.TCP_KEEPALIVE, secs)
+ runtime.KeepAlive(fd)
+ return wrapSyscallError("setsockopt", err)
+}
diff --git a/src/net/tcpsockopt_dragonfly.go b/src/net/tcpsockopt_dragonfly.go
new file mode 100644
index 0000000..b473c02
--- /dev/null
+++ b/src/net/tcpsockopt_dragonfly.go
@@ -0,0 +1,23 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "runtime"
+ "syscall"
+ "time"
+)
+
+func setKeepAlivePeriod(fd *netFD, d time.Duration) error {
+ // The kernel expects milliseconds so round to next highest
+ // millisecond.
+ msecs := int(roundDurationUp(d, time.Millisecond))
+ if err := fd.pfd.SetsockoptInt(syscall.IPPROTO_TCP, syscall.TCP_KEEPINTVL, msecs); err != nil {
+ return wrapSyscallError("setsockopt", err)
+ }
+ err := fd.pfd.SetsockoptInt(syscall.IPPROTO_TCP, syscall.TCP_KEEPIDLE, msecs)
+ runtime.KeepAlive(fd)
+ return wrapSyscallError("setsockopt", err)
+}
diff --git a/src/net/tcpsockopt_openbsd.go b/src/net/tcpsockopt_openbsd.go
new file mode 100644
index 0000000..10e1bef
--- /dev/null
+++ b/src/net/tcpsockopt_openbsd.go
@@ -0,0 +1,16 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "syscall"
+ "time"
+)
+
+func setKeepAlivePeriod(fd *netFD, d time.Duration) error {
+ // OpenBSD has no user-settable per-socket TCP keepalive
+ // options.
+ return syscall.ENOPROTOOPT
+}
diff --git a/src/net/tcpsockopt_plan9.go b/src/net/tcpsockopt_plan9.go
new file mode 100644
index 0000000..264359d
--- /dev/null
+++ b/src/net/tcpsockopt_plan9.go
@@ -0,0 +1,24 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// TCP socket options for plan9
+
+package net
+
+import (
+ "internal/itoa"
+ "syscall"
+ "time"
+)
+
+func setNoDelay(fd *netFD, noDelay bool) error {
+ return syscall.EPLAN9
+}
+
+// Set keep alive period.
+func setKeepAlivePeriod(fd *netFD, d time.Duration) error {
+ cmd := "keepalive " + itoa.Itoa(int(d/time.Millisecond))
+ _, e := fd.ctl.WriteAt([]byte(cmd), 0)
+ return e
+}
diff --git a/src/net/tcpsockopt_posix.go b/src/net/tcpsockopt_posix.go
new file mode 100644
index 0000000..d708f04
--- /dev/null
+++ b/src/net/tcpsockopt_posix.go
@@ -0,0 +1,18 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build unix || windows
+
+package net
+
+import (
+ "runtime"
+ "syscall"
+)
+
+func setNoDelay(fd *netFD, noDelay bool) error {
+ err := fd.pfd.SetsockoptInt(syscall.IPPROTO_TCP, syscall.TCP_NODELAY, boolint(noDelay))
+ runtime.KeepAlive(fd)
+ return wrapSyscallError("setsockopt", err)
+}
diff --git a/src/net/tcpsockopt_solaris.go b/src/net/tcpsockopt_solaris.go
new file mode 100644
index 0000000..f15e589
--- /dev/null
+++ b/src/net/tcpsockopt_solaris.go
@@ -0,0 +1,32 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "runtime"
+ "syscall"
+ "time"
+)
+
+func setKeepAlivePeriod(fd *netFD, d time.Duration) error {
+ // The kernel expects milliseconds so round to next highest
+ // millisecond.
+ msecs := int(roundDurationUp(d, time.Millisecond))
+
+ // Normally we'd do
+ // syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_TCP, syscall.TCP_KEEPINTVL, secs)
+ // here, but we can't because Solaris does not have TCP_KEEPINTVL.
+ // Solaris has TCP_KEEPALIVE_ABORT_THRESHOLD, but it's not the same
+ // thing, it refers to the total time until aborting (not between
+ // probes), and it uses an exponential backoff algorithm instead of
+ // waiting the same time between probes. We can't hope for the best
+ // and do it anyway, like on Darwin, because Solaris might eventually
+ // allocate a constant with a different meaning for the value of
+ // TCP_KEEPINTVL on illumos.
+
+ err := fd.pfd.SetsockoptInt(syscall.IPPROTO_TCP, syscall.TCP_KEEPALIVE_THRESHOLD, msecs)
+ runtime.KeepAlive(fd)
+ return wrapSyscallError("setsockopt", err)
+}
diff --git a/src/net/tcpsockopt_stub.go b/src/net/tcpsockopt_stub.go
new file mode 100644
index 0000000..f778143
--- /dev/null
+++ b/src/net/tcpsockopt_stub.go
@@ -0,0 +1,20 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build (js && wasm) || wasip1
+
+package net
+
+import (
+ "syscall"
+ "time"
+)
+
+func setNoDelay(fd *netFD, noDelay bool) error {
+ return syscall.ENOPROTOOPT
+}
+
+func setKeepAlivePeriod(fd *netFD, d time.Duration) error {
+ return syscall.ENOPROTOOPT
+}
diff --git a/src/net/tcpsockopt_unix.go b/src/net/tcpsockopt_unix.go
new file mode 100644
index 0000000..bdcdc40
--- /dev/null
+++ b/src/net/tcpsockopt_unix.go
@@ -0,0 +1,24 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build aix || freebsd || linux || netbsd
+
+package net
+
+import (
+ "runtime"
+ "syscall"
+ "time"
+)
+
+func setKeepAlivePeriod(fd *netFD, d time.Duration) error {
+ // The kernel expects seconds so round to next highest second.
+ secs := int(roundDurationUp(d, time.Second))
+ if err := fd.pfd.SetsockoptInt(syscall.IPPROTO_TCP, syscall.TCP_KEEPINTVL, secs); err != nil {
+ return wrapSyscallError("setsockopt", err)
+ }
+ err := fd.pfd.SetsockoptInt(syscall.IPPROTO_TCP, syscall.TCP_KEEPIDLE, secs)
+ runtime.KeepAlive(fd)
+ return wrapSyscallError("setsockopt", err)
+}
diff --git a/src/net/tcpsockopt_windows.go b/src/net/tcpsockopt_windows.go
new file mode 100644
index 0000000..4a0b094
--- /dev/null
+++ b/src/net/tcpsockopt_windows.go
@@ -0,0 +1,29 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "os"
+ "runtime"
+ "syscall"
+ "time"
+ "unsafe"
+)
+
+func setKeepAlivePeriod(fd *netFD, d time.Duration) error {
+ // The kernel expects milliseconds so round to next highest
+ // millisecond.
+ msecs := uint32(roundDurationUp(d, time.Millisecond))
+ ka := syscall.TCPKeepalive{
+ OnOff: 1,
+ Time: msecs,
+ Interval: msecs,
+ }
+ ret := uint32(0)
+ size := uint32(unsafe.Sizeof(ka))
+ err := fd.pfd.WSAIoctl(syscall.SIO_KEEPALIVE_VALS, (*byte)(unsafe.Pointer(&ka)), size, nil, 0, &ret, nil, 0)
+ runtime.KeepAlive(fd)
+ return os.NewSyscallError("wsaioctl", err)
+}
diff --git a/src/net/testdata/aliases b/src/net/testdata/aliases
new file mode 100644
index 0000000..9330ba0
--- /dev/null
+++ b/src/net/testdata/aliases
@@ -0,0 +1,8 @@
+127.0.0.1 test
+127.0.0.2 test2.example.com 2.test
+127.0.0.3 3.test test3.example.com
+127.0.0.4 example.com
+127.0.0.5 test4.example.com 4.test 5.test test5.example.com
+
+# must be a non resolvable domain on the internet
+127.0.1.1 invalid.test invalid.invalid
diff --git a/src/net/testdata/case-hosts b/src/net/testdata/case-hosts
new file mode 100644
index 0000000..1f30df1
--- /dev/null
+++ b/src/net/testdata/case-hosts
@@ -0,0 +1,2 @@
+127.0.0.1 PreserveMe PreserveMe.local
+::1 PreserveMe PreserveMe.local
diff --git a/src/net/testdata/domain-resolv.conf b/src/net/testdata/domain-resolv.conf
new file mode 100644
index 0000000..ff26918
--- /dev/null
+++ b/src/net/testdata/domain-resolv.conf
@@ -0,0 +1,5 @@
+# /etc/resolv.conf
+
+search test invalid
+domain localdomain
+nameserver 8.8.8.8
diff --git a/src/net/testdata/empty-resolv.conf b/src/net/testdata/empty-resolv.conf
new file mode 100644
index 0000000..c4b2b57
--- /dev/null
+++ b/src/net/testdata/empty-resolv.conf
@@ -0,0 +1 @@
+# /etc/resolv.conf
diff --git a/src/net/testdata/freebsd-usevc-resolv.conf b/src/net/testdata/freebsd-usevc-resolv.conf
new file mode 100644
index 0000000..4afb281
--- /dev/null
+++ b/src/net/testdata/freebsd-usevc-resolv.conf
@@ -0,0 +1 @@
+options usevc \ No newline at end of file
diff --git a/src/net/testdata/hosts b/src/net/testdata/hosts
new file mode 100644
index 0000000..3ed83ff
--- /dev/null
+++ b/src/net/testdata/hosts
@@ -0,0 +1,11 @@
+255.255.255.255 broadcasthost
+127.0.0.2 odin
+127.0.0.3 odin # inline comment
+::2 odin
+127.1.1.1 thor
+# aliases
+127.1.1.2 ullr ullrhost
+fe80::1%lo0 localhost
+# Bogus entries that must be ignored.
+123.123.123 loki
+321.321.321.321
diff --git a/src/net/testdata/igmp b/src/net/testdata/igmp
new file mode 100644
index 0000000..5f380a2
--- /dev/null
+++ b/src/net/testdata/igmp
@@ -0,0 +1,24 @@
+Idx Device : Count Querier Group Users Timer Reporter
+1 lo : 1 V3
+ 010000E0 1 0:00000000 0
+2 eth0 : 2 V2
+ FB0000E0 1 0:00000000 1
+ 010000E0 1 0:00000000 0
+3 eth1 : 1 V3
+ 010000E0 1 0:00000000 0
+4 eth2 : 1 V3
+ 010000E0 1 0:00000000 0
+5 eth0.100 : 2 V3
+ FB0000E0 1 0:00000000 0
+ 010000E0 1 0:00000000 0
+6 eth0.101 : 2 V3
+ FB0000E0 1 0:00000000 0
+ 010000E0 1 0:00000000 0
+7 eth0.102 : 2 V3
+ FB0000E0 1 0:00000000 0
+ 010000E0 1 0:00000000 0
+8 eth0.103 : 2 V3
+ FB0000E0 1 0:00000000 0
+ 010000E0 1 0:00000000 0
+9 device1tap2: 1 V3
+ 010000E0 1 0:00000000 0
diff --git a/src/net/testdata/igmp6 b/src/net/testdata/igmp6
new file mode 100644
index 0000000..6cd5a2d
--- /dev/null
+++ b/src/net/testdata/igmp6
@@ -0,0 +1,18 @@
+1 lo ff020000000000000000000000000001 1 0000000C 0
+2 eth0 ff0200000000000000000001ffac891e 1 00000006 0
+2 eth0 ff020000000000000000000000000001 1 0000000C 0
+3 eth1 ff0200000000000000000001ffac8928 2 00000006 0
+3 eth1 ff020000000000000000000000000001 1 0000000C 0
+4 eth2 ff0200000000000000000001ffac8932 2 00000006 0
+4 eth2 ff020000000000000000000000000001 1 0000000C 0
+5 eth0.100 ff0200000000000000000001ffac891e 1 00000004 0
+5 eth0.100 ff020000000000000000000000000001 1 0000000C 0
+6 pan0 ff020000000000000000000000000001 1 0000000C 0
+7 eth0.101 ff0200000000000000000001ffac891e 1 00000004 0
+7 eth0.101 ff020000000000000000000000000001 1 0000000C 0
+8 eth0.102 ff0200000000000000000001ffac891e 1 00000004 0
+8 eth0.102 ff020000000000000000000000000001 1 0000000C 0
+9 eth0.103 ff0200000000000000000001ffac891e 1 00000004 0
+9 eth0.103 ff020000000000000000000000000001 1 0000000C 0
+10 device1tap2 ff0200000000000000000001ff4cc3a3 1 00000004 0
+10 device1tap2 ff020000000000000000000000000001 1 0000000C 0
diff --git a/src/net/testdata/invalid-ndots-resolv.conf b/src/net/testdata/invalid-ndots-resolv.conf
new file mode 100644
index 0000000..084c164
--- /dev/null
+++ b/src/net/testdata/invalid-ndots-resolv.conf
@@ -0,0 +1 @@
+options ndots:invalid \ No newline at end of file
diff --git a/src/net/testdata/ipv4-hosts b/src/net/testdata/ipv4-hosts
new file mode 100644
index 0000000..6b99675
--- /dev/null
+++ b/src/net/testdata/ipv4-hosts
@@ -0,0 +1,8 @@
+# See https://tools.ietf.org/html/rfc1123.
+
+# internet address and host name
+127.0.0.1 localhost # inline comment separated by tab
+127.0.0.2 localhost # inline comment separated by space
+
+# internet address, host name and aliases
+127.0.0.3 localhost localhost.localdomain
diff --git a/src/net/testdata/ipv6-hosts b/src/net/testdata/ipv6-hosts
new file mode 100644
index 0000000..f78b7fc
--- /dev/null
+++ b/src/net/testdata/ipv6-hosts
@@ -0,0 +1,11 @@
+# See https://tools.ietf.org/html/rfc5952, https://tools.ietf.org/html/rfc4007.
+
+# internet address and host name
+::1 localhost # inline comment separated by tab
+fe80:0000:0000:0000:0000:0000:0000:0001 localhost # inline comment separated by space
+
+# internet address with zone identifier and host name
+fe80:0000:0000:0000:0000:0000:0000:0002%lo0 localhost
+
+# internet address, host name and aliases
+fe80::3%lo0 localhost localhost.localdomain
diff --git a/src/net/testdata/large-ndots-resolv.conf b/src/net/testdata/large-ndots-resolv.conf
new file mode 100644
index 0000000..72968ee
--- /dev/null
+++ b/src/net/testdata/large-ndots-resolv.conf
@@ -0,0 +1 @@
+options ndots:16 \ No newline at end of file
diff --git a/src/net/testdata/linux-use-vc-resolv.conf b/src/net/testdata/linux-use-vc-resolv.conf
new file mode 100644
index 0000000..4e4a58b
--- /dev/null
+++ b/src/net/testdata/linux-use-vc-resolv.conf
@@ -0,0 +1 @@
+options use-vc \ No newline at end of file
diff --git a/src/net/testdata/negative-ndots-resolv.conf b/src/net/testdata/negative-ndots-resolv.conf
new file mode 100644
index 0000000..c11e0cc
--- /dev/null
+++ b/src/net/testdata/negative-ndots-resolv.conf
@@ -0,0 +1 @@
+options ndots:-1 \ No newline at end of file
diff --git a/src/net/testdata/openbsd-resolv.conf b/src/net/testdata/openbsd-resolv.conf
new file mode 100644
index 0000000..8281a91
--- /dev/null
+++ b/src/net/testdata/openbsd-resolv.conf
@@ -0,0 +1,5 @@
+# Generated by vio0 dhclient
+search c.symbolic-datum-552.internal.
+nameserver 169.254.169.254
+nameserver 10.240.0.1
+lookup file bind
diff --git a/src/net/testdata/openbsd-tcp-resolv.conf b/src/net/testdata/openbsd-tcp-resolv.conf
new file mode 100644
index 0000000..7929e50
--- /dev/null
+++ b/src/net/testdata/openbsd-tcp-resolv.conf
@@ -0,0 +1 @@
+options tcp \ No newline at end of file
diff --git a/src/net/testdata/resolv.conf b/src/net/testdata/resolv.conf
new file mode 100644
index 0000000..04e87ee
--- /dev/null
+++ b/src/net/testdata/resolv.conf
@@ -0,0 +1,8 @@
+# /etc/resolv.conf
+
+domain localdomain
+nameserver 8.8.8.8
+nameserver 2001:4860:4860::8888
+nameserver fe80::1%lo0
+options ndots:5 timeout:10 attempts:3 rotate
+options attempts 3
diff --git a/src/net/testdata/search-resolv.conf b/src/net/testdata/search-resolv.conf
new file mode 100644
index 0000000..1c846bf
--- /dev/null
+++ b/src/net/testdata/search-resolv.conf
@@ -0,0 +1,5 @@
+# /etc/resolv.conf
+
+domain localdomain
+search test invalid
+nameserver 8.8.8.8
diff --git a/src/net/testdata/search-single-dot-resolv.conf b/src/net/testdata/search-single-dot-resolv.conf
new file mode 100644
index 0000000..934cd3e
--- /dev/null
+++ b/src/net/testdata/search-single-dot-resolv.conf
@@ -0,0 +1,5 @@
+# /etc/resolv.conf
+
+domain localdomain
+search .
+nameserver 8.8.8.8
diff --git a/src/net/testdata/single-request-reopen-resolv.conf b/src/net/testdata/single-request-reopen-resolv.conf
new file mode 100644
index 0000000..9bddeb3
--- /dev/null
+++ b/src/net/testdata/single-request-reopen-resolv.conf
@@ -0,0 +1 @@
+options single-request-reopen \ No newline at end of file
diff --git a/src/net/testdata/single-request-resolv.conf b/src/net/testdata/single-request-resolv.conf
new file mode 100644
index 0000000..5595d29
--- /dev/null
+++ b/src/net/testdata/single-request-resolv.conf
@@ -0,0 +1 @@
+options single-request \ No newline at end of file
diff --git a/src/net/testdata/singleline-hosts b/src/net/testdata/singleline-hosts
new file mode 100644
index 0000000..5f5f74a
--- /dev/null
+++ b/src/net/testdata/singleline-hosts
@@ -0,0 +1 @@
+127.0.0.2 odin \ No newline at end of file
diff --git a/src/net/textproto/header.go b/src/net/textproto/header.go
new file mode 100644
index 0000000..a58df7a
--- /dev/null
+++ b/src/net/textproto/header.go
@@ -0,0 +1,56 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package textproto
+
+// A MIMEHeader represents a MIME-style header mapping
+// keys to sets of values.
+type MIMEHeader map[string][]string
+
+// Add adds the key, value pair to the header.
+// It appends to any existing values associated with key.
+func (h MIMEHeader) Add(key, value string) {
+ key = CanonicalMIMEHeaderKey(key)
+ h[key] = append(h[key], value)
+}
+
+// Set sets the header entries associated with key to
+// the single element value. It replaces any existing
+// values associated with key.
+func (h MIMEHeader) Set(key, value string) {
+ h[CanonicalMIMEHeaderKey(key)] = []string{value}
+}
+
+// Get gets the first value associated with the given key.
+// It is case insensitive; CanonicalMIMEHeaderKey is used
+// to canonicalize the provided key.
+// If there are no values associated with the key, Get returns "".
+// To use non-canonical keys, access the map directly.
+func (h MIMEHeader) Get(key string) string {
+ if h == nil {
+ return ""
+ }
+ v := h[CanonicalMIMEHeaderKey(key)]
+ if len(v) == 0 {
+ return ""
+ }
+ return v[0]
+}
+
+// Values returns all values associated with the given key.
+// It is case insensitive; CanonicalMIMEHeaderKey is
+// used to canonicalize the provided key. To use non-canonical
+// keys, access the map directly.
+// The returned slice is not a copy.
+func (h MIMEHeader) Values(key string) []string {
+ if h == nil {
+ return nil
+ }
+ return h[CanonicalMIMEHeaderKey(key)]
+}
+
+// Del deletes the values associated with key.
+func (h MIMEHeader) Del(key string) {
+ delete(h, CanonicalMIMEHeaderKey(key))
+}
diff --git a/src/net/textproto/header_test.go b/src/net/textproto/header_test.go
new file mode 100644
index 0000000..de9405c
--- /dev/null
+++ b/src/net/textproto/header_test.go
@@ -0,0 +1,54 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package textproto
+
+import "testing"
+
+type canonicalHeaderKeyTest struct {
+ in, out string
+}
+
+var canonicalHeaderKeyTests = []canonicalHeaderKeyTest{
+ {"a-b-c", "A-B-C"},
+ {"a-1-c", "A-1-C"},
+ {"User-Agent", "User-Agent"},
+ {"uSER-aGENT", "User-Agent"},
+ {"user-agent", "User-Agent"},
+ {"USER-AGENT", "User-Agent"},
+
+ // Other valid tchar bytes in tokens:
+ {"foo-bar_baz", "Foo-Bar_baz"},
+ {"foo-bar$baz", "Foo-Bar$baz"},
+ {"foo-bar~baz", "Foo-Bar~baz"},
+ {"foo-bar*baz", "Foo-Bar*baz"},
+
+ // Non-ASCII or anything with spaces or non-token chars is unchanged:
+ {"üser-agenT", "üser-agenT"},
+ {"a B", "a B"},
+
+ // This caused a panic due to mishandling of a space:
+ {"C Ontent-Transfer-Encoding", "C Ontent-Transfer-Encoding"},
+ {"foo bar", "foo bar"},
+}
+
+func TestCanonicalMIMEHeaderKey(t *testing.T) {
+ for _, tt := range canonicalHeaderKeyTests {
+ if s := CanonicalMIMEHeaderKey(tt.in); s != tt.out {
+ t.Errorf("CanonicalMIMEHeaderKey(%q) = %q, want %q", tt.in, s, tt.out)
+ }
+ }
+}
+
+// Issue #34799 add a Header method to get multiple values []string, with canonicalized key
+func TestMIMEHeaderMultipleValues(t *testing.T) {
+ testHeader := MIMEHeader{
+ "Set-Cookie": {"cookie 1", "cookie 2"},
+ }
+ values := testHeader.Values("set-cookie")
+ n := len(values)
+ if n != 2 {
+ t.Errorf("count: %d; want 2", n)
+ }
+}
diff --git a/src/net/textproto/pipeline.go b/src/net/textproto/pipeline.go
new file mode 100644
index 0000000..1928a30
--- /dev/null
+++ b/src/net/textproto/pipeline.go
@@ -0,0 +1,118 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package textproto
+
+import (
+ "sync"
+)
+
+// A Pipeline manages a pipelined in-order request/response sequence.
+//
+// To use a Pipeline p to manage multiple clients on a connection,
+// each client should run:
+//
+// id := p.Next() // take a number
+//
+// p.StartRequest(id) // wait for turn to send request
+// «send request»
+// p.EndRequest(id) // notify Pipeline that request is sent
+//
+// p.StartResponse(id) // wait for turn to read response
+// «read response»
+// p.EndResponse(id) // notify Pipeline that response is read
+//
+// A pipelined server can use the same calls to ensure that
+// responses computed in parallel are written in the correct order.
+type Pipeline struct {
+ mu sync.Mutex
+ id uint
+ request sequencer
+ response sequencer
+}
+
+// Next returns the next id for a request/response pair.
+func (p *Pipeline) Next() uint {
+ p.mu.Lock()
+ id := p.id
+ p.id++
+ p.mu.Unlock()
+ return id
+}
+
+// StartRequest blocks until it is time to send (or, if this is a server, receive)
+// the request with the given id.
+func (p *Pipeline) StartRequest(id uint) {
+ p.request.Start(id)
+}
+
+// EndRequest notifies p that the request with the given id has been sent
+// (or, if this is a server, received).
+func (p *Pipeline) EndRequest(id uint) {
+ p.request.End(id)
+}
+
+// StartResponse blocks until it is time to receive (or, if this is a server, send)
+// the request with the given id.
+func (p *Pipeline) StartResponse(id uint) {
+ p.response.Start(id)
+}
+
+// EndResponse notifies p that the response with the given id has been received
+// (or, if this is a server, sent).
+func (p *Pipeline) EndResponse(id uint) {
+ p.response.End(id)
+}
+
+// A sequencer schedules a sequence of numbered events that must
+// happen in order, one after the other. The event numbering must start
+// at 0 and increment without skipping. The event number wraps around
+// safely as long as there are not 2^32 simultaneous events pending.
+type sequencer struct {
+ mu sync.Mutex
+ id uint
+ wait map[uint]chan struct{}
+}
+
+// Start waits until it is time for the event numbered id to begin.
+// That is, except for the first event, it waits until End(id-1) has
+// been called.
+func (s *sequencer) Start(id uint) {
+ s.mu.Lock()
+ if s.id == id {
+ s.mu.Unlock()
+ return
+ }
+ c := make(chan struct{})
+ if s.wait == nil {
+ s.wait = make(map[uint]chan struct{})
+ }
+ s.wait[id] = c
+ s.mu.Unlock()
+ <-c
+}
+
+// End notifies the sequencer that the event numbered id has completed,
+// allowing it to schedule the event numbered id+1. It is a run-time error
+// to call End with an id that is not the number of the active event.
+func (s *sequencer) End(id uint) {
+ s.mu.Lock()
+ if s.id != id {
+ s.mu.Unlock()
+ panic("out of sync")
+ }
+ id++
+ s.id = id
+ if s.wait == nil {
+ s.wait = make(map[uint]chan struct{})
+ }
+ c, ok := s.wait[id]
+ if ok {
+ delete(s.wait, id)
+ }
+ s.mu.Unlock()
+ if ok {
+ close(c)
+ }
+}
diff --git a/src/net/textproto/reader.go b/src/net/textproto/reader.go
new file mode 100644
index 0000000..fcd1a01
--- /dev/null
+++ b/src/net/textproto/reader.go
@@ -0,0 +1,840 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package textproto
+
+import (
+ "bufio"
+ "bytes"
+ "errors"
+ "fmt"
+ "io"
+ "math"
+ "strconv"
+ "strings"
+ "sync"
+)
+
+// TODO: This should be a distinguishable error (ErrMessageTooLarge)
+// to allow mime/multipart to detect it.
+var errMessageTooLarge = errors.New("message too large")
+
+// A Reader implements convenience methods for reading requests
+// or responses from a text protocol network connection.
+type Reader struct {
+ R *bufio.Reader
+ dot *dotReader
+ buf []byte // a re-usable buffer for readContinuedLineSlice
+}
+
+// NewReader returns a new Reader reading from r.
+//
+// To avoid denial of service attacks, the provided bufio.Reader
+// should be reading from an io.LimitReader or similar Reader to bound
+// the size of responses.
+func NewReader(r *bufio.Reader) *Reader {
+ return &Reader{R: r}
+}
+
+// ReadLine reads a single line from r,
+// eliding the final \n or \r\n from the returned string.
+func (r *Reader) ReadLine() (string, error) {
+ line, err := r.readLineSlice(-1)
+ return string(line), err
+}
+
+// ReadLineBytes is like ReadLine but returns a []byte instead of a string.
+func (r *Reader) ReadLineBytes() ([]byte, error) {
+ line, err := r.readLineSlice(-1)
+ if line != nil {
+ line = bytes.Clone(line)
+ }
+ return line, err
+}
+
+// readLineSlice reads a single line from r,
+// up to lim bytes long (or unlimited if lim is less than 0),
+// eliding the final \r or \r\n from the returned string.
+func (r *Reader) readLineSlice(lim int64) ([]byte, error) {
+ r.closeDot()
+ var line []byte
+ for {
+ l, more, err := r.R.ReadLine()
+ if err != nil {
+ return nil, err
+ }
+ if lim >= 0 && int64(len(line))+int64(len(l)) > lim {
+ return nil, errMessageTooLarge
+ }
+ // Avoid the copy if the first call produced a full line.
+ if line == nil && !more {
+ return l, nil
+ }
+ line = append(line, l...)
+ if !more {
+ break
+ }
+ }
+ return line, nil
+}
+
+// ReadContinuedLine reads a possibly continued line from r,
+// eliding the final trailing ASCII white space.
+// Lines after the first are considered continuations if they
+// begin with a space or tab character. In the returned data,
+// continuation lines are separated from the previous line
+// only by a single space: the newline and leading white space
+// are removed.
+//
+// For example, consider this input:
+//
+// Line 1
+// continued...
+// Line 2
+//
+// The first call to ReadContinuedLine will return "Line 1 continued..."
+// and the second will return "Line 2".
+//
+// Empty lines are never continued.
+func (r *Reader) ReadContinuedLine() (string, error) {
+ line, err := r.readContinuedLineSlice(-1, noValidation)
+ return string(line), err
+}
+
+// trim returns s with leading and trailing spaces and tabs removed.
+// It does not assume Unicode or UTF-8.
+func trim(s []byte) []byte {
+ i := 0
+ for i < len(s) && (s[i] == ' ' || s[i] == '\t') {
+ i++
+ }
+ n := len(s)
+ for n > i && (s[n-1] == ' ' || s[n-1] == '\t') {
+ n--
+ }
+ return s[i:n]
+}
+
+// ReadContinuedLineBytes is like ReadContinuedLine but
+// returns a []byte instead of a string.
+func (r *Reader) ReadContinuedLineBytes() ([]byte, error) {
+ line, err := r.readContinuedLineSlice(-1, noValidation)
+ if line != nil {
+ line = bytes.Clone(line)
+ }
+ return line, err
+}
+
+// readContinuedLineSlice reads continued lines from the reader buffer,
+// returning a byte slice with all lines. The validateFirstLine function
+// is run on the first read line, and if it returns an error then this
+// error is returned from readContinuedLineSlice.
+// It reads up to lim bytes of data (or unlimited if lim is less than 0).
+func (r *Reader) readContinuedLineSlice(lim int64, validateFirstLine func([]byte) error) ([]byte, error) {
+ if validateFirstLine == nil {
+ return nil, fmt.Errorf("missing validateFirstLine func")
+ }
+
+ // Read the first line.
+ line, err := r.readLineSlice(lim)
+ if err != nil {
+ return nil, err
+ }
+ if len(line) == 0 { // blank line - no continuation
+ return line, nil
+ }
+
+ if err := validateFirstLine(line); err != nil {
+ return nil, err
+ }
+
+ // Optimistically assume that we have started to buffer the next line
+ // and it starts with an ASCII letter (the next header key), or a blank
+ // line, so we can avoid copying that buffered data around in memory
+ // and skipping over non-existent whitespace.
+ if r.R.Buffered() > 1 {
+ peek, _ := r.R.Peek(2)
+ if len(peek) > 0 && (isASCIILetter(peek[0]) || peek[0] == '\n') ||
+ len(peek) == 2 && peek[0] == '\r' && peek[1] == '\n' {
+ return trim(line), nil
+ }
+ }
+
+ // ReadByte or the next readLineSlice will flush the read buffer;
+ // copy the slice into buf.
+ r.buf = append(r.buf[:0], trim(line)...)
+
+ if lim < 0 {
+ lim = math.MaxInt64
+ }
+ lim -= int64(len(r.buf))
+
+ // Read continuation lines.
+ for r.skipSpace() > 0 {
+ r.buf = append(r.buf, ' ')
+ if int64(len(r.buf)) >= lim {
+ return nil, errMessageTooLarge
+ }
+ line, err := r.readLineSlice(lim - int64(len(r.buf)))
+ if err != nil {
+ break
+ }
+ r.buf = append(r.buf, trim(line)...)
+ }
+ return r.buf, nil
+}
+
+// skipSpace skips R over all spaces and returns the number of bytes skipped.
+func (r *Reader) skipSpace() int {
+ n := 0
+ for {
+ c, err := r.R.ReadByte()
+ if err != nil {
+ // Bufio will keep err until next read.
+ break
+ }
+ if c != ' ' && c != '\t' {
+ r.R.UnreadByte()
+ break
+ }
+ n++
+ }
+ return n
+}
+
+func (r *Reader) readCodeLine(expectCode int) (code int, continued bool, message string, err error) {
+ line, err := r.ReadLine()
+ if err != nil {
+ return
+ }
+ return parseCodeLine(line, expectCode)
+}
+
+func parseCodeLine(line string, expectCode int) (code int, continued bool, message string, err error) {
+ if len(line) < 4 || line[3] != ' ' && line[3] != '-' {
+ err = ProtocolError("short response: " + line)
+ return
+ }
+ continued = line[3] == '-'
+ code, err = strconv.Atoi(line[0:3])
+ if err != nil || code < 100 {
+ err = ProtocolError("invalid response code: " + line)
+ return
+ }
+ message = line[4:]
+ if 1 <= expectCode && expectCode < 10 && code/100 != expectCode ||
+ 10 <= expectCode && expectCode < 100 && code/10 != expectCode ||
+ 100 <= expectCode && expectCode < 1000 && code != expectCode {
+ err = &Error{code, message}
+ }
+ return
+}
+
+// ReadCodeLine reads a response code line of the form
+//
+// code message
+//
+// where code is a three-digit status code and the message
+// extends to the rest of the line. An example of such a line is:
+//
+// 220 plan9.bell-labs.com ESMTP
+//
+// If the prefix of the status does not match the digits in expectCode,
+// ReadCodeLine returns with err set to &Error{code, message}.
+// For example, if expectCode is 31, an error will be returned if
+// the status is not in the range [310,319].
+//
+// If the response is multi-line, ReadCodeLine returns an error.
+//
+// An expectCode <= 0 disables the check of the status code.
+func (r *Reader) ReadCodeLine(expectCode int) (code int, message string, err error) {
+ code, continued, message, err := r.readCodeLine(expectCode)
+ if err == nil && continued {
+ err = ProtocolError("unexpected multi-line response: " + message)
+ }
+ return
+}
+
+// ReadResponse reads a multi-line response of the form:
+//
+// code-message line 1
+// code-message line 2
+// ...
+// code message line n
+//
+// where code is a three-digit status code. The first line starts with the
+// code and a hyphen. The response is terminated by a line that starts
+// with the same code followed by a space. Each line in message is
+// separated by a newline (\n).
+//
+// See page 36 of RFC 959 (https://www.ietf.org/rfc/rfc959.txt) for
+// details of another form of response accepted:
+//
+// code-message line 1
+// message line 2
+// ...
+// code message line n
+//
+// If the prefix of the status does not match the digits in expectCode,
+// ReadResponse returns with err set to &Error{code, message}.
+// For example, if expectCode is 31, an error will be returned if
+// the status is not in the range [310,319].
+//
+// An expectCode <= 0 disables the check of the status code.
+func (r *Reader) ReadResponse(expectCode int) (code int, message string, err error) {
+ code, continued, message, err := r.readCodeLine(expectCode)
+ multi := continued
+ for continued {
+ line, err := r.ReadLine()
+ if err != nil {
+ return 0, "", err
+ }
+
+ var code2 int
+ var moreMessage string
+ code2, continued, moreMessage, err = parseCodeLine(line, 0)
+ if err != nil || code2 != code {
+ message += "\n" + strings.TrimRight(line, "\r\n")
+ continued = true
+ continue
+ }
+ message += "\n" + moreMessage
+ }
+ if err != nil && multi && message != "" {
+ // replace one line error message with all lines (full message)
+ err = &Error{code, message}
+ }
+ return
+}
+
+// DotReader returns a new Reader that satisfies Reads using the
+// decoded text of a dot-encoded block read from r.
+// The returned Reader is only valid until the next call
+// to a method on r.
+//
+// Dot encoding is a common framing used for data blocks
+// in text protocols such as SMTP. The data consists of a sequence
+// of lines, each of which ends in "\r\n". The sequence itself
+// ends at a line containing just a dot: ".\r\n". Lines beginning
+// with a dot are escaped with an additional dot to avoid
+// looking like the end of the sequence.
+//
+// The decoded form returned by the Reader's Read method
+// rewrites the "\r\n" line endings into the simpler "\n",
+// removes leading dot escapes if present, and stops with error io.EOF
+// after consuming (and discarding) the end-of-sequence line.
+func (r *Reader) DotReader() io.Reader {
+ r.closeDot()
+ r.dot = &dotReader{r: r}
+ return r.dot
+}
+
+type dotReader struct {
+ r *Reader
+ state int
+}
+
+// Read satisfies reads by decoding dot-encoded data read from d.r.
+func (d *dotReader) Read(b []byte) (n int, err error) {
+ // Run data through a simple state machine to
+ // elide leading dots, rewrite trailing \r\n into \n,
+ // and detect ending .\r\n line.
+ const (
+ stateBeginLine = iota // beginning of line; initial state; must be zero
+ stateDot // read . at beginning of line
+ stateDotCR // read .\r at beginning of line
+ stateCR // read \r (possibly at end of line)
+ stateData // reading data in middle of line
+ stateEOF // reached .\r\n end marker line
+ )
+ br := d.r.R
+ for n < len(b) && d.state != stateEOF {
+ var c byte
+ c, err = br.ReadByte()
+ if err != nil {
+ if err == io.EOF {
+ err = io.ErrUnexpectedEOF
+ }
+ break
+ }
+ switch d.state {
+ case stateBeginLine:
+ if c == '.' {
+ d.state = stateDot
+ continue
+ }
+ if c == '\r' {
+ d.state = stateCR
+ continue
+ }
+ d.state = stateData
+
+ case stateDot:
+ if c == '\r' {
+ d.state = stateDotCR
+ continue
+ }
+ if c == '\n' {
+ d.state = stateEOF
+ continue
+ }
+ d.state = stateData
+
+ case stateDotCR:
+ if c == '\n' {
+ d.state = stateEOF
+ continue
+ }
+ // Not part of .\r\n.
+ // Consume leading dot and emit saved \r.
+ br.UnreadByte()
+ c = '\r'
+ d.state = stateData
+
+ case stateCR:
+ if c == '\n' {
+ d.state = stateBeginLine
+ break
+ }
+ // Not part of \r\n. Emit saved \r
+ br.UnreadByte()
+ c = '\r'
+ d.state = stateData
+
+ case stateData:
+ if c == '\r' {
+ d.state = stateCR
+ continue
+ }
+ if c == '\n' {
+ d.state = stateBeginLine
+ }
+ }
+ b[n] = c
+ n++
+ }
+ if err == nil && d.state == stateEOF {
+ err = io.EOF
+ }
+ if err != nil && d.r.dot == d {
+ d.r.dot = nil
+ }
+ return
+}
+
+// closeDot drains the current DotReader if any,
+// making sure that it reads until the ending dot line.
+func (r *Reader) closeDot() {
+ if r.dot == nil {
+ return
+ }
+ buf := make([]byte, 128)
+ for r.dot != nil {
+ // When Read reaches EOF or an error,
+ // it will set r.dot == nil.
+ r.dot.Read(buf)
+ }
+}
+
+// ReadDotBytes reads a dot-encoding and returns the decoded data.
+//
+// See the documentation for the DotReader method for details about dot-encoding.
+func (r *Reader) ReadDotBytes() ([]byte, error) {
+ return io.ReadAll(r.DotReader())
+}
+
+// ReadDotLines reads a dot-encoding and returns a slice
+// containing the decoded lines, with the final \r\n or \n elided from each.
+//
+// See the documentation for the DotReader method for details about dot-encoding.
+func (r *Reader) ReadDotLines() ([]string, error) {
+ // We could use ReadDotBytes and then Split it,
+ // but reading a line at a time avoids needing a
+ // large contiguous block of memory and is simpler.
+ var v []string
+ var err error
+ for {
+ var line string
+ line, err = r.ReadLine()
+ if err != nil {
+ if err == io.EOF {
+ err = io.ErrUnexpectedEOF
+ }
+ break
+ }
+
+ // Dot by itself marks end; otherwise cut one dot.
+ if len(line) > 0 && line[0] == '.' {
+ if len(line) == 1 {
+ break
+ }
+ line = line[1:]
+ }
+ v = append(v, line)
+ }
+ return v, err
+}
+
+var colon = []byte(":")
+
+// ReadMIMEHeader reads a MIME-style header from r.
+// The header is a sequence of possibly continued Key: Value lines
+// ending in a blank line.
+// The returned map m maps CanonicalMIMEHeaderKey(key) to a
+// sequence of values in the same order encountered in the input.
+//
+// For example, consider this input:
+//
+// My-Key: Value 1
+// Long-Key: Even
+// Longer Value
+// My-Key: Value 2
+//
+// Given that input, ReadMIMEHeader returns the map:
+//
+// map[string][]string{
+// "My-Key": {"Value 1", "Value 2"},
+// "Long-Key": {"Even Longer Value"},
+// }
+func (r *Reader) ReadMIMEHeader() (MIMEHeader, error) {
+ return readMIMEHeader(r, math.MaxInt64, math.MaxInt64)
+}
+
+// readMIMEHeader is a version of ReadMIMEHeader which takes a limit on the header size.
+// It is called by the mime/multipart package.
+func readMIMEHeader(r *Reader, maxMemory, maxHeaders int64) (MIMEHeader, error) {
+ // Avoid lots of small slice allocations later by allocating one
+ // large one ahead of time which we'll cut up into smaller
+ // slices. If this isn't big enough later, we allocate small ones.
+ var strs []string
+ hint := r.upcomingHeaderKeys()
+ if hint > 0 {
+ if hint > 1000 {
+ hint = 1000 // set a cap to avoid overallocation
+ }
+ strs = make([]string, hint)
+ }
+
+ m := make(MIMEHeader, hint)
+
+ // Account for 400 bytes of overhead for the MIMEHeader, plus 200 bytes per entry.
+ // Benchmarking map creation as of go1.20, a one-entry MIMEHeader is 416 bytes and large
+ // MIMEHeaders average about 200 bytes per entry.
+ maxMemory -= 400
+ const mapEntryOverhead = 200
+
+ // The first line cannot start with a leading space.
+ if buf, err := r.R.Peek(1); err == nil && (buf[0] == ' ' || buf[0] == '\t') {
+ const errorLimit = 80 // arbitrary limit on how much of the line we'll quote
+ line, err := r.readLineSlice(errorLimit)
+ if err != nil {
+ return m, err
+ }
+ return m, ProtocolError("malformed MIME header initial line: " + string(line))
+ }
+
+ for {
+ kv, err := r.readContinuedLineSlice(maxMemory, mustHaveFieldNameColon)
+ if len(kv) == 0 {
+ return m, err
+ }
+
+ // Key ends at first colon.
+ k, v, ok := bytes.Cut(kv, colon)
+ if !ok {
+ return m, ProtocolError("malformed MIME header line: " + string(kv))
+ }
+ key, ok := canonicalMIMEHeaderKey(k)
+ if !ok {
+ return m, ProtocolError("malformed MIME header line: " + string(kv))
+ }
+ for _, c := range v {
+ if !validHeaderValueByte(c) {
+ return m, ProtocolError("malformed MIME header line: " + string(kv))
+ }
+ }
+
+ // As per RFC 7230 field-name is a token, tokens consist of one or more chars.
+ // We could return a ProtocolError here, but better to be liberal in what we
+ // accept, so if we get an empty key, skip it.
+ if key == "" {
+ continue
+ }
+
+ maxHeaders--
+ if maxHeaders < 0 {
+ return nil, errMessageTooLarge
+ }
+
+ // Skip initial spaces in value.
+ value := string(bytes.TrimLeft(v, " \t"))
+
+ vv := m[key]
+ if vv == nil {
+ maxMemory -= int64(len(key))
+ maxMemory -= mapEntryOverhead
+ }
+ maxMemory -= int64(len(value))
+ if maxMemory < 0 {
+ return m, errMessageTooLarge
+ }
+ if vv == nil && len(strs) > 0 {
+ // More than likely this will be a single-element key.
+ // Most headers aren't multi-valued.
+ // Set the capacity on strs[0] to 1, so any future append
+ // won't extend the slice into the other strings.
+ vv, strs = strs[:1:1], strs[1:]
+ vv[0] = value
+ m[key] = vv
+ } else {
+ m[key] = append(vv, value)
+ }
+
+ if err != nil {
+ return m, err
+ }
+ }
+}
+
+// noValidation is a no-op validation func for readContinuedLineSlice
+// that permits any lines.
+func noValidation(_ []byte) error { return nil }
+
+// mustHaveFieldNameColon ensures that, per RFC 7230, the
+// field-name is on a single line, so the first line must
+// contain a colon.
+func mustHaveFieldNameColon(line []byte) error {
+ if bytes.IndexByte(line, ':') < 0 {
+ return ProtocolError(fmt.Sprintf("malformed MIME header: missing colon: %q", line))
+ }
+ return nil
+}
+
+var nl = []byte("\n")
+
+// upcomingHeaderKeys returns an approximation of the number of keys
+// that will be in this header. If it gets confused, it returns 0.
+func (r *Reader) upcomingHeaderKeys() (n int) {
+ // Try to determine the 'hint' size.
+ r.R.Peek(1) // force a buffer load if empty
+ s := r.R.Buffered()
+ if s == 0 {
+ return
+ }
+ peek, _ := r.R.Peek(s)
+ for len(peek) > 0 && n < 1000 {
+ var line []byte
+ line, peek, _ = bytes.Cut(peek, nl)
+ if len(line) == 0 || (len(line) == 1 && line[0] == '\r') {
+ // Blank line separating headers from the body.
+ break
+ }
+ if line[0] == ' ' || line[0] == '\t' {
+ // Folded continuation of the previous line.
+ continue
+ }
+ n++
+ }
+ return n
+}
+
+// CanonicalMIMEHeaderKey returns the canonical format of the
+// MIME header key s. The canonicalization converts the first
+// letter and any letter following a hyphen to upper case;
+// the rest are converted to lowercase. For example, the
+// canonical key for "accept-encoding" is "Accept-Encoding".
+// MIME header keys are assumed to be ASCII only.
+// If s contains a space or invalid header field bytes, it is
+// returned without modifications.
+func CanonicalMIMEHeaderKey(s string) string {
+ // Quick check for canonical encoding.
+ upper := true
+ for i := 0; i < len(s); i++ {
+ c := s[i]
+ if !validHeaderFieldByte(c) {
+ return s
+ }
+ if upper && 'a' <= c && c <= 'z' {
+ s, _ = canonicalMIMEHeaderKey([]byte(s))
+ return s
+ }
+ if !upper && 'A' <= c && c <= 'Z' {
+ s, _ = canonicalMIMEHeaderKey([]byte(s))
+ return s
+ }
+ upper = c == '-'
+ }
+ return s
+}
+
+const toLower = 'a' - 'A'
+
+// validHeaderFieldByte reports whether c is a valid byte in a header
+// field name. RFC 7230 says:
+//
+// header-field = field-name ":" OWS field-value OWS
+// field-name = token
+// tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" / "+" / "-" / "." /
+// "^" / "_" / "`" / "|" / "~" / DIGIT / ALPHA
+// token = 1*tchar
+func validHeaderFieldByte(c byte) bool {
+ // mask is a 128-bit bitmap with 1s for allowed bytes,
+ // so that the byte c can be tested with a shift and an and.
+ // If c >= 128, then 1<<c and 1<<(c-64) will both be zero,
+ // and this function will return false.
+ const mask = 0 |
+ (1<<(10)-1)<<'0' |
+ (1<<(26)-1)<<'a' |
+ (1<<(26)-1)<<'A' |
+ 1<<'!' |
+ 1<<'#' |
+ 1<<'$' |
+ 1<<'%' |
+ 1<<'&' |
+ 1<<'\'' |
+ 1<<'*' |
+ 1<<'+' |
+ 1<<'-' |
+ 1<<'.' |
+ 1<<'^' |
+ 1<<'_' |
+ 1<<'`' |
+ 1<<'|' |
+ 1<<'~'
+ return ((uint64(1)<<c)&(mask&(1<<64-1)) |
+ (uint64(1)<<(c-64))&(mask>>64)) != 0
+}
+
+// validHeaderValueByte reports whether c is a valid byte in a header
+// field value. RFC 7230 says:
+//
+// field-content = field-vchar [ 1*( SP / HTAB ) field-vchar ]
+// field-vchar = VCHAR / obs-text
+// obs-text = %x80-FF
+//
+// RFC 5234 says:
+//
+// HTAB = %x09
+// SP = %x20
+// VCHAR = %x21-7E
+func validHeaderValueByte(c byte) bool {
+ // mask is a 128-bit bitmap with 1s for allowed bytes,
+ // so that the byte c can be tested with a shift and an and.
+ // If c >= 128, then 1<<c and 1<<(c-64) will both be zero.
+ // Since this is the obs-text range, we invert the mask to
+ // create a bitmap with 1s for disallowed bytes.
+ const mask = 0 |
+ (1<<(0x7f-0x21)-1)<<0x21 | // VCHAR: %x21-7E
+ 1<<0x20 | // SP: %x20
+ 1<<0x09 // HTAB: %x09
+ return ((uint64(1)<<c)&^(mask&(1<<64-1)) |
+ (uint64(1)<<(c-64))&^(mask>>64)) == 0
+}
+
+// canonicalMIMEHeaderKey is like CanonicalMIMEHeaderKey but is
+// allowed to mutate the provided byte slice before returning the
+// string.
+//
+// For invalid inputs (if a contains spaces or non-token bytes), a
+// is unchanged and a string copy is returned.
+//
+// ok is true if the header key contains only valid characters and spaces.
+// ReadMIMEHeader accepts header keys containing spaces, but does not
+// canonicalize them.
+func canonicalMIMEHeaderKey(a []byte) (_ string, ok bool) {
+ // See if a looks like a header key. If not, return it unchanged.
+ noCanon := false
+ for _, c := range a {
+ if validHeaderFieldByte(c) {
+ continue
+ }
+ // Don't canonicalize.
+ if c == ' ' {
+ // We accept invalid headers with a space before the
+ // colon, but must not canonicalize them.
+ // See https://go.dev/issue/34540.
+ noCanon = true
+ continue
+ }
+ return string(a), false
+ }
+ if noCanon {
+ return string(a), true
+ }
+
+ upper := true
+ for i, c := range a {
+ // Canonicalize: first letter upper case
+ // and upper case after each dash.
+ // (Host, User-Agent, If-Modified-Since).
+ // MIME headers are ASCII only, so no Unicode issues.
+ if upper && 'a' <= c && c <= 'z' {
+ c -= toLower
+ } else if !upper && 'A' <= c && c <= 'Z' {
+ c += toLower
+ }
+ a[i] = c
+ upper = c == '-' // for next time
+ }
+ commonHeaderOnce.Do(initCommonHeader)
+ // The compiler recognizes m[string(byteSlice)] as a special
+ // case, so a copy of a's bytes into a new string does not
+ // happen in this map lookup:
+ if v := commonHeader[string(a)]; v != "" {
+ return v, true
+ }
+ return string(a), true
+}
+
+// commonHeader interns common header strings.
+var commonHeader map[string]string
+
+var commonHeaderOnce sync.Once
+
+func initCommonHeader() {
+ commonHeader = make(map[string]string)
+ for _, v := range []string{
+ "Accept",
+ "Accept-Charset",
+ "Accept-Encoding",
+ "Accept-Language",
+ "Accept-Ranges",
+ "Cache-Control",
+ "Cc",
+ "Connection",
+ "Content-Id",
+ "Content-Language",
+ "Content-Length",
+ "Content-Transfer-Encoding",
+ "Content-Type",
+ "Cookie",
+ "Date",
+ "Dkim-Signature",
+ "Etag",
+ "Expires",
+ "From",
+ "Host",
+ "If-Modified-Since",
+ "If-None-Match",
+ "In-Reply-To",
+ "Last-Modified",
+ "Location",
+ "Message-Id",
+ "Mime-Version",
+ "Pragma",
+ "Received",
+ "Return-Path",
+ "Server",
+ "Set-Cookie",
+ "Subject",
+ "To",
+ "User-Agent",
+ "Via",
+ "X-Forwarded-For",
+ "X-Imforwards",
+ "X-Powered-By",
+ } {
+ commonHeader[v] = v
+ }
+}
diff --git a/src/net/textproto/reader_test.go b/src/net/textproto/reader_test.go
new file mode 100644
index 0000000..26ff617
--- /dev/null
+++ b/src/net/textproto/reader_test.go
@@ -0,0 +1,537 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package textproto
+
+import (
+ "bufio"
+ "bytes"
+ "io"
+ "net"
+ "reflect"
+ "runtime"
+ "strings"
+ "sync"
+ "testing"
+)
+
+func reader(s string) *Reader {
+ return NewReader(bufio.NewReader(strings.NewReader(s)))
+}
+
+func TestReadLine(t *testing.T) {
+ r := reader("line1\nline2\n")
+ s, err := r.ReadLine()
+ if s != "line1" || err != nil {
+ t.Fatalf("Line 1: %s, %v", s, err)
+ }
+ s, err = r.ReadLine()
+ if s != "line2" || err != nil {
+ t.Fatalf("Line 2: %s, %v", s, err)
+ }
+ s, err = r.ReadLine()
+ if s != "" || err != io.EOF {
+ t.Fatalf("EOF: %s, %v", s, err)
+ }
+}
+
+func TestReadLineLongLine(t *testing.T) {
+ line := strings.Repeat("12345", 10000)
+ r := reader(line + "\r\n")
+ s, err := r.ReadLine()
+ if err != nil {
+ t.Fatalf("Line 1: %v", err)
+ }
+ if s != line {
+ t.Fatalf("%v-byte line does not match expected %v-byte line", len(s), len(line))
+ }
+}
+
+func TestReadContinuedLine(t *testing.T) {
+ r := reader("line1\nline\n 2\nline3\n")
+ s, err := r.ReadContinuedLine()
+ if s != "line1" || err != nil {
+ t.Fatalf("Line 1: %s, %v", s, err)
+ }
+ s, err = r.ReadContinuedLine()
+ if s != "line 2" || err != nil {
+ t.Fatalf("Line 2: %s, %v", s, err)
+ }
+ s, err = r.ReadContinuedLine()
+ if s != "line3" || err != nil {
+ t.Fatalf("Line 3: %s, %v", s, err)
+ }
+ s, err = r.ReadContinuedLine()
+ if s != "" || err != io.EOF {
+ t.Fatalf("EOF: %s, %v", s, err)
+ }
+}
+
+func TestReadCodeLine(t *testing.T) {
+ r := reader("123 hi\n234 bye\n345 no way\n")
+ code, msg, err := r.ReadCodeLine(0)
+ if code != 123 || msg != "hi" || err != nil {
+ t.Fatalf("Line 1: %d, %s, %v", code, msg, err)
+ }
+ code, msg, err = r.ReadCodeLine(23)
+ if code != 234 || msg != "bye" || err != nil {
+ t.Fatalf("Line 2: %d, %s, %v", code, msg, err)
+ }
+ code, msg, err = r.ReadCodeLine(346)
+ if code != 345 || msg != "no way" || err == nil {
+ t.Fatalf("Line 3: %d, %s, %v", code, msg, err)
+ }
+ if e, ok := err.(*Error); !ok || e.Code != code || e.Msg != msg {
+ t.Fatalf("Line 3: wrong error %v\n", err)
+ }
+ code, msg, err = r.ReadCodeLine(1)
+ if code != 0 || msg != "" || err != io.EOF {
+ t.Fatalf("EOF: %d, %s, %v", code, msg, err)
+ }
+}
+
+func TestReadDotLines(t *testing.T) {
+ r := reader("dotlines\r\n.foo\r\n..bar\n...baz\nquux\r\n\r\n.\r\nanother\n")
+ s, err := r.ReadDotLines()
+ want := []string{"dotlines", "foo", ".bar", "..baz", "quux", ""}
+ if !reflect.DeepEqual(s, want) || err != nil {
+ t.Fatalf("ReadDotLines: %v, %v", s, err)
+ }
+
+ s, err = r.ReadDotLines()
+ want = []string{"another"}
+ if !reflect.DeepEqual(s, want) || err != io.ErrUnexpectedEOF {
+ t.Fatalf("ReadDotLines2: %v, %v", s, err)
+ }
+}
+
+func TestReadDotBytes(t *testing.T) {
+ r := reader("dotlines\r\n.foo\r\n..bar\n...baz\nquux\r\n\r\n.\r\nanot.her\r\n")
+ b, err := r.ReadDotBytes()
+ want := []byte("dotlines\nfoo\n.bar\n..baz\nquux\n\n")
+ if !reflect.DeepEqual(b, want) || err != nil {
+ t.Fatalf("ReadDotBytes: %q, %v", b, err)
+ }
+
+ b, err = r.ReadDotBytes()
+ want = []byte("anot.her\n")
+ if !reflect.DeepEqual(b, want) || err != io.ErrUnexpectedEOF {
+ t.Fatalf("ReadDotBytes2: %q, %v", b, err)
+ }
+}
+
+func TestReadMIMEHeader(t *testing.T) {
+ r := reader("my-key: Value 1 \r\nLong-key: Even \n Longer Value\r\nmy-Key: Value 2\r\n\n")
+ m, err := r.ReadMIMEHeader()
+ want := MIMEHeader{
+ "My-Key": {"Value 1", "Value 2"},
+ "Long-Key": {"Even Longer Value"},
+ }
+ if !reflect.DeepEqual(m, want) || err != nil {
+ t.Fatalf("ReadMIMEHeader: %v, %v; want %v", m, err, want)
+ }
+}
+
+func TestReadMIMEHeaderSingle(t *testing.T) {
+ r := reader("Foo: bar\n\n")
+ m, err := r.ReadMIMEHeader()
+ want := MIMEHeader{"Foo": {"bar"}}
+ if !reflect.DeepEqual(m, want) || err != nil {
+ t.Fatalf("ReadMIMEHeader: %v, %v; want %v", m, err, want)
+ }
+}
+
+// TestReaderUpcomingHeaderKeys is testing an internal function, but it's very
+// difficult to test well via the external API.
+func TestReaderUpcomingHeaderKeys(t *testing.T) {
+ for _, test := range []struct {
+ input string
+ want int
+ }{{
+ input: "",
+ want: 0,
+ }, {
+ input: "A: v",
+ want: 1,
+ }, {
+ input: "A: v\r\nB: v\r\n",
+ want: 2,
+ }, {
+ input: "A: v\nB: v\n",
+ want: 2,
+ }, {
+ input: "A: v\r\n continued\r\n still continued\r\nB: v\r\n\r\n",
+ want: 2,
+ }, {
+ input: "A: v\r\n\r\nB: v\r\nC: v\r\n",
+ want: 1,
+ }, {
+ input: "A: v" + strings.Repeat("\n", 1000),
+ want: 1,
+ }} {
+ r := reader(test.input)
+ got := r.upcomingHeaderKeys()
+ if test.want != got {
+ t.Fatalf("upcomingHeaderKeys(%q): %v; want %v", test.input, got, test.want)
+ }
+ }
+}
+
+func TestReadMIMEHeaderNoKey(t *testing.T) {
+ r := reader(": bar\ntest-1: 1\n\n")
+ m, err := r.ReadMIMEHeader()
+ want := MIMEHeader{"Test-1": {"1"}}
+ if !reflect.DeepEqual(m, want) || err != nil {
+ t.Fatalf("ReadMIMEHeader: %v, %v; want %v", m, err, want)
+ }
+}
+
+func TestLargeReadMIMEHeader(t *testing.T) {
+ data := make([]byte, 16*1024)
+ for i := 0; i < len(data); i++ {
+ data[i] = 'x'
+ }
+ sdata := string(data)
+ r := reader("Cookie: " + sdata + "\r\n\n")
+ m, err := r.ReadMIMEHeader()
+ if err != nil {
+ t.Fatalf("ReadMIMEHeader: %v", err)
+ }
+ cookie := m.Get("Cookie")
+ if cookie != sdata {
+ t.Fatalf("ReadMIMEHeader: %v bytes, want %v bytes", len(cookie), len(sdata))
+ }
+}
+
+// TestReadMIMEHeaderNonCompliant checks that we don't normalize headers
+// with spaces before colons, and accept spaces in keys.
+func TestReadMIMEHeaderNonCompliant(t *testing.T) {
+ // These invalid headers will be rejected by net/http according to RFC 7230.
+ r := reader("Foo: bar\r\n" +
+ "Content-Language: en\r\n" +
+ "SID : 0\r\n" +
+ "Audio Mode : None\r\n" +
+ "Privilege : 127\r\n\r\n")
+ m, err := r.ReadMIMEHeader()
+ want := MIMEHeader{
+ "Foo": {"bar"},
+ "Content-Language": {"en"},
+ "SID ": {"0"},
+ "Audio Mode ": {"None"},
+ "Privilege ": {"127"},
+ }
+ if !reflect.DeepEqual(m, want) || err != nil {
+ t.Fatalf("ReadMIMEHeader =\n%v, %v; want:\n%v", m, err, want)
+ }
+}
+
+func TestReadMIMEHeaderMalformed(t *testing.T) {
+ inputs := []string{
+ "No colon first line\r\nFoo: foo\r\n\r\n",
+ " No colon first line with leading space\r\nFoo: foo\r\n\r\n",
+ "\tNo colon first line with leading tab\r\nFoo: foo\r\n\r\n",
+ " First: line with leading space\r\nFoo: foo\r\n\r\n",
+ "\tFirst: line with leading tab\r\nFoo: foo\r\n\r\n",
+ "Foo: foo\r\nNo colon second line\r\n\r\n",
+ "Foo-\n\tBar: foo\r\n\r\n",
+ "Foo-\r\n\tBar: foo\r\n\r\n",
+ "Foo\r\n\t: foo\r\n\r\n",
+ "Foo-\n\tBar",
+ "Foo \tBar: foo\r\n\r\n",
+ }
+ for _, input := range inputs {
+ r := reader(input)
+ if m, err := r.ReadMIMEHeader(); err == nil || err == io.EOF {
+ t.Errorf("ReadMIMEHeader(%q) = %v, %v; want nil, err", input, m, err)
+ }
+ }
+}
+
+func TestReadMIMEHeaderBytes(t *testing.T) {
+ for i := 0; i <= 0xff; i++ {
+ s := "Foo" + string(rune(i)) + "Bar: foo\r\n\r\n"
+ r := reader(s)
+ wantErr := true
+ switch {
+ case i >= '0' && i <= '9':
+ wantErr = false
+ case i >= 'a' && i <= 'z':
+ wantErr = false
+ case i >= 'A' && i <= 'Z':
+ wantErr = false
+ case i == '!' || i == '#' || i == '$' || i == '%' || i == '&' || i == '\'' || i == '*' || i == '+' || i == '-' || i == '.' || i == '^' || i == '_' || i == '`' || i == '|' || i == '~':
+ wantErr = false
+ case i == ':':
+ // Special case: "Foo:Bar: foo" is the header "Foo".
+ wantErr = false
+ case i == ' ':
+ wantErr = false
+ }
+ m, err := r.ReadMIMEHeader()
+ if err != nil != wantErr {
+ t.Errorf("ReadMIMEHeader(%q) = %v, %v; want error=%v", s, m, err, wantErr)
+ }
+ }
+ for i := 0; i <= 0xff; i++ {
+ s := "Foo: foo" + string(rune(i)) + "bar\r\n\r\n"
+ r := reader(s)
+ wantErr := true
+ switch {
+ case i >= 0x21 && i <= 0x7e:
+ wantErr = false
+ case i == ' ':
+ wantErr = false
+ case i == '\t':
+ wantErr = false
+ case i >= 0x80 && i <= 0xff:
+ wantErr = false
+ }
+ m, err := r.ReadMIMEHeader()
+ if (err != nil) != wantErr {
+ t.Errorf("ReadMIMEHeader(%q) = %v, %v; want error=%v", s, m, err, wantErr)
+ }
+ }
+}
+
+// Test that continued lines are properly trimmed. Issue 11204.
+func TestReadMIMEHeaderTrimContinued(t *testing.T) {
+ // In this header, \n and \r\n terminated lines are mixed on purpose.
+ // We expect each line to be trimmed (prefix and suffix) before being concatenated.
+ // Keep the spaces as they are.
+ r := reader("" + // for code formatting purpose.
+ "a:\n" +
+ " 0 \r\n" +
+ "b:1 \t\r\n" +
+ "c: 2\r\n" +
+ " 3\t\n" +
+ " \t 4 \r\n\n")
+ m, err := r.ReadMIMEHeader()
+ if err != nil {
+ t.Fatal(err)
+ }
+ want := MIMEHeader{
+ "A": {"0"},
+ "B": {"1"},
+ "C": {"2 3 4"},
+ }
+ if !reflect.DeepEqual(m, want) {
+ t.Fatalf("ReadMIMEHeader mismatch.\n got: %q\nwant: %q", m, want)
+ }
+}
+
+// Test that reading a header doesn't overallocate. Issue 58975.
+func TestReadMIMEHeaderAllocations(t *testing.T) {
+ var totalAlloc uint64
+ const count = 200
+ for i := 0; i < count; i++ {
+ r := reader("A: b\r\n\r\n" + strings.Repeat("\n", 4096))
+ var m1, m2 runtime.MemStats
+ runtime.ReadMemStats(&m1)
+ _, err := r.ReadMIMEHeader()
+ if err != nil {
+ t.Fatalf("ReadMIMEHeader: %v", err)
+ }
+ runtime.ReadMemStats(&m2)
+ totalAlloc += m2.TotalAlloc - m1.TotalAlloc
+ }
+ // 32k is large and we actually allocate substantially less,
+ // but prior to the fix for #58975 we allocated ~400k in this case.
+ if got, want := totalAlloc/count, uint64(32768); got > want {
+ t.Fatalf("ReadMIMEHeader allocated %v bytes, want < %v", got, want)
+ }
+}
+
+type readResponseTest struct {
+ in string
+ inCode int
+ wantCode int
+ wantMsg string
+}
+
+var readResponseTests = []readResponseTest{
+ {"230-Anonymous access granted, restrictions apply\n" +
+ "Read the file README.txt,\n" +
+ "230 please",
+ 23,
+ 230,
+ "Anonymous access granted, restrictions apply\nRead the file README.txt,\n please",
+ },
+
+ {"230 Anonymous access granted, restrictions apply\n",
+ 23,
+ 230,
+ "Anonymous access granted, restrictions apply",
+ },
+
+ {"400-A\n400-B\n400 C",
+ 4,
+ 400,
+ "A\nB\nC",
+ },
+
+ {"400-A\r\n400-B\r\n400 C\r\n",
+ 4,
+ 400,
+ "A\nB\nC",
+ },
+}
+
+// See https://www.ietf.org/rfc/rfc959.txt page 36.
+func TestRFC959Lines(t *testing.T) {
+ for i, tt := range readResponseTests {
+ r := reader(tt.in + "\nFOLLOWING DATA")
+ code, msg, err := r.ReadResponse(tt.inCode)
+ if err != nil {
+ t.Errorf("#%d: ReadResponse: %v", i, err)
+ continue
+ }
+ if code != tt.wantCode {
+ t.Errorf("#%d: code=%d, want %d", i, code, tt.wantCode)
+ }
+ if msg != tt.wantMsg {
+ t.Errorf("#%d: msg=%q, want %q", i, msg, tt.wantMsg)
+ }
+ }
+}
+
+// Test that multi-line errors are appropriately and fully read. Issue 10230.
+func TestReadMultiLineError(t *testing.T) {
+ r := reader("550-5.1.1 The email account that you tried to reach does not exist. Please try\n" +
+ "550-5.1.1 double-checking the recipient's email address for typos or\n" +
+ "550-5.1.1 unnecessary spaces. Learn more at\n" +
+ "Unexpected but legal text!\n" +
+ "550 5.1.1 https://support.google.com/mail/answer/6596 h20si25154304pfd.166 - gsmtp\n")
+
+ wantMsg := "5.1.1 The email account that you tried to reach does not exist. Please try\n" +
+ "5.1.1 double-checking the recipient's email address for typos or\n" +
+ "5.1.1 unnecessary spaces. Learn more at\n" +
+ "Unexpected but legal text!\n" +
+ "5.1.1 https://support.google.com/mail/answer/6596 h20si25154304pfd.166 - gsmtp"
+
+ code, msg, err := r.ReadResponse(250)
+ if err == nil {
+ t.Errorf("ReadResponse: no error, want error")
+ }
+ if code != 550 {
+ t.Errorf("ReadResponse: code=%d, want %d", code, 550)
+ }
+ if msg != wantMsg {
+ t.Errorf("ReadResponse: msg=%q, want %q", msg, wantMsg)
+ }
+ if err != nil && err.Error() != "550 "+wantMsg {
+ t.Errorf("ReadResponse: error=%q, want %q", err.Error(), "550 "+wantMsg)
+ }
+}
+
+func TestCommonHeaders(t *testing.T) {
+ commonHeaderOnce.Do(initCommonHeader)
+ for h := range commonHeader {
+ if h != CanonicalMIMEHeaderKey(h) {
+ t.Errorf("Non-canonical header %q in commonHeader", h)
+ }
+ }
+ b := []byte("content-Length")
+ want := "Content-Length"
+ n := testing.AllocsPerRun(200, func() {
+ if x, _ := canonicalMIMEHeaderKey(b); x != want {
+ t.Fatalf("canonicalMIMEHeaderKey(%q) = %q; want %q", b, x, want)
+ }
+ })
+ if n > 0 {
+ t.Errorf("canonicalMIMEHeaderKey allocs = %v; want 0", n)
+ }
+}
+
+func TestIssue46363(t *testing.T) {
+ // Regression test for data race reported in issue 46363:
+ // ReadMIMEHeader reads commonHeader before commonHeader has been initialized.
+ // Run this test with the race detector enabled to catch the reported data race.
+
+ // Reset commonHeaderOnce, so that commonHeader will have to be initialized
+ commonHeaderOnce = sync.Once{}
+ commonHeader = nil
+
+ // Test for data race by calling ReadMIMEHeader and CanonicalMIMEHeaderKey concurrently
+
+ // Send MIME header over net.Conn
+ r, w := net.Pipe()
+ go func() {
+ // ReadMIMEHeader calls canonicalMIMEHeaderKey, which reads from commonHeader
+ NewConn(r).ReadMIMEHeader()
+ }()
+ w.Write([]byte("A: 1\r\nB: 2\r\nC: 3\r\n\r\n"))
+
+ // CanonicalMIMEHeaderKey calls commonHeaderOnce.Do(initCommonHeader) which initializes commonHeader
+ CanonicalMIMEHeaderKey("a")
+
+ if commonHeader == nil {
+ t.Fatal("CanonicalMIMEHeaderKey should initialize commonHeader")
+ }
+}
+
+var clientHeaders = strings.Replace(`Host: golang.org
+Connection: keep-alive
+Cache-Control: max-age=0
+Accept: application/xml,application/xhtml+xml,text/html;q=0.9,text/plain;q=0.8,image/png,*/*;q=0.5
+User-Agent: Mozilla/5.0 (X11; U; Linux x86_64; en-US) AppleWebKit/534.3 (KHTML, like Gecko) Chrome/6.0.472.63 Safari/534.3
+Accept-Encoding: gzip,deflate,sdch
+Accept-Language: en-US,en;q=0.8,fr-CH;q=0.6
+Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3
+COOKIE: __utma=000000000.0000000000.0000000000.0000000000.0000000000.00; __utmb=000000000.0.00.0000000000; __utmc=000000000; __utmz=000000000.0000000000.00.0.utmcsr=code.google.com|utmccn=(referral)|utmcmd=referral|utmcct=/p/go/issues/detail
+Non-Interned: test
+
+`, "\n", "\r\n", -1)
+
+var serverHeaders = strings.Replace(`Content-Type: text/html; charset=utf-8
+Content-Encoding: gzip
+Date: Thu, 27 Sep 2012 09:03:33 GMT
+Server: Google Frontend
+Cache-Control: private
+Content-Length: 2298
+VIA: 1.1 proxy.example.com:80 (XXX/n.n.n-nnn)
+Connection: Close
+Non-Interned: test
+
+`, "\n", "\r\n", -1)
+
+func BenchmarkReadMIMEHeader(b *testing.B) {
+ b.ReportAllocs()
+ for _, set := range []struct {
+ name string
+ headers string
+ }{
+ {"client_headers", clientHeaders},
+ {"server_headers", serverHeaders},
+ } {
+ b.Run(set.name, func(b *testing.B) {
+ var buf bytes.Buffer
+ br := bufio.NewReader(&buf)
+ r := NewReader(br)
+
+ for i := 0; i < b.N; i++ {
+ buf.WriteString(set.headers)
+ if _, err := r.ReadMIMEHeader(); err != nil {
+ b.Fatal(err)
+ }
+ }
+ })
+ }
+}
+
+func BenchmarkUncommon(b *testing.B) {
+ b.ReportAllocs()
+ var buf bytes.Buffer
+ br := bufio.NewReader(&buf)
+ r := NewReader(br)
+ for i := 0; i < b.N; i++ {
+ buf.WriteString("uncommon-header-for-benchmark: foo\r\n\r\n")
+ h, err := r.ReadMIMEHeader()
+ if err != nil {
+ b.Fatal(err)
+ }
+ if _, ok := h["Uncommon-Header-For-Benchmark"]; !ok {
+ b.Fatal("Missing result header.")
+ }
+ }
+}
diff --git a/src/net/textproto/textproto.go b/src/net/textproto/textproto.go
new file mode 100644
index 0000000..70038d5
--- /dev/null
+++ b/src/net/textproto/textproto.go
@@ -0,0 +1,152 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package textproto implements generic support for text-based request/response
+// protocols in the style of HTTP, NNTP, and SMTP.
+//
+// The package provides:
+//
+// Error, which represents a numeric error response from
+// a server.
+//
+// Pipeline, to manage pipelined requests and responses
+// in a client.
+//
+// Reader, to read numeric response code lines,
+// key: value headers, lines wrapped with leading spaces
+// on continuation lines, and whole text blocks ending
+// with a dot on a line by itself.
+//
+// Writer, to write dot-encoded text blocks.
+//
+// Conn, a convenient packaging of Reader, Writer, and Pipeline for use
+// with a single network connection.
+package textproto
+
+import (
+ "bufio"
+ "fmt"
+ "io"
+ "net"
+)
+
+// An Error represents a numeric error response from a server.
+type Error struct {
+ Code int
+ Msg string
+}
+
+func (e *Error) Error() string {
+ return fmt.Sprintf("%03d %s", e.Code, e.Msg)
+}
+
+// A ProtocolError describes a protocol violation such
+// as an invalid response or a hung-up connection.
+type ProtocolError string
+
+func (p ProtocolError) Error() string {
+ return string(p)
+}
+
+// A Conn represents a textual network protocol connection.
+// It consists of a Reader and Writer to manage I/O
+// and a Pipeline to sequence concurrent requests on the connection.
+// These embedded types carry methods with them;
+// see the documentation of those types for details.
+type Conn struct {
+ Reader
+ Writer
+ Pipeline
+ conn io.ReadWriteCloser
+}
+
+// NewConn returns a new Conn using conn for I/O.
+func NewConn(conn io.ReadWriteCloser) *Conn {
+ return &Conn{
+ Reader: Reader{R: bufio.NewReader(conn)},
+ Writer: Writer{W: bufio.NewWriter(conn)},
+ conn: conn,
+ }
+}
+
+// Close closes the connection.
+func (c *Conn) Close() error {
+ return c.conn.Close()
+}
+
+// Dial connects to the given address on the given network using net.Dial
+// and then returns a new Conn for the connection.
+func Dial(network, addr string) (*Conn, error) {
+ c, err := net.Dial(network, addr)
+ if err != nil {
+ return nil, err
+ }
+ return NewConn(c), nil
+}
+
+// Cmd is a convenience method that sends a command after
+// waiting its turn in the pipeline. The command text is the
+// result of formatting format with args and appending \r\n.
+// Cmd returns the id of the command, for use with StartResponse and EndResponse.
+//
+// For example, a client might run a HELP command that returns a dot-body
+// by using:
+//
+// id, err := c.Cmd("HELP")
+// if err != nil {
+// return nil, err
+// }
+//
+// c.StartResponse(id)
+// defer c.EndResponse(id)
+//
+// if _, _, err = c.ReadCodeLine(110); err != nil {
+// return nil, err
+// }
+// text, err := c.ReadDotBytes()
+// if err != nil {
+// return nil, err
+// }
+// return c.ReadCodeLine(250)
+func (c *Conn) Cmd(format string, args ...any) (id uint, err error) {
+ id = c.Next()
+ c.StartRequest(id)
+ err = c.PrintfLine(format, args...)
+ c.EndRequest(id)
+ if err != nil {
+ return 0, err
+ }
+ return id, nil
+}
+
+// TrimString returns s without leading and trailing ASCII space.
+func TrimString(s string) string {
+ for len(s) > 0 && isASCIISpace(s[0]) {
+ s = s[1:]
+ }
+ for len(s) > 0 && isASCIISpace(s[len(s)-1]) {
+ s = s[:len(s)-1]
+ }
+ return s
+}
+
+// TrimBytes returns b without leading and trailing ASCII space.
+func TrimBytes(b []byte) []byte {
+ for len(b) > 0 && isASCIISpace(b[0]) {
+ b = b[1:]
+ }
+ for len(b) > 0 && isASCIISpace(b[len(b)-1]) {
+ b = b[:len(b)-1]
+ }
+ return b
+}
+
+func isASCIISpace(b byte) bool {
+ return b == ' ' || b == '\t' || b == '\n' || b == '\r'
+}
+
+func isASCIILetter(b byte) bool {
+ b |= 0x20 // make lower case
+ return 'a' <= b && b <= 'z'
+}
diff --git a/src/net/textproto/writer.go b/src/net/textproto/writer.go
new file mode 100644
index 0000000..2ece3f5
--- /dev/null
+++ b/src/net/textproto/writer.go
@@ -0,0 +1,119 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package textproto
+
+import (
+ "bufio"
+ "fmt"
+ "io"
+)
+
+// A Writer implements convenience methods for writing
+// requests or responses to a text protocol network connection.
+type Writer struct {
+ W *bufio.Writer
+ dot *dotWriter
+}
+
+// NewWriter returns a new Writer writing to w.
+func NewWriter(w *bufio.Writer) *Writer {
+ return &Writer{W: w}
+}
+
+var crnl = []byte{'\r', '\n'}
+var dotcrnl = []byte{'.', '\r', '\n'}
+
+// PrintfLine writes the formatted output followed by \r\n.
+func (w *Writer) PrintfLine(format string, args ...any) error {
+ w.closeDot()
+ fmt.Fprintf(w.W, format, args...)
+ w.W.Write(crnl)
+ return w.W.Flush()
+}
+
+// DotWriter returns a writer that can be used to write a dot-encoding to w.
+// It takes care of inserting leading dots when necessary,
+// translating line-ending \n into \r\n, and adding the final .\r\n line
+// when the DotWriter is closed. The caller should close the
+// DotWriter before the next call to a method on w.
+//
+// See the documentation for Reader's DotReader method for details about dot-encoding.
+func (w *Writer) DotWriter() io.WriteCloser {
+ w.closeDot()
+ w.dot = &dotWriter{w: w}
+ return w.dot
+}
+
+func (w *Writer) closeDot() {
+ if w.dot != nil {
+ w.dot.Close() // sets w.dot = nil
+ }
+}
+
+type dotWriter struct {
+ w *Writer
+ state int
+}
+
+const (
+ wstateBegin = iota // initial state; must be zero
+ wstateBeginLine // beginning of line
+ wstateCR // wrote \r (possibly at end of line)
+ wstateData // writing data in middle of line
+)
+
+func (d *dotWriter) Write(b []byte) (n int, err error) {
+ bw := d.w.W
+ for n < len(b) {
+ c := b[n]
+ switch d.state {
+ case wstateBegin, wstateBeginLine:
+ d.state = wstateData
+ if c == '.' {
+ // escape leading dot
+ bw.WriteByte('.')
+ }
+ fallthrough
+
+ case wstateData:
+ if c == '\r' {
+ d.state = wstateCR
+ }
+ if c == '\n' {
+ bw.WriteByte('\r')
+ d.state = wstateBeginLine
+ }
+
+ case wstateCR:
+ d.state = wstateData
+ if c == '\n' {
+ d.state = wstateBeginLine
+ }
+ }
+ if err = bw.WriteByte(c); err != nil {
+ break
+ }
+ n++
+ }
+ return
+}
+
+func (d *dotWriter) Close() error {
+ if d.w.dot == d {
+ d.w.dot = nil
+ }
+ bw := d.w.W
+ switch d.state {
+ default:
+ bw.WriteByte('\r')
+ fallthrough
+ case wstateCR:
+ bw.WriteByte('\n')
+ fallthrough
+ case wstateBeginLine:
+ bw.Write(dotcrnl)
+ }
+ return bw.Flush()
+}
diff --git a/src/net/textproto/writer_test.go b/src/net/textproto/writer_test.go
new file mode 100644
index 0000000..8f11b10
--- /dev/null
+++ b/src/net/textproto/writer_test.go
@@ -0,0 +1,61 @@
+// Copyright 2010 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package textproto
+
+import (
+ "bufio"
+ "strings"
+ "testing"
+)
+
+func TestPrintfLine(t *testing.T) {
+ var buf strings.Builder
+ w := NewWriter(bufio.NewWriter(&buf))
+ err := w.PrintfLine("foo %d", 123)
+ if s := buf.String(); s != "foo 123\r\n" || err != nil {
+ t.Fatalf("s=%q; err=%s", s, err)
+ }
+}
+
+func TestDotWriter(t *testing.T) {
+ var buf strings.Builder
+ w := NewWriter(bufio.NewWriter(&buf))
+ d := w.DotWriter()
+ n, err := d.Write([]byte("abc\n.def\n..ghi\n.jkl\n."))
+ if n != 21 || err != nil {
+ t.Fatalf("Write: %d, %s", n, err)
+ }
+ d.Close()
+ want := "abc\r\n..def\r\n...ghi\r\n..jkl\r\n..\r\n.\r\n"
+ if s := buf.String(); s != want {
+ t.Fatalf("wrote %q", s)
+ }
+}
+
+func TestDotWriterCloseEmptyWrite(t *testing.T) {
+ var buf strings.Builder
+ w := NewWriter(bufio.NewWriter(&buf))
+ d := w.DotWriter()
+ n, err := d.Write([]byte{})
+ if n != 0 || err != nil {
+ t.Fatalf("Write: %d, %s", n, err)
+ }
+ d.Close()
+ want := "\r\n.\r\n"
+ if s := buf.String(); s != want {
+ t.Fatalf("wrote %q; want %q", s, want)
+ }
+}
+
+func TestDotWriterCloseNoWrite(t *testing.T) {
+ var buf strings.Builder
+ w := NewWriter(bufio.NewWriter(&buf))
+ d := w.DotWriter()
+ d.Close()
+ want := "\r\n.\r\n"
+ if s := buf.String(); s != want {
+ t.Fatalf("wrote %q; want %q", s, want)
+ }
+}
diff --git a/src/net/timeout_test.go b/src/net/timeout_test.go
new file mode 100644
index 0000000..c0bce57
--- /dev/null
+++ b/src/net/timeout_test.go
@@ -0,0 +1,1161 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !js && !wasip1
+
+package net
+
+import (
+ "errors"
+ "fmt"
+ "internal/testenv"
+ "io"
+ "net/internal/socktest"
+ "os"
+ "runtime"
+ "sync"
+ "testing"
+ "time"
+)
+
+var dialTimeoutTests = []struct {
+ timeout time.Duration
+ delta time.Duration // for deadline
+
+ guard time.Duration
+}{
+ // Tests that dial timeouts, deadlines in the past work.
+ {-5 * time.Second, 0, -5 * time.Second},
+ {0, -5 * time.Second, -5 * time.Second},
+ {-5 * time.Second, 5 * time.Second, -5 * time.Second}, // timeout over deadline
+ {-1 << 63, 0, time.Second},
+ {0, -1 << 63, time.Second},
+
+ {50 * time.Millisecond, 0, 100 * time.Millisecond},
+ {0, 50 * time.Millisecond, 100 * time.Millisecond},
+ {50 * time.Millisecond, 5 * time.Second, 100 * time.Millisecond}, // timeout over deadline
+}
+
+func TestDialTimeout(t *testing.T) {
+ // Cannot use t.Parallel - modifies global hooks.
+ origTestHookDialChannel := testHookDialChannel
+ defer func() { testHookDialChannel = origTestHookDialChannel }()
+ defer sw.Set(socktest.FilterConnect, nil)
+
+ for i, tt := range dialTimeoutTests {
+ switch runtime.GOOS {
+ case "plan9", "windows":
+ testHookDialChannel = func() { time.Sleep(tt.guard) }
+ if runtime.GOOS == "plan9" {
+ break
+ }
+ fallthrough
+ default:
+ sw.Set(socktest.FilterConnect, func(so *socktest.Status) (socktest.AfterFilter, error) {
+ time.Sleep(tt.guard)
+ return nil, errTimedout
+ })
+ }
+
+ d := Dialer{Timeout: tt.timeout}
+ if tt.delta != 0 {
+ d.Deadline = time.Now().Add(tt.delta)
+ }
+
+ // This dial never starts to send any TCP SYN
+ // segment because of above socket filter and
+ // test hook.
+ c, err := d.Dial("tcp", "127.0.0.1:0")
+ if err == nil {
+ err = fmt.Errorf("unexpectedly established: tcp:%s->%s", c.LocalAddr(), c.RemoteAddr())
+ c.Close()
+ }
+
+ if perr := parseDialError(err); perr != nil {
+ t.Errorf("#%d: %v", i, perr)
+ }
+ if nerr, ok := err.(Error); !ok || !nerr.Timeout() {
+ t.Fatalf("#%d: %v", i, err)
+ }
+ }
+}
+
+func TestDialTimeoutMaxDuration(t *testing.T) {
+ ln := newLocalListener(t, "tcp")
+ defer func() {
+ if err := ln.Close(); err != nil {
+ t.Error(err)
+ }
+ }()
+
+ for _, tt := range []struct {
+ timeout time.Duration
+ delta time.Duration // for deadline
+ }{
+ // Large timeouts that will overflow an int64 unix nanos.
+ {1<<63 - 1, 0},
+ {0, 1<<63 - 1},
+ } {
+ t.Run(fmt.Sprintf("timeout=%s/delta=%s", tt.timeout, tt.delta), func(t *testing.T) {
+ d := Dialer{Timeout: tt.timeout}
+ if tt.delta != 0 {
+ d.Deadline = time.Now().Add(tt.delta)
+ }
+ c, err := d.Dial(ln.Addr().Network(), ln.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ if err := c.Close(); err != nil {
+ t.Error(err)
+ }
+ })
+ }
+}
+
+var acceptTimeoutTests = []struct {
+ timeout time.Duration
+ xerrs [2]error // expected errors in transition
+}{
+ // Tests that accept deadlines in the past work, even if
+ // there's incoming connections available.
+ {-5 * time.Second, [2]error{os.ErrDeadlineExceeded, os.ErrDeadlineExceeded}},
+
+ {50 * time.Millisecond, [2]error{nil, os.ErrDeadlineExceeded}},
+}
+
+func TestAcceptTimeout(t *testing.T) {
+ testenv.SkipFlaky(t, 17948)
+ t.Parallel()
+
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("not supported on %s", runtime.GOOS)
+ }
+
+ ln := newLocalListener(t, "tcp")
+ defer ln.Close()
+
+ var wg sync.WaitGroup
+ for i, tt := range acceptTimeoutTests {
+ if tt.timeout < 0 {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ d := Dialer{Timeout: 100 * time.Millisecond}
+ c, err := d.Dial(ln.Addr().Network(), ln.Addr().String())
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ c.Close()
+ }()
+ }
+
+ if err := ln.(*TCPListener).SetDeadline(time.Now().Add(tt.timeout)); err != nil {
+ t.Fatalf("$%d: %v", i, err)
+ }
+ for j, xerr := range tt.xerrs {
+ for {
+ c, err := ln.Accept()
+ if xerr != nil {
+ if perr := parseAcceptError(err); perr != nil {
+ t.Errorf("#%d/%d: %v", i, j, perr)
+ }
+ if !isDeadlineExceeded(err) {
+ t.Fatalf("#%d/%d: %v", i, j, err)
+ }
+ }
+ if err == nil {
+ c.Close()
+ time.Sleep(10 * time.Millisecond)
+ continue
+ }
+ break
+ }
+ }
+ }
+ wg.Wait()
+}
+
+func TestAcceptTimeoutMustReturn(t *testing.T) {
+ t.Parallel()
+
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("not supported on %s", runtime.GOOS)
+ }
+
+ ln := newLocalListener(t, "tcp")
+ defer ln.Close()
+
+ max := time.NewTimer(time.Second)
+ defer max.Stop()
+ ch := make(chan error)
+ go func() {
+ if err := ln.(*TCPListener).SetDeadline(noDeadline); err != nil {
+ t.Error(err)
+ }
+ if err := ln.(*TCPListener).SetDeadline(time.Now().Add(10 * time.Millisecond)); err != nil {
+ t.Error(err)
+ }
+ c, err := ln.Accept()
+ if err == nil {
+ c.Close()
+ }
+ ch <- err
+ }()
+
+ select {
+ case <-max.C:
+ ln.Close()
+ <-ch // wait for tester goroutine to stop
+ t.Fatal("Accept didn't return in an expected time")
+ case err := <-ch:
+ if perr := parseAcceptError(err); perr != nil {
+ t.Error(perr)
+ }
+ if !isDeadlineExceeded(err) {
+ t.Fatal(err)
+ }
+ }
+}
+
+func TestAcceptTimeoutMustNotReturn(t *testing.T) {
+ t.Parallel()
+
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("not supported on %s", runtime.GOOS)
+ }
+
+ ln := newLocalListener(t, "tcp")
+ defer ln.Close()
+
+ maxch := make(chan *time.Timer)
+ ch := make(chan error)
+ go func() {
+ if err := ln.(*TCPListener).SetDeadline(time.Now().Add(-5 * time.Second)); err != nil {
+ t.Error(err)
+ }
+ if err := ln.(*TCPListener).SetDeadline(noDeadline); err != nil {
+ t.Error(err)
+ }
+ maxch <- time.NewTimer(100 * time.Millisecond)
+ _, err := ln.Accept()
+ ch <- err
+ }()
+
+ max := <-maxch
+ defer max.Stop()
+
+ select {
+ case err := <-ch:
+ if perr := parseAcceptError(err); perr != nil {
+ t.Error(perr)
+ }
+ t.Fatalf("expected Accept to not return, but it returned with %v", err)
+ case <-max.C:
+ ln.Close()
+ <-ch // wait for tester goroutine to stop
+ }
+}
+
+var readTimeoutTests = []struct {
+ timeout time.Duration
+ xerrs [2]error // expected errors in transition
+}{
+ // Tests that read deadlines work, even if there's data ready
+ // to be read.
+ {-5 * time.Second, [2]error{os.ErrDeadlineExceeded, os.ErrDeadlineExceeded}},
+
+ {50 * time.Millisecond, [2]error{nil, os.ErrDeadlineExceeded}},
+}
+
+// There is a very similar copy of this in os/timeout_test.go.
+func TestReadTimeout(t *testing.T) {
+ handler := func(ls *localServer, ln Listener) {
+ c, err := ln.Accept()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ c.Write([]byte("READ TIMEOUT TEST"))
+ defer c.Close()
+ }
+ ls := newLocalServer(t, "tcp")
+ defer ls.teardown()
+ if err := ls.buildup(handler); err != nil {
+ t.Fatal(err)
+ }
+
+ c, err := Dial(ls.Listener.Addr().Network(), ls.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+
+ for i, tt := range readTimeoutTests {
+ if err := c.SetReadDeadline(time.Now().Add(tt.timeout)); err != nil {
+ t.Fatalf("#%d: %v", i, err)
+ }
+ var b [1]byte
+ for j, xerr := range tt.xerrs {
+ for {
+ n, err := c.Read(b[:])
+ if xerr != nil {
+ if perr := parseReadError(err); perr != nil {
+ t.Errorf("#%d/%d: %v", i, j, perr)
+ }
+ if !isDeadlineExceeded(err) {
+ t.Fatalf("#%d/%d: %v", i, j, err)
+ }
+ }
+ if err == nil {
+ time.Sleep(tt.timeout / 3)
+ continue
+ }
+ if n != 0 {
+ t.Fatalf("#%d/%d: read %d; want 0", i, j, n)
+ }
+ break
+ }
+ }
+ }
+}
+
+// There is a very similar copy of this in os/timeout_test.go.
+func TestReadTimeoutMustNotReturn(t *testing.T) {
+ t.Parallel()
+
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("not supported on %s", runtime.GOOS)
+ }
+
+ ln := newLocalListener(t, "tcp")
+ defer ln.Close()
+
+ c, err := Dial(ln.Addr().Network(), ln.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+
+ maxch := make(chan *time.Timer)
+ ch := make(chan error)
+ go func() {
+ if err := c.SetDeadline(time.Now().Add(-5 * time.Second)); err != nil {
+ t.Error(err)
+ }
+ if err := c.SetWriteDeadline(time.Now().Add(-5 * time.Second)); err != nil {
+ t.Error(err)
+ }
+ if err := c.SetReadDeadline(noDeadline); err != nil {
+ t.Error(err)
+ }
+ maxch <- time.NewTimer(100 * time.Millisecond)
+ var b [1]byte
+ _, err := c.Read(b[:])
+ ch <- err
+ }()
+
+ max := <-maxch
+ defer max.Stop()
+
+ select {
+ case err := <-ch:
+ if perr := parseReadError(err); perr != nil {
+ t.Error(perr)
+ }
+ t.Fatalf("expected Read to not return, but it returned with %v", err)
+ case <-max.C:
+ c.Close()
+ err := <-ch // wait for tester goroutine to stop
+ if perr := parseReadError(err); perr != nil {
+ t.Error(perr)
+ }
+ if nerr, ok := err.(Error); !ok || nerr.Timeout() || nerr.Temporary() {
+ t.Fatal(err)
+ }
+ }
+}
+
+var readFromTimeoutTests = []struct {
+ timeout time.Duration
+ xerrs [2]error // expected errors in transition
+}{
+ // Tests that read deadlines work, even if there's data ready
+ // to be read.
+ {-5 * time.Second, [2]error{os.ErrDeadlineExceeded, os.ErrDeadlineExceeded}},
+
+ {50 * time.Millisecond, [2]error{nil, os.ErrDeadlineExceeded}},
+}
+
+func TestReadFromTimeout(t *testing.T) {
+ ch := make(chan Addr)
+ defer close(ch)
+ handler := func(ls *localPacketServer, c PacketConn) {
+ if dst, ok := <-ch; ok {
+ c.WriteTo([]byte("READFROM TIMEOUT TEST"), dst)
+ }
+ }
+ ls := newLocalPacketServer(t, "udp")
+ defer ls.teardown()
+ if err := ls.buildup(handler); err != nil {
+ t.Fatal(err)
+ }
+
+ host, _, err := SplitHostPort(ls.PacketConn.LocalAddr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ c, err := ListenPacket(ls.PacketConn.LocalAddr().Network(), JoinHostPort(host, "0"))
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+ ch <- c.LocalAddr()
+
+ for i, tt := range readFromTimeoutTests {
+ if err := c.SetReadDeadline(time.Now().Add(tt.timeout)); err != nil {
+ t.Fatalf("#%d: %v", i, err)
+ }
+ var b [1]byte
+ for j, xerr := range tt.xerrs {
+ for {
+ n, _, err := c.ReadFrom(b[:])
+ if xerr != nil {
+ if perr := parseReadError(err); perr != nil {
+ t.Errorf("#%d/%d: %v", i, j, perr)
+ }
+ if !isDeadlineExceeded(err) {
+ t.Fatalf("#%d/%d: %v", i, j, err)
+ }
+ }
+ if err == nil {
+ time.Sleep(tt.timeout / 3)
+ continue
+ }
+ if nerr, ok := err.(Error); ok && nerr.Timeout() && n != 0 {
+ t.Fatalf("#%d/%d: read %d; want 0", i, j, n)
+ }
+ break
+ }
+ }
+ }
+}
+
+var writeTimeoutTests = []struct {
+ timeout time.Duration
+ xerrs [2]error // expected errors in transition
+}{
+ // Tests that write deadlines work, even if there's buffer
+ // space available to write.
+ {-5 * time.Second, [2]error{os.ErrDeadlineExceeded, os.ErrDeadlineExceeded}},
+
+ {10 * time.Millisecond, [2]error{nil, os.ErrDeadlineExceeded}},
+}
+
+// There is a very similar copy of this in os/timeout_test.go.
+func TestWriteTimeout(t *testing.T) {
+ t.Parallel()
+
+ ln := newLocalListener(t, "tcp")
+ defer ln.Close()
+
+ for i, tt := range writeTimeoutTests {
+ c, err := Dial(ln.Addr().Network(), ln.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+
+ if err := c.SetWriteDeadline(time.Now().Add(tt.timeout)); err != nil {
+ t.Fatalf("#%d: %v", i, err)
+ }
+ for j, xerr := range tt.xerrs {
+ for {
+ n, err := c.Write([]byte("WRITE TIMEOUT TEST"))
+ if xerr != nil {
+ if perr := parseWriteError(err); perr != nil {
+ t.Errorf("#%d/%d: %v", i, j, perr)
+ }
+ if !isDeadlineExceeded(err) {
+ t.Fatalf("#%d/%d: %v", i, j, err)
+ }
+ }
+ if err == nil {
+ time.Sleep(tt.timeout / 3)
+ continue
+ }
+ if n != 0 {
+ t.Fatalf("#%d/%d: wrote %d; want 0", i, j, n)
+ }
+ break
+ }
+ }
+ }
+}
+
+// There is a very similar copy of this in os/timeout_test.go.
+func TestWriteTimeoutMustNotReturn(t *testing.T) {
+ t.Parallel()
+
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("not supported on %s", runtime.GOOS)
+ }
+
+ ln := newLocalListener(t, "tcp")
+ defer ln.Close()
+
+ c, err := Dial(ln.Addr().Network(), ln.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+
+ maxch := make(chan *time.Timer)
+ ch := make(chan error)
+ go func() {
+ if err := c.SetDeadline(time.Now().Add(-5 * time.Second)); err != nil {
+ t.Error(err)
+ }
+ if err := c.SetReadDeadline(time.Now().Add(-5 * time.Second)); err != nil {
+ t.Error(err)
+ }
+ if err := c.SetWriteDeadline(noDeadline); err != nil {
+ t.Error(err)
+ }
+ maxch <- time.NewTimer(100 * time.Millisecond)
+ var b [1]byte
+ for {
+ if _, err := c.Write(b[:]); err != nil {
+ ch <- err
+ break
+ }
+ }
+ }()
+
+ max := <-maxch
+ defer max.Stop()
+
+ select {
+ case err := <-ch:
+ if perr := parseWriteError(err); perr != nil {
+ t.Error(perr)
+ }
+ t.Fatalf("expected Write to not return, but it returned with %v", err)
+ case <-max.C:
+ c.Close()
+ err := <-ch // wait for tester goroutine to stop
+ if perr := parseWriteError(err); perr != nil {
+ t.Error(perr)
+ }
+ if nerr, ok := err.(Error); !ok || nerr.Timeout() || nerr.Temporary() {
+ t.Fatal(err)
+ }
+ }
+}
+
+func TestWriteToTimeout(t *testing.T) {
+ t.Parallel()
+
+ c1 := newLocalPacketListener(t, "udp")
+ defer c1.Close()
+
+ host, _, err := SplitHostPort(c1.LocalAddr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ timeouts := []time.Duration{
+ -5 * time.Second,
+ 10 * time.Millisecond,
+ }
+
+ for _, timeout := range timeouts {
+ t.Run(fmt.Sprint(timeout), func(t *testing.T) {
+ c2, err := ListenPacket(c1.LocalAddr().Network(), JoinHostPort(host, "0"))
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c2.Close()
+
+ if err := c2.SetWriteDeadline(time.Now().Add(timeout)); err != nil {
+ t.Fatalf("SetWriteDeadline: %v", err)
+ }
+ backoff := 1 * time.Millisecond
+ nDeadlineExceeded := 0
+ for j := 0; nDeadlineExceeded < 2; j++ {
+ n, err := c2.WriteTo([]byte("WRITETO TIMEOUT TEST"), c1.LocalAddr())
+ t.Logf("#%d: WriteTo: %d, %v", j, n, err)
+ if err == nil && timeout >= 0 && nDeadlineExceeded == 0 {
+ // If the timeout is nonnegative, some number of WriteTo calls may
+ // succeed before the timeout takes effect.
+ t.Logf("WriteTo succeeded; sleeping %v", timeout/3)
+ time.Sleep(timeout / 3)
+ continue
+ }
+ if isENOBUFS(err) {
+ t.Logf("WriteTo: %v", err)
+ // We're looking for a deadline exceeded error, but if the kernel's
+ // network buffers are saturated we may see ENOBUFS instead (see
+ // https://go.dev/issue/49930). Give it some time to unsaturate.
+ time.Sleep(backoff)
+ backoff *= 2
+ continue
+ }
+ if perr := parseWriteError(err); perr != nil {
+ t.Errorf("failed to parse error: %v", perr)
+ }
+ if !isDeadlineExceeded(err) {
+ t.Errorf("error is not 'deadline exceeded'")
+ }
+ if n != 0 {
+ t.Errorf("unexpectedly wrote %d bytes", n)
+ }
+ if !t.Failed() {
+ t.Logf("WriteTo timed out as expected")
+ }
+ nDeadlineExceeded++
+ }
+ })
+ }
+}
+
+const (
+ // minDynamicTimeout is the minimum timeout to attempt for
+ // tests that automatically increase timeouts until success.
+ //
+ // Lower values may allow tests to succeed more quickly if the value is close
+ // to the true minimum, but may require more iterations (and waste more time
+ // and CPU power on failed attempts) if the timeout is too low.
+ minDynamicTimeout = 1 * time.Millisecond
+
+ // maxDynamicTimeout is the maximum timeout to attempt for
+ // tests that automatically increase timeouts until success.
+ //
+ // This should be a strict upper bound on the latency required to hit a
+ // timeout accurately, even on a slow or heavily-loaded machine. If a test
+ // would increase the timeout beyond this value, the test fails.
+ maxDynamicTimeout = 4 * time.Second
+)
+
+// timeoutUpperBound returns the maximum time that we expect a timeout of
+// duration d to take to return the caller.
+func timeoutUpperBound(d time.Duration) time.Duration {
+ switch runtime.GOOS {
+ case "openbsd", "netbsd":
+ // NetBSD and OpenBSD seem to be unable to reliably hit deadlines even when
+ // the absolute durations are long.
+ // In https://build.golang.org/log/c34f8685d020b98377dd4988cd38f0c5bd72267e,
+ // we observed that an openbsd-amd64-68 builder took 4.090948779s for a
+ // 2.983020682s timeout (37.1% overhead).
+ // (See https://go.dev/issue/50189 for further detail.)
+ // Give them lots of slop to compensate.
+ return d * 3 / 2
+ }
+ // Other platforms seem to hit their deadlines more reliably,
+ // at least when they are long enough to cover scheduling jitter.
+ return d * 11 / 10
+}
+
+// nextTimeout returns the next timeout to try after an operation took the given
+// actual duration with a timeout shorter than that duration.
+func nextTimeout(actual time.Duration) (next time.Duration, ok bool) {
+ if actual >= maxDynamicTimeout {
+ return maxDynamicTimeout, false
+ }
+ // Since the previous attempt took actual, we can't expect to beat that
+ // duration by any significant margin. Try the next attempt with an arbitrary
+ // factor above that, so that our growth curve is at least exponential.
+ next = actual * 5 / 4
+ if next > maxDynamicTimeout {
+ return maxDynamicTimeout, true
+ }
+ return next, true
+}
+
+// There is a very similar copy of this in os/timeout_test.go.
+func TestReadTimeoutFluctuation(t *testing.T) {
+ ln := newLocalListener(t, "tcp")
+ defer ln.Close()
+
+ c, err := Dial(ln.Addr().Network(), ln.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+
+ d := minDynamicTimeout
+ b := make([]byte, 256)
+ for {
+ t.Logf("SetReadDeadline(+%v)", d)
+ t0 := time.Now()
+ deadline := t0.Add(d)
+ if err = c.SetReadDeadline(deadline); err != nil {
+ t.Fatalf("SetReadDeadline(%v): %v", deadline, err)
+ }
+ var n int
+ n, err = c.Read(b)
+ t1 := time.Now()
+
+ if n != 0 || err == nil || !err.(Error).Timeout() {
+ t.Errorf("Read did not return (0, timeout): (%d, %v)", n, err)
+ }
+ if perr := parseReadError(err); perr != nil {
+ t.Error(perr)
+ }
+ if !isDeadlineExceeded(err) {
+ t.Errorf("Read error is not DeadlineExceeded: %v", err)
+ }
+
+ actual := t1.Sub(t0)
+ if t1.Before(deadline) {
+ t.Errorf("Read took %s; expected at least %s", actual, d)
+ }
+ if t.Failed() {
+ return
+ }
+ if want := timeoutUpperBound(d); actual > want {
+ next, ok := nextTimeout(actual)
+ if !ok {
+ t.Fatalf("Read took %s; expected at most %v", actual, want)
+ }
+ // Maybe this machine is too slow to reliably schedule goroutines within
+ // the requested duration. Increase the timeout and try again.
+ t.Logf("Read took %s (expected %s); trying with longer timeout", actual, d)
+ d = next
+ continue
+ }
+
+ break
+ }
+}
+
+// There is a very similar copy of this in os/timeout_test.go.
+func TestReadFromTimeoutFluctuation(t *testing.T) {
+ c1 := newLocalPacketListener(t, "udp")
+ defer c1.Close()
+
+ c2, err := Dial(c1.LocalAddr().Network(), c1.LocalAddr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c2.Close()
+
+ d := minDynamicTimeout
+ b := make([]byte, 256)
+ for {
+ t.Logf("SetReadDeadline(+%v)", d)
+ t0 := time.Now()
+ deadline := t0.Add(d)
+ if err = c2.SetReadDeadline(deadline); err != nil {
+ t.Fatalf("SetReadDeadline(%v): %v", deadline, err)
+ }
+ var n int
+ n, _, err = c2.(PacketConn).ReadFrom(b)
+ t1 := time.Now()
+
+ if n != 0 || err == nil || !err.(Error).Timeout() {
+ t.Errorf("ReadFrom did not return (0, timeout): (%d, %v)", n, err)
+ }
+ if perr := parseReadError(err); perr != nil {
+ t.Error(perr)
+ }
+ if !isDeadlineExceeded(err) {
+ t.Errorf("ReadFrom error is not DeadlineExceeded: %v", err)
+ }
+
+ actual := t1.Sub(t0)
+ if t1.Before(deadline) {
+ t.Errorf("ReadFrom took %s; expected at least %s", actual, d)
+ }
+ if t.Failed() {
+ return
+ }
+ if want := timeoutUpperBound(d); actual > want {
+ next, ok := nextTimeout(actual)
+ if !ok {
+ t.Fatalf("ReadFrom took %s; expected at most %s", actual, want)
+ }
+ // Maybe this machine is too slow to reliably schedule goroutines within
+ // the requested duration. Increase the timeout and try again.
+ t.Logf("ReadFrom took %s (expected %s); trying with longer timeout", actual, d)
+ d = next
+ continue
+ }
+
+ break
+ }
+}
+
+func TestWriteTimeoutFluctuation(t *testing.T) {
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("not supported on %s", runtime.GOOS)
+ }
+
+ ln := newLocalListener(t, "tcp")
+ defer ln.Close()
+
+ c, err := Dial(ln.Addr().Network(), ln.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+
+ d := minDynamicTimeout
+ for {
+ t.Logf("SetWriteDeadline(+%v)", d)
+ t0 := time.Now()
+ deadline := t0.Add(d)
+ if err := c.SetWriteDeadline(deadline); err != nil {
+ t.Fatalf("SetWriteDeadline(%v): %v", deadline, err)
+ }
+ var n int64
+ var err error
+ for {
+ var dn int
+ dn, err = c.Write([]byte("TIMEOUT TRANSMITTER"))
+ n += int64(dn)
+ if err != nil {
+ break
+ }
+ }
+ t1 := time.Now()
+ // Inv: err != nil
+ if !err.(Error).Timeout() {
+ t.Fatalf("Write did not return (any, timeout): (%d, %v)", n, err)
+ }
+ if perr := parseWriteError(err); perr != nil {
+ t.Error(perr)
+ }
+ if !isDeadlineExceeded(err) {
+ t.Errorf("Write error is not DeadlineExceeded: %v", err)
+ }
+
+ actual := t1.Sub(t0)
+ if t1.Before(deadline) {
+ t.Errorf("Write took %s; expected at least %s", actual, d)
+ }
+ if t.Failed() {
+ return
+ }
+ if want := timeoutUpperBound(d); actual > want {
+ if n > 0 {
+ // SetWriteDeadline specifies a time “after which I/O operations fail
+ // instead of blocking”. However, the kernel's send buffer is not yet
+ // full, we may be able to write some arbitrary (but finite) number of
+ // bytes to it without blocking.
+ t.Logf("Wrote %d bytes into send buffer; retrying until buffer is full", n)
+ if d <= maxDynamicTimeout/2 {
+ // We don't know how long the actual write loop would have taken if
+ // the buffer were full, so just guess and double the duration so that
+ // the next attempt can make twice as much progress toward filling it.
+ d *= 2
+ }
+ } else if next, ok := nextTimeout(actual); !ok {
+ t.Fatalf("Write took %s; expected at most %s", actual, want)
+ } else {
+ // Maybe this machine is too slow to reliably schedule goroutines within
+ // the requested duration. Increase the timeout and try again.
+ t.Logf("Write took %s (expected %s); trying with longer timeout", actual, d)
+ d = next
+ }
+ continue
+ }
+
+ break
+ }
+}
+
+// There is a very similar copy of this in os/timeout_test.go.
+func TestVariousDeadlines(t *testing.T) {
+ t.Parallel()
+ testVariousDeadlines(t)
+}
+
+// There is a very similar copy of this in os/timeout_test.go.
+func TestVariousDeadlines1Proc(t *testing.T) {
+ // Cannot use t.Parallel - modifies global GOMAXPROCS.
+ if testing.Short() {
+ t.Skip("skipping in short mode")
+ }
+ defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(1))
+ testVariousDeadlines(t)
+}
+
+// There is a very similar copy of this in os/timeout_test.go.
+func TestVariousDeadlines4Proc(t *testing.T) {
+ // Cannot use t.Parallel - modifies global GOMAXPROCS.
+ if testing.Short() {
+ t.Skip("skipping in short mode")
+ }
+ defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(4))
+ testVariousDeadlines(t)
+}
+
+func testVariousDeadlines(t *testing.T) {
+ handler := func(ls *localServer, ln Listener) {
+ for {
+ c, err := ln.Accept()
+ if err != nil {
+ break
+ }
+ c.Read(make([]byte, 1)) // wait for client to close connection
+ c.Close()
+ }
+ }
+ ls := newLocalServer(t, "tcp")
+ defer ls.teardown()
+ if err := ls.buildup(handler); err != nil {
+ t.Fatal(err)
+ }
+
+ for _, timeout := range []time.Duration{
+ 1 * time.Nanosecond,
+ 2 * time.Nanosecond,
+ 5 * time.Nanosecond,
+ 50 * time.Nanosecond,
+ 100 * time.Nanosecond,
+ 200 * time.Nanosecond,
+ 500 * time.Nanosecond,
+ 750 * time.Nanosecond,
+ 1 * time.Microsecond,
+ 5 * time.Microsecond,
+ 25 * time.Microsecond,
+ 250 * time.Microsecond,
+ 500 * time.Microsecond,
+ 1 * time.Millisecond,
+ 5 * time.Millisecond,
+ 100 * time.Millisecond,
+ 250 * time.Millisecond,
+ 500 * time.Millisecond,
+ 1 * time.Second,
+ } {
+ numRuns := 3
+ if testing.Short() {
+ numRuns = 1
+ if timeout > 500*time.Microsecond {
+ continue
+ }
+ }
+ for run := 0; run < numRuns; run++ {
+ name := fmt.Sprintf("%v %d/%d", timeout, run, numRuns)
+ t.Log(name)
+
+ c, err := Dial(ls.Listener.Addr().Network(), ls.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ t0 := time.Now()
+ if err := c.SetDeadline(t0.Add(timeout)); err != nil {
+ t.Error(err)
+ }
+ n, err := io.Copy(io.Discard, c)
+ dt := time.Since(t0)
+ c.Close()
+
+ if nerr, ok := err.(Error); ok && nerr.Timeout() {
+ t.Logf("%v: good timeout after %v; %d bytes", name, dt, n)
+ } else {
+ t.Fatalf("%v: Copy = %d, %v; want timeout", name, n, err)
+ }
+ }
+ }
+}
+
+// TestReadWriteProlongedTimeout tests concurrent deadline
+// modification. Known to cause data races in the past.
+func TestReadWriteProlongedTimeout(t *testing.T) {
+ t.Parallel()
+
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("not supported on %s", runtime.GOOS)
+ }
+
+ handler := func(ls *localServer, ln Listener) {
+ c, err := ln.Accept()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ defer c.Close()
+
+ var wg sync.WaitGroup
+ wg.Add(2)
+ go func() {
+ defer wg.Done()
+ var b [1]byte
+ for {
+ if err := c.SetReadDeadline(time.Now().Add(time.Hour)); err != nil {
+ if perr := parseCommonError(err); perr != nil {
+ t.Error(perr)
+ }
+ t.Error(err)
+ return
+ }
+ if _, err := c.Read(b[:]); err != nil {
+ if perr := parseReadError(err); perr != nil {
+ t.Error(perr)
+ }
+ return
+ }
+ }
+ }()
+ go func() {
+ defer wg.Done()
+ var b [1]byte
+ for {
+ if err := c.SetWriteDeadline(time.Now().Add(time.Hour)); err != nil {
+ if perr := parseCommonError(err); perr != nil {
+ t.Error(perr)
+ }
+ t.Error(err)
+ return
+ }
+ if _, err := c.Write(b[:]); err != nil {
+ if perr := parseWriteError(err); perr != nil {
+ t.Error(perr)
+ }
+ return
+ }
+ }
+ }()
+ wg.Wait()
+ }
+ ls := newLocalServer(t, "tcp")
+ defer ls.teardown()
+ if err := ls.buildup(handler); err != nil {
+ t.Fatal(err)
+ }
+
+ c, err := Dial(ls.Listener.Addr().Network(), ls.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+
+ var b [1]byte
+ for i := 0; i < 1000; i++ {
+ c.Write(b[:])
+ c.Read(b[:])
+ }
+}
+
+// There is a very similar copy of this in os/timeout_test.go.
+func TestReadWriteDeadlineRace(t *testing.T) {
+ t.Parallel()
+
+ N := 1000
+ if testing.Short() {
+ N = 50
+ }
+
+ ln := newLocalListener(t, "tcp")
+ defer ln.Close()
+
+ c, err := Dial(ln.Addr().Network(), ln.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+
+ var wg sync.WaitGroup
+ wg.Add(3)
+ go func() {
+ defer wg.Done()
+ tic := time.NewTicker(2 * time.Microsecond)
+ defer tic.Stop()
+ for i := 0; i < N; i++ {
+ if err := c.SetReadDeadline(time.Now().Add(2 * time.Microsecond)); err != nil {
+ if perr := parseCommonError(err); perr != nil {
+ t.Error(perr)
+ }
+ break
+ }
+ if err := c.SetWriteDeadline(time.Now().Add(2 * time.Microsecond)); err != nil {
+ if perr := parseCommonError(err); perr != nil {
+ t.Error(perr)
+ }
+ break
+ }
+ <-tic.C
+ }
+ }()
+ go func() {
+ defer wg.Done()
+ var b [1]byte
+ for i := 0; i < N; i++ {
+ c.Read(b[:]) // ignore possible timeout errors
+ }
+ }()
+ go func() {
+ defer wg.Done()
+ var b [1]byte
+ for i := 0; i < N; i++ {
+ c.Write(b[:]) // ignore possible timeout errors
+ }
+ }()
+ wg.Wait() // wait for tester goroutine to stop
+}
+
+// Issue 35367.
+func TestConcurrentSetDeadline(t *testing.T) {
+ ln := newLocalListener(t, "tcp")
+ defer ln.Close()
+
+ const goroutines = 8
+ const conns = 10
+ const tries = 100
+
+ var c [conns]Conn
+ for i := 0; i < conns; i++ {
+ var err error
+ c[i], err = Dial(ln.Addr().Network(), ln.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c[i].Close()
+ }
+
+ var wg sync.WaitGroup
+ wg.Add(goroutines)
+ now := time.Now()
+ for i := 0; i < goroutines; i++ {
+ go func(i int) {
+ defer wg.Done()
+ // Make the deadlines steadily earlier,
+ // to trigger runtime adjusttimers calls.
+ for j := tries; j > 0; j-- {
+ for k := 0; k < conns; k++ {
+ c[k].SetReadDeadline(now.Add(2*time.Hour + time.Duration(i*j*k)*time.Second))
+ c[k].SetWriteDeadline(now.Add(1*time.Hour + time.Duration(i*j*k)*time.Second))
+ }
+ }
+ }(i)
+ }
+ wg.Wait()
+}
+
+// isDeadlineExceeded reports whether err is or wraps os.ErrDeadlineExceeded.
+// We also check that the error implements net.Error, and that the
+// Timeout method returns true.
+func isDeadlineExceeded(err error) bool {
+ nerr, ok := err.(Error)
+ if !ok {
+ return false
+ }
+ if !nerr.Timeout() {
+ return false
+ }
+ if !errors.Is(err, os.ErrDeadlineExceeded) {
+ return false
+ }
+ return true
+}
diff --git a/src/net/udpsock.go b/src/net/udpsock.go
new file mode 100644
index 0000000..e30624d
--- /dev/null
+++ b/src/net/udpsock.go
@@ -0,0 +1,368 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "context"
+ "internal/itoa"
+ "net/netip"
+ "syscall"
+)
+
+// BUG(mikio): On Plan 9, the ReadMsgUDP and
+// WriteMsgUDP methods of UDPConn are not implemented.
+
+// BUG(mikio): On Windows, the File method of UDPConn is not
+// implemented.
+
+// BUG(mikio): On JS, methods and functions related to UDPConn are not
+// implemented.
+
+// UDPAddr represents the address of a UDP end point.
+type UDPAddr struct {
+ IP IP
+ Port int
+ Zone string // IPv6 scoped addressing zone
+}
+
+// AddrPort returns the UDPAddr a as a netip.AddrPort.
+//
+// If a.Port does not fit in a uint16, it's silently truncated.
+//
+// If a is nil, a zero value is returned.
+func (a *UDPAddr) AddrPort() netip.AddrPort {
+ if a == nil {
+ return netip.AddrPort{}
+ }
+ na, _ := netip.AddrFromSlice(a.IP)
+ na = na.WithZone(a.Zone)
+ return netip.AddrPortFrom(na, uint16(a.Port))
+}
+
+// Network returns the address's network name, "udp".
+func (a *UDPAddr) Network() string { return "udp" }
+
+func (a *UDPAddr) String() string {
+ if a == nil {
+ return "<nil>"
+ }
+ ip := ipEmptyString(a.IP)
+ if a.Zone != "" {
+ return JoinHostPort(ip+"%"+a.Zone, itoa.Itoa(a.Port))
+ }
+ return JoinHostPort(ip, itoa.Itoa(a.Port))
+}
+
+func (a *UDPAddr) isWildcard() bool {
+ if a == nil || a.IP == nil {
+ return true
+ }
+ return a.IP.IsUnspecified()
+}
+
+func (a *UDPAddr) opAddr() Addr {
+ if a == nil {
+ return nil
+ }
+ return a
+}
+
+// ResolveUDPAddr returns an address of UDP end point.
+//
+// The network must be a UDP network name.
+//
+// If the host in the address parameter is not a literal IP address or
+// the port is not a literal port number, ResolveUDPAddr resolves the
+// address to an address of UDP end point.
+// Otherwise, it parses the address as a pair of literal IP address
+// and port number.
+// The address parameter can use a host name, but this is not
+// recommended, because it will return at most one of the host name's
+// IP addresses.
+//
+// See func Dial for a description of the network and address
+// parameters.
+func ResolveUDPAddr(network, address string) (*UDPAddr, error) {
+ switch network {
+ case "udp", "udp4", "udp6":
+ case "": // a hint wildcard for Go 1.0 undocumented behavior
+ network = "udp"
+ default:
+ return nil, UnknownNetworkError(network)
+ }
+ addrs, err := DefaultResolver.internetAddrList(context.Background(), network, address)
+ if err != nil {
+ return nil, err
+ }
+ return addrs.forResolve(network, address).(*UDPAddr), nil
+}
+
+// UDPAddrFromAddrPort returns addr as a UDPAddr. If addr.IsValid() is false,
+// then the returned UDPAddr will contain a nil IP field, indicating an
+// address family-agnostic unspecified address.
+func UDPAddrFromAddrPort(addr netip.AddrPort) *UDPAddr {
+ return &UDPAddr{
+ IP: addr.Addr().AsSlice(),
+ Zone: addr.Addr().Zone(),
+ Port: int(addr.Port()),
+ }
+}
+
+// An addrPortUDPAddr is a netip.AddrPort-based UDP address that satisfies the Addr interface.
+type addrPortUDPAddr struct {
+ netip.AddrPort
+}
+
+func (addrPortUDPAddr) Network() string { return "udp" }
+
+// UDPConn is the implementation of the Conn and PacketConn interfaces
+// for UDP network connections.
+type UDPConn struct {
+ conn
+}
+
+// SyscallConn returns a raw network connection.
+// This implements the syscall.Conn interface.
+func (c *UDPConn) SyscallConn() (syscall.RawConn, error) {
+ if !c.ok() {
+ return nil, syscall.EINVAL
+ }
+ return newRawConn(c.fd)
+}
+
+// ReadFromUDP acts like ReadFrom but returns a UDPAddr.
+func (c *UDPConn) ReadFromUDP(b []byte) (n int, addr *UDPAddr, err error) {
+ // This function is designed to allow the caller to control the lifetime
+ // of the returned *UDPAddr and thereby prevent an allocation.
+ // See https://blog.filippo.io/efficient-go-apis-with-the-inliner/.
+ // The real work is done by readFromUDP, below.
+ return c.readFromUDP(b, &UDPAddr{})
+}
+
+// readFromUDP implements ReadFromUDP.
+func (c *UDPConn) readFromUDP(b []byte, addr *UDPAddr) (int, *UDPAddr, error) {
+ if !c.ok() {
+ return 0, nil, syscall.EINVAL
+ }
+ n, addr, err := c.readFrom(b, addr)
+ if err != nil {
+ err = &OpError{Op: "read", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
+ }
+ return n, addr, err
+}
+
+// ReadFrom implements the PacketConn ReadFrom method.
+func (c *UDPConn) ReadFrom(b []byte) (int, Addr, error) {
+ n, addr, err := c.readFromUDP(b, &UDPAddr{})
+ if addr == nil {
+ // Return Addr(nil), not Addr(*UDPConn(nil)).
+ return n, nil, err
+ }
+ return n, addr, err
+}
+
+// ReadFromUDPAddrPort acts like ReadFrom but returns a netip.AddrPort.
+//
+// If c is bound to an unspecified address, the returned
+// netip.AddrPort's address might be an IPv4-mapped IPv6 address.
+// Use netip.Addr.Unmap to get the address without the IPv6 prefix.
+func (c *UDPConn) ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) {
+ if !c.ok() {
+ return 0, netip.AddrPort{}, syscall.EINVAL
+ }
+ n, addr, err = c.readFromAddrPort(b)
+ if err != nil {
+ err = &OpError{Op: "read", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
+ }
+ return n, addr, err
+}
+
+// ReadMsgUDP reads a message from c, copying the payload into b and
+// the associated out-of-band data into oob. It returns the number of
+// bytes copied into b, the number of bytes copied into oob, the flags
+// that were set on the message and the source address of the message.
+//
+// The packages golang.org/x/net/ipv4 and golang.org/x/net/ipv6 can be
+// used to manipulate IP-level socket options in oob.
+func (c *UDPConn) ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *UDPAddr, err error) {
+ var ap netip.AddrPort
+ n, oobn, flags, ap, err = c.ReadMsgUDPAddrPort(b, oob)
+ if ap.IsValid() {
+ addr = UDPAddrFromAddrPort(ap)
+ }
+ return
+}
+
+// ReadMsgUDPAddrPort is like ReadMsgUDP but returns an netip.AddrPort instead of a UDPAddr.
+func (c *UDPConn) ReadMsgUDPAddrPort(b, oob []byte) (n, oobn, flags int, addr netip.AddrPort, err error) {
+ if !c.ok() {
+ return 0, 0, 0, netip.AddrPort{}, syscall.EINVAL
+ }
+ n, oobn, flags, addr, err = c.readMsg(b, oob)
+ if err != nil {
+ err = &OpError{Op: "read", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
+ }
+ return
+}
+
+// WriteToUDP acts like WriteTo but takes a UDPAddr.
+func (c *UDPConn) WriteToUDP(b []byte, addr *UDPAddr) (int, error) {
+ if !c.ok() {
+ return 0, syscall.EINVAL
+ }
+ n, err := c.writeTo(b, addr)
+ if err != nil {
+ err = &OpError{Op: "write", Net: c.fd.net, Source: c.fd.laddr, Addr: addr.opAddr(), Err: err}
+ }
+ return n, err
+}
+
+// WriteToUDPAddrPort acts like WriteTo but takes a netip.AddrPort.
+func (c *UDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) {
+ if !c.ok() {
+ return 0, syscall.EINVAL
+ }
+ n, err := c.writeToAddrPort(b, addr)
+ if err != nil {
+ err = &OpError{Op: "write", Net: c.fd.net, Source: c.fd.laddr, Addr: addrPortUDPAddr{addr}, Err: err}
+ }
+ return n, err
+}
+
+// WriteTo implements the PacketConn WriteTo method.
+func (c *UDPConn) WriteTo(b []byte, addr Addr) (int, error) {
+ if !c.ok() {
+ return 0, syscall.EINVAL
+ }
+ a, ok := addr.(*UDPAddr)
+ if !ok {
+ return 0, &OpError{Op: "write", Net: c.fd.net, Source: c.fd.laddr, Addr: addr, Err: syscall.EINVAL}
+ }
+ n, err := c.writeTo(b, a)
+ if err != nil {
+ err = &OpError{Op: "write", Net: c.fd.net, Source: c.fd.laddr, Addr: a.opAddr(), Err: err}
+ }
+ return n, err
+}
+
+// WriteMsgUDP writes a message to addr via c if c isn't connected, or
+// to c's remote address if c is connected (in which case addr must be
+// nil). The payload is copied from b and the associated out-of-band
+// data is copied from oob. It returns the number of payload and
+// out-of-band bytes written.
+//
+// The packages golang.org/x/net/ipv4 and golang.org/x/net/ipv6 can be
+// used to manipulate IP-level socket options in oob.
+func (c *UDPConn) WriteMsgUDP(b, oob []byte, addr *UDPAddr) (n, oobn int, err error) {
+ if !c.ok() {
+ return 0, 0, syscall.EINVAL
+ }
+ n, oobn, err = c.writeMsg(b, oob, addr)
+ if err != nil {
+ err = &OpError{Op: "write", Net: c.fd.net, Source: c.fd.laddr, Addr: addr.opAddr(), Err: err}
+ }
+ return
+}
+
+// WriteMsgUDPAddrPort is like WriteMsgUDP but takes a netip.AddrPort instead of a UDPAddr.
+func (c *UDPConn) WriteMsgUDPAddrPort(b, oob []byte, addr netip.AddrPort) (n, oobn int, err error) {
+ if !c.ok() {
+ return 0, 0, syscall.EINVAL
+ }
+ n, oobn, err = c.writeMsgAddrPort(b, oob, addr)
+ if err != nil {
+ err = &OpError{Op: "write", Net: c.fd.net, Source: c.fd.laddr, Addr: addrPortUDPAddr{addr}, Err: err}
+ }
+ return
+}
+
+func newUDPConn(fd *netFD) *UDPConn { return &UDPConn{conn{fd}} }
+
+// DialUDP acts like Dial for UDP networks.
+//
+// The network must be a UDP network name; see func Dial for details.
+//
+// If laddr is nil, a local address is automatically chosen.
+// If the IP field of raddr is nil or an unspecified IP address, the
+// local system is assumed.
+func DialUDP(network string, laddr, raddr *UDPAddr) (*UDPConn, error) {
+ switch network {
+ case "udp", "udp4", "udp6":
+ default:
+ return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: UnknownNetworkError(network)}
+ }
+ if raddr == nil {
+ return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: nil, Err: errMissingAddress}
+ }
+ sd := &sysDialer{network: network, address: raddr.String()}
+ c, err := sd.dialUDP(context.Background(), laddr, raddr)
+ if err != nil {
+ return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
+ }
+ return c, nil
+}
+
+// ListenUDP acts like ListenPacket for UDP networks.
+//
+// The network must be a UDP network name; see func Dial for details.
+//
+// If the IP field of laddr is nil or an unspecified IP address,
+// ListenUDP listens on all available IP addresses of the local system
+// except multicast IP addresses.
+// If the Port field of laddr is 0, a port number is automatically
+// chosen.
+func ListenUDP(network string, laddr *UDPAddr) (*UDPConn, error) {
+ switch network {
+ case "udp", "udp4", "udp6":
+ default:
+ return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: laddr.opAddr(), Err: UnknownNetworkError(network)}
+ }
+ if laddr == nil {
+ laddr = &UDPAddr{}
+ }
+ sl := &sysListener{network: network, address: laddr.String()}
+ c, err := sl.listenUDP(context.Background(), laddr)
+ if err != nil {
+ return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: laddr.opAddr(), Err: err}
+ }
+ return c, nil
+}
+
+// ListenMulticastUDP acts like ListenPacket for UDP networks but
+// takes a group address on a specific network interface.
+//
+// The network must be a UDP network name; see func Dial for details.
+//
+// ListenMulticastUDP listens on all available IP addresses of the
+// local system including the group, multicast IP address.
+// If ifi is nil, ListenMulticastUDP uses the system-assigned
+// multicast interface, although this is not recommended because the
+// assignment depends on platforms and sometimes it might require
+// routing configuration.
+// If the Port field of gaddr is 0, a port number is automatically
+// chosen.
+//
+// ListenMulticastUDP is just for convenience of simple, small
+// applications. There are golang.org/x/net/ipv4 and
+// golang.org/x/net/ipv6 packages for general purpose uses.
+//
+// Note that ListenMulticastUDP will set the IP_MULTICAST_LOOP socket option
+// to 0 under IPPROTO_IP, to disable loopback of multicast packets.
+func ListenMulticastUDP(network string, ifi *Interface, gaddr *UDPAddr) (*UDPConn, error) {
+ switch network {
+ case "udp", "udp4", "udp6":
+ default:
+ return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: gaddr.opAddr(), Err: UnknownNetworkError(network)}
+ }
+ if gaddr == nil || gaddr.IP == nil {
+ return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: gaddr.opAddr(), Err: errMissingAddress}
+ }
+ sl := &sysListener{network: network, address: gaddr.String()}
+ c, err := sl.listenMulticastUDP(context.Background(), ifi, gaddr)
+ if err != nil {
+ return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: gaddr.opAddr(), Err: err}
+ }
+ return c, nil
+}
diff --git a/src/net/udpsock_plan9.go b/src/net/udpsock_plan9.go
new file mode 100644
index 0000000..732a3b0
--- /dev/null
+++ b/src/net/udpsock_plan9.go
@@ -0,0 +1,182 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "context"
+ "errors"
+ "net/netip"
+ "os"
+ "syscall"
+)
+
+func (c *UDPConn) readFrom(b []byte, addr *UDPAddr) (int, *UDPAddr, error) {
+ buf := make([]byte, udpHeaderSize+len(b))
+ m, err := c.fd.Read(buf)
+ if err != nil {
+ return 0, nil, err
+ }
+ if m < udpHeaderSize {
+ return 0, nil, errors.New("short read reading UDP header")
+ }
+ buf = buf[:m]
+
+ h, buf := unmarshalUDPHeader(buf)
+ n := copy(b, buf)
+ *addr = UDPAddr{IP: h.raddr, Port: int(h.rport)}
+ return n, addr, nil
+}
+
+func (c *UDPConn) readFromAddrPort(b []byte) (int, netip.AddrPort, error) {
+ // TODO: optimize. The equivalent code on posix is alloc-free.
+ buf := make([]byte, udpHeaderSize+len(b))
+ m, err := c.fd.Read(buf)
+ if err != nil {
+ return 0, netip.AddrPort{}, err
+ }
+ if m < udpHeaderSize {
+ return 0, netip.AddrPort{}, errors.New("short read reading UDP header")
+ }
+ buf = buf[:m]
+
+ h, buf := unmarshalUDPHeader(buf)
+ n := copy(b, buf)
+ ip, _ := netip.AddrFromSlice(h.raddr)
+ addr := netip.AddrPortFrom(ip, h.rport)
+ return n, addr, nil
+}
+
+func (c *UDPConn) readMsg(b, oob []byte) (n, oobn, flags int, addr netip.AddrPort, err error) {
+ return 0, 0, 0, netip.AddrPort{}, syscall.EPLAN9
+}
+
+func (c *UDPConn) writeTo(b []byte, addr *UDPAddr) (int, error) {
+ if addr == nil {
+ return 0, errMissingAddress
+ }
+ h := new(udpHeader)
+ h.raddr = addr.IP.To16()
+ h.laddr = c.fd.laddr.(*UDPAddr).IP.To16()
+ h.ifcaddr = IPv6zero // ignored (receive only)
+ h.rport = uint16(addr.Port)
+ h.lport = uint16(c.fd.laddr.(*UDPAddr).Port)
+
+ buf := make([]byte, udpHeaderSize+len(b))
+ i := copy(buf, h.Bytes())
+ copy(buf[i:], b)
+ if _, err := c.fd.Write(buf); err != nil {
+ return 0, err
+ }
+ return len(b), nil
+}
+
+func (c *UDPConn) writeToAddrPort(b []byte, addr netip.AddrPort) (int, error) {
+ return c.writeTo(b, UDPAddrFromAddrPort(addr)) // TODO: optimize instead of allocating
+}
+
+func (c *UDPConn) writeMsg(b, oob []byte, addr *UDPAddr) (n, oobn int, err error) {
+ return 0, 0, syscall.EPLAN9
+}
+
+func (c *UDPConn) writeMsgAddrPort(b, oob []byte, addr netip.AddrPort) (n, oobn int, err error) {
+ return 0, 0, syscall.EPLAN9
+}
+
+func (sd *sysDialer) dialUDP(ctx context.Context, laddr, raddr *UDPAddr) (*UDPConn, error) {
+ fd, err := dialPlan9(ctx, sd.network, laddr, raddr)
+ if err != nil {
+ return nil, err
+ }
+ return newUDPConn(fd), nil
+}
+
+const udpHeaderSize = 16*3 + 2*2
+
+type udpHeader struct {
+ raddr, laddr, ifcaddr IP
+ rport, lport uint16
+}
+
+func (h *udpHeader) Bytes() []byte {
+ b := make([]byte, udpHeaderSize)
+ i := 0
+ i += copy(b[i:i+16], h.raddr)
+ i += copy(b[i:i+16], h.laddr)
+ i += copy(b[i:i+16], h.ifcaddr)
+ b[i], b[i+1], i = byte(h.rport>>8), byte(h.rport), i+2
+ b[i], b[i+1], i = byte(h.lport>>8), byte(h.lport), i+2
+ return b
+}
+
+func unmarshalUDPHeader(b []byte) (*udpHeader, []byte) {
+ h := new(udpHeader)
+ h.raddr, b = IP(b[:16]), b[16:]
+ h.laddr, b = IP(b[:16]), b[16:]
+ h.ifcaddr, b = IP(b[:16]), b[16:]
+ h.rport, b = uint16(b[0])<<8|uint16(b[1]), b[2:]
+ h.lport, b = uint16(b[0])<<8|uint16(b[1]), b[2:]
+ return h, b
+}
+
+func (sl *sysListener) listenUDP(ctx context.Context, laddr *UDPAddr) (*UDPConn, error) {
+ l, err := listenPlan9(ctx, sl.network, laddr)
+ if err != nil {
+ return nil, err
+ }
+ _, err = l.ctl.WriteString("headers")
+ if err != nil {
+ return nil, err
+ }
+ l.data, err = os.OpenFile(l.dir+"/data", os.O_RDWR, 0)
+ if err != nil {
+ return nil, err
+ }
+ fd, err := l.netFD()
+ return newUDPConn(fd), err
+}
+
+func (sl *sysListener) listenMulticastUDP(ctx context.Context, ifi *Interface, gaddr *UDPAddr) (*UDPConn, error) {
+ // Plan 9 does not like announce command with a multicast address,
+ // so do not specify an IP address when listening.
+ l, err := listenPlan9(ctx, sl.network, &UDPAddr{IP: nil, Port: gaddr.Port, Zone: gaddr.Zone})
+ if err != nil {
+ return nil, err
+ }
+ _, err = l.ctl.WriteString("headers")
+ if err != nil {
+ return nil, err
+ }
+ var addrs []Addr
+ if ifi != nil {
+ addrs, err = ifi.Addrs()
+ if err != nil {
+ return nil, err
+ }
+ } else {
+ addrs, err = InterfaceAddrs()
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ have4 := gaddr.IP.To4() != nil
+ for _, addr := range addrs {
+ if ipnet, ok := addr.(*IPNet); ok && (ipnet.IP.To4() != nil) == have4 {
+ _, err = l.ctl.WriteString("addmulti " + ipnet.IP.String() + " " + gaddr.IP.String())
+ if err != nil {
+ return nil, &OpError{Op: "addmulti", Net: "", Source: nil, Addr: ipnet, Err: err}
+ }
+ }
+ }
+ l.data, err = os.OpenFile(l.dir+"/data", os.O_RDWR, 0)
+ if err != nil {
+ return nil, err
+ }
+ fd, err := l.netFD()
+ if err != nil {
+ return nil, err
+ }
+ return newUDPConn(fd), nil
+}
diff --git a/src/net/udpsock_plan9_test.go b/src/net/udpsock_plan9_test.go
new file mode 100644
index 0000000..3febfcc
--- /dev/null
+++ b/src/net/udpsock_plan9_test.go
@@ -0,0 +1,69 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "internal/testenv"
+ "runtime"
+ "testing"
+)
+
+func TestListenMulticastUDP(t *testing.T) {
+ testenv.MustHaveExternalNetwork(t)
+
+ ifcs, err := Interfaces()
+ if err != nil {
+ t.Skip(err.Error())
+ }
+ if len(ifcs) == 0 {
+ t.Skip("no network interfaces found")
+ }
+
+ var mifc *Interface
+ for _, ifc := range ifcs {
+ if ifc.Flags&FlagUp|FlagMulticast != FlagUp|FlagMulticast {
+ continue
+ }
+ mifc = &ifc
+ break
+ }
+
+ if mifc == nil {
+ t.Skipf("no multicast interfaces found")
+ }
+
+ c1, err := ListenMulticastUDP("udp4", mifc, &UDPAddr{IP: ParseIP("224.0.0.254")})
+ if err != nil {
+ t.Fatalf("multicast not working on %s: %v", runtime.GOOS, err)
+ }
+ c1addr := c1.LocalAddr().(*UDPAddr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c1.Close()
+
+ c2, err := ListenUDP("udp4", &UDPAddr{IP: IPv4zero, Port: 0})
+ c2addr := c2.LocalAddr().(*UDPAddr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c2.Close()
+
+ n, err := c2.WriteToUDP([]byte("data"), c1addr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if n != 4 {
+ t.Fatalf("got %d; want 4", n)
+ }
+
+ n, err = c1.WriteToUDP([]byte("data"), c2addr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if n != 4 {
+ t.Fatalf("got %d; want 4", n)
+ }
+}
diff --git a/src/net/udpsock_posix.go b/src/net/udpsock_posix.go
new file mode 100644
index 0000000..f3dbcfe
--- /dev/null
+++ b/src/net/udpsock_posix.go
@@ -0,0 +1,287 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build unix || (js && wasm) || wasip1 || windows
+
+package net
+
+import (
+ "context"
+ "net/netip"
+ "syscall"
+)
+
+func sockaddrToUDP(sa syscall.Sockaddr) Addr {
+ switch sa := sa.(type) {
+ case *syscall.SockaddrInet4:
+ return &UDPAddr{IP: sa.Addr[0:], Port: sa.Port}
+ case *syscall.SockaddrInet6:
+ return &UDPAddr{IP: sa.Addr[0:], Port: sa.Port, Zone: zoneCache.name(int(sa.ZoneId))}
+ }
+ return nil
+}
+
+func (a *UDPAddr) family() int {
+ if a == nil || len(a.IP) <= IPv4len {
+ return syscall.AF_INET
+ }
+ if a.IP.To4() != nil {
+ return syscall.AF_INET
+ }
+ return syscall.AF_INET6
+}
+
+func (a *UDPAddr) sockaddr(family int) (syscall.Sockaddr, error) {
+ if a == nil {
+ return nil, nil
+ }
+ return ipToSockaddr(family, a.IP, a.Port, a.Zone)
+}
+
+func (a *UDPAddr) toLocal(net string) sockaddr {
+ return &UDPAddr{loopbackIP(net), a.Port, a.Zone}
+}
+
+func (c *UDPConn) readFrom(b []byte, addr *UDPAddr) (int, *UDPAddr, error) {
+ var n int
+ var err error
+ switch c.fd.family {
+ case syscall.AF_INET:
+ var from syscall.SockaddrInet4
+ n, err = c.fd.readFromInet4(b, &from)
+ if err == nil {
+ ip := from.Addr // copy from.Addr; ip escapes, so this line allocates 4 bytes
+ *addr = UDPAddr{IP: ip[:], Port: from.Port}
+ }
+ case syscall.AF_INET6:
+ var from syscall.SockaddrInet6
+ n, err = c.fd.readFromInet6(b, &from)
+ if err == nil {
+ ip := from.Addr // copy from.Addr; ip escapes, so this line allocates 16 bytes
+ *addr = UDPAddr{IP: ip[:], Port: from.Port, Zone: zoneCache.name(int(from.ZoneId))}
+ }
+ }
+ if err != nil {
+ // No sockaddr, so don't return UDPAddr.
+ addr = nil
+ }
+ return n, addr, err
+}
+
+func (c *UDPConn) readFromAddrPort(b []byte) (n int, addr netip.AddrPort, err error) {
+ var ip netip.Addr
+ var port int
+ switch c.fd.family {
+ case syscall.AF_INET:
+ var from syscall.SockaddrInet4
+ n, err = c.fd.readFromInet4(b, &from)
+ if err == nil {
+ ip = netip.AddrFrom4(from.Addr)
+ port = from.Port
+ }
+ case syscall.AF_INET6:
+ var from syscall.SockaddrInet6
+ n, err = c.fd.readFromInet6(b, &from)
+ if err == nil {
+ ip = netip.AddrFrom16(from.Addr).WithZone(zoneCache.name(int(from.ZoneId)))
+ port = from.Port
+ }
+ }
+ if err == nil {
+ addr = netip.AddrPortFrom(ip, uint16(port))
+ }
+ return n, addr, err
+}
+
+func (c *UDPConn) readMsg(b, oob []byte) (n, oobn, flags int, addr netip.AddrPort, err error) {
+ switch c.fd.family {
+ case syscall.AF_INET:
+ var sa syscall.SockaddrInet4
+ n, oobn, flags, err = c.fd.readMsgInet4(b, oob, 0, &sa)
+ ip := netip.AddrFrom4(sa.Addr)
+ addr = netip.AddrPortFrom(ip, uint16(sa.Port))
+ case syscall.AF_INET6:
+ var sa syscall.SockaddrInet6
+ n, oobn, flags, err = c.fd.readMsgInet6(b, oob, 0, &sa)
+ ip := netip.AddrFrom16(sa.Addr).WithZone(zoneCache.name(int(sa.ZoneId)))
+ addr = netip.AddrPortFrom(ip, uint16(sa.Port))
+ }
+ return
+}
+
+func (c *UDPConn) writeTo(b []byte, addr *UDPAddr) (int, error) {
+ if c.fd.isConnected {
+ return 0, ErrWriteToConnected
+ }
+ if addr == nil {
+ return 0, errMissingAddress
+ }
+
+ switch c.fd.family {
+ case syscall.AF_INET:
+ sa, err := ipToSockaddrInet4(addr.IP, addr.Port)
+ if err != nil {
+ return 0, err
+ }
+ return c.fd.writeToInet4(b, &sa)
+ case syscall.AF_INET6:
+ sa, err := ipToSockaddrInet6(addr.IP, addr.Port, addr.Zone)
+ if err != nil {
+ return 0, err
+ }
+ return c.fd.writeToInet6(b, &sa)
+ default:
+ return 0, &AddrError{Err: "invalid address family", Addr: addr.IP.String()}
+ }
+}
+
+func (c *UDPConn) writeToAddrPort(b []byte, addr netip.AddrPort) (int, error) {
+ if c.fd.isConnected {
+ return 0, ErrWriteToConnected
+ }
+ if !addr.IsValid() {
+ return 0, errMissingAddress
+ }
+
+ switch c.fd.family {
+ case syscall.AF_INET:
+ sa, err := addrPortToSockaddrInet4(addr)
+ if err != nil {
+ return 0, err
+ }
+ return c.fd.writeToInet4(b, &sa)
+ case syscall.AF_INET6:
+ sa, err := addrPortToSockaddrInet6(addr)
+ if err != nil {
+ return 0, err
+ }
+ return c.fd.writeToInet6(b, &sa)
+ default:
+ return 0, &AddrError{Err: "invalid address family", Addr: addr.Addr().String()}
+ }
+}
+
+func (c *UDPConn) writeMsg(b, oob []byte, addr *UDPAddr) (n, oobn int, err error) {
+ if c.fd.isConnected && addr != nil {
+ return 0, 0, ErrWriteToConnected
+ }
+ if !c.fd.isConnected && addr == nil {
+ return 0, 0, errMissingAddress
+ }
+ sa, err := addr.sockaddr(c.fd.family)
+ if err != nil {
+ return 0, 0, err
+ }
+ return c.fd.writeMsg(b, oob, sa)
+}
+
+func (c *UDPConn) writeMsgAddrPort(b, oob []byte, addr netip.AddrPort) (n, oobn int, err error) {
+ if c.fd.isConnected && addr.IsValid() {
+ return 0, 0, ErrWriteToConnected
+ }
+ if !c.fd.isConnected && !addr.IsValid() {
+ return 0, 0, errMissingAddress
+ }
+
+ switch c.fd.family {
+ case syscall.AF_INET:
+ sa, err := addrPortToSockaddrInet4(addr)
+ if err != nil {
+ return 0, 0, err
+ }
+ return c.fd.writeMsgInet4(b, oob, &sa)
+ case syscall.AF_INET6:
+ sa, err := addrPortToSockaddrInet6(addr)
+ if err != nil {
+ return 0, 0, err
+ }
+ return c.fd.writeMsgInet6(b, oob, &sa)
+ default:
+ return 0, 0, &AddrError{Err: "invalid address family", Addr: addr.Addr().String()}
+ }
+}
+
+func (sd *sysDialer) dialUDP(ctx context.Context, laddr, raddr *UDPAddr) (*UDPConn, error) {
+ ctrlCtxFn := sd.Dialer.ControlContext
+ if ctrlCtxFn == nil && sd.Dialer.Control != nil {
+ ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
+ return sd.Dialer.Control(network, address, c)
+ }
+ }
+ fd, err := internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_DGRAM, 0, "dial", ctrlCtxFn)
+ if err != nil {
+ return nil, err
+ }
+ return newUDPConn(fd), nil
+}
+
+func (sl *sysListener) listenUDP(ctx context.Context, laddr *UDPAddr) (*UDPConn, error) {
+ var ctrlCtxFn func(cxt context.Context, network, address string, c syscall.RawConn) error
+ if sl.ListenConfig.Control != nil {
+ ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
+ return sl.ListenConfig.Control(network, address, c)
+ }
+ }
+ fd, err := internetSocket(ctx, sl.network, laddr, nil, syscall.SOCK_DGRAM, 0, "listen", ctrlCtxFn)
+ if err != nil {
+ return nil, err
+ }
+ return newUDPConn(fd), nil
+}
+
+func (sl *sysListener) listenMulticastUDP(ctx context.Context, ifi *Interface, gaddr *UDPAddr) (*UDPConn, error) {
+ var ctrlCtxFn func(cxt context.Context, network, address string, c syscall.RawConn) error
+ if sl.ListenConfig.Control != nil {
+ ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
+ return sl.ListenConfig.Control(network, address, c)
+ }
+ }
+ fd, err := internetSocket(ctx, sl.network, gaddr, nil, syscall.SOCK_DGRAM, 0, "listen", ctrlCtxFn)
+ if err != nil {
+ return nil, err
+ }
+ c := newUDPConn(fd)
+ if ip4 := gaddr.IP.To4(); ip4 != nil {
+ if err := listenIPv4MulticastUDP(c, ifi, ip4); err != nil {
+ c.Close()
+ return nil, err
+ }
+ } else {
+ if err := listenIPv6MulticastUDP(c, ifi, gaddr.IP); err != nil {
+ c.Close()
+ return nil, err
+ }
+ }
+ return c, nil
+}
+
+func listenIPv4MulticastUDP(c *UDPConn, ifi *Interface, ip IP) error {
+ if ifi != nil {
+ if err := setIPv4MulticastInterface(c.fd, ifi); err != nil {
+ return err
+ }
+ }
+ if err := setIPv4MulticastLoopback(c.fd, false); err != nil {
+ return err
+ }
+ if err := joinIPv4Group(c.fd, ifi, ip); err != nil {
+ return err
+ }
+ return nil
+}
+
+func listenIPv6MulticastUDP(c *UDPConn, ifi *Interface, ip IP) error {
+ if ifi != nil {
+ if err := setIPv6MulticastInterface(c.fd, ifi); err != nil {
+ return err
+ }
+ }
+ if err := setIPv6MulticastLoopback(c.fd, false); err != nil {
+ return err
+ }
+ if err := joinIPv6Group(c.fd, ifi, ip); err != nil {
+ return err
+ }
+ return nil
+}
diff --git a/src/net/udpsock_test.go b/src/net/udpsock_test.go
new file mode 100644
index 0000000..2afd4ac
--- /dev/null
+++ b/src/net/udpsock_test.go
@@ -0,0 +1,666 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !js && !wasip1
+
+package net
+
+import (
+ "errors"
+ "internal/testenv"
+ "net/netip"
+ "os"
+ "reflect"
+ "runtime"
+ "testing"
+ "time"
+)
+
+func BenchmarkUDP6LinkLocalUnicast(b *testing.B) {
+ testHookUninstaller.Do(uninstallTestHooks)
+
+ if !supportsIPv6() {
+ b.Skip("IPv6 is not supported")
+ }
+ ifi := loopbackInterface()
+ if ifi == nil {
+ b.Skip("loopback interface not found")
+ }
+ lla := ipv6LinkLocalUnicastAddr(ifi)
+ if lla == "" {
+ b.Skip("IPv6 link-local unicast address not found")
+ }
+
+ c1, err := ListenPacket("udp6", JoinHostPort(lla+"%"+ifi.Name, "0"))
+ if err != nil {
+ b.Fatal(err)
+ }
+ defer c1.Close()
+ c2, err := ListenPacket("udp6", JoinHostPort(lla+"%"+ifi.Name, "0"))
+ if err != nil {
+ b.Fatal(err)
+ }
+ defer c2.Close()
+
+ var buf [1]byte
+ for i := 0; i < b.N; i++ {
+ if _, err := c1.WriteTo(buf[:], c2.LocalAddr()); err != nil {
+ b.Fatal(err)
+ }
+ if _, _, err := c2.ReadFrom(buf[:]); err != nil {
+ b.Fatal(err)
+ }
+ }
+}
+
+type resolveUDPAddrTest struct {
+ network string
+ litAddrOrName string
+ addr *UDPAddr
+ err error
+}
+
+var resolveUDPAddrTests = []resolveUDPAddrTest{
+ {"udp", "127.0.0.1:0", &UDPAddr{IP: IPv4(127, 0, 0, 1), Port: 0}, nil},
+ {"udp4", "127.0.0.1:65535", &UDPAddr{IP: IPv4(127, 0, 0, 1), Port: 65535}, nil},
+
+ {"udp", "[::1]:0", &UDPAddr{IP: ParseIP("::1"), Port: 0}, nil},
+ {"udp6", "[::1]:65535", &UDPAddr{IP: ParseIP("::1"), Port: 65535}, nil},
+
+ {"udp", "[::1%en0]:1", &UDPAddr{IP: ParseIP("::1"), Port: 1, Zone: "en0"}, nil},
+ {"udp6", "[::1%911]:2", &UDPAddr{IP: ParseIP("::1"), Port: 2, Zone: "911"}, nil},
+
+ {"", "127.0.0.1:0", &UDPAddr{IP: IPv4(127, 0, 0, 1), Port: 0}, nil}, // Go 1.0 behavior
+ {"", "[::1]:0", &UDPAddr{IP: ParseIP("::1"), Port: 0}, nil}, // Go 1.0 behavior
+
+ {"udp", ":12345", &UDPAddr{Port: 12345}, nil},
+
+ {"http", "127.0.0.1:0", nil, UnknownNetworkError("http")},
+
+ {"udp", "127.0.0.1:domain", &UDPAddr{IP: ParseIP("127.0.0.1"), Port: 53}, nil},
+ {"udp", "[::ffff:127.0.0.1]:domain", &UDPAddr{IP: ParseIP("::ffff:127.0.0.1"), Port: 53}, nil},
+ {"udp", "[2001:db8::1]:domain", &UDPAddr{IP: ParseIP("2001:db8::1"), Port: 53}, nil},
+ {"udp4", "127.0.0.1:domain", &UDPAddr{IP: ParseIP("127.0.0.1"), Port: 53}, nil},
+ {"udp4", "[::ffff:127.0.0.1]:domain", &UDPAddr{IP: ParseIP("127.0.0.1"), Port: 53}, nil},
+ {"udp6", "[2001:db8::1]:domain", &UDPAddr{IP: ParseIP("2001:db8::1"), Port: 53}, nil},
+
+ {"udp4", "[2001:db8::1]:domain", nil, &AddrError{Err: errNoSuitableAddress.Error(), Addr: "2001:db8::1"}},
+ {"udp6", "127.0.0.1:domain", nil, &AddrError{Err: errNoSuitableAddress.Error(), Addr: "127.0.0.1"}},
+ {"udp6", "[::ffff:127.0.0.1]:domain", nil, &AddrError{Err: errNoSuitableAddress.Error(), Addr: "::ffff:127.0.0.1"}},
+}
+
+func TestResolveUDPAddr(t *testing.T) {
+ origTestHookLookupIP := testHookLookupIP
+ defer func() { testHookLookupIP = origTestHookLookupIP }()
+ testHookLookupIP = lookupLocalhost
+
+ for _, tt := range resolveUDPAddrTests {
+ addr, err := ResolveUDPAddr(tt.network, tt.litAddrOrName)
+ if !reflect.DeepEqual(addr, tt.addr) || !reflect.DeepEqual(err, tt.err) {
+ t.Errorf("ResolveUDPAddr(%q, %q) = %#v, %v, want %#v, %v", tt.network, tt.litAddrOrName, addr, err, tt.addr, tt.err)
+ continue
+ }
+ if err == nil {
+ addr2, err := ResolveUDPAddr(addr.Network(), addr.String())
+ if !reflect.DeepEqual(addr2, tt.addr) || err != tt.err {
+ t.Errorf("(%q, %q): ResolveUDPAddr(%q, %q) = %#v, %v, want %#v, %v", tt.network, tt.litAddrOrName, addr.Network(), addr.String(), addr2, err, tt.addr, tt.err)
+ }
+ }
+ }
+}
+
+func TestWriteToUDP(t *testing.T) {
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("not supported on %s", runtime.GOOS)
+ }
+
+ c, err := ListenPacket("udp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+
+ testWriteToConn(t, c.LocalAddr().String())
+ testWriteToPacketConn(t, c.LocalAddr().String())
+}
+
+func testWriteToConn(t *testing.T, raddr string) {
+ c, err := Dial("udp", raddr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+
+ ra, err := ResolveUDPAddr("udp", raddr)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ b := []byte("CONNECTED-MODE SOCKET")
+ _, err = c.(*UDPConn).WriteToUDP(b, ra)
+ if err == nil {
+ t.Fatal("should fail")
+ }
+ if err != nil && err.(*OpError).Err != ErrWriteToConnected {
+ t.Fatalf("should fail as ErrWriteToConnected: %v", err)
+ }
+ _, err = c.(*UDPConn).WriteTo(b, ra)
+ if err == nil {
+ t.Fatal("should fail")
+ }
+ if err != nil && err.(*OpError).Err != ErrWriteToConnected {
+ t.Fatalf("should fail as ErrWriteToConnected: %v", err)
+ }
+ _, err = c.Write(b)
+ if err != nil {
+ t.Fatal(err)
+ }
+ _, _, err = c.(*UDPConn).WriteMsgUDP(b, nil, ra)
+ if err == nil {
+ t.Fatal("should fail")
+ }
+ if err != nil && err.(*OpError).Err != ErrWriteToConnected {
+ t.Fatalf("should fail as ErrWriteToConnected: %v", err)
+ }
+ _, _, err = c.(*UDPConn).WriteMsgUDP(b, nil, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+}
+
+func testWriteToPacketConn(t *testing.T, raddr string) {
+ c, err := ListenPacket("udp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+
+ ra, err := ResolveUDPAddr("udp", raddr)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ b := []byte("UNCONNECTED-MODE SOCKET")
+ _, err = c.(*UDPConn).WriteToUDP(b, ra)
+ if err != nil {
+ t.Fatal(err)
+ }
+ _, err = c.WriteTo(b, ra)
+ if err != nil {
+ t.Fatal(err)
+ }
+ _, err = c.(*UDPConn).Write(b)
+ if err == nil {
+ t.Fatal("should fail")
+ }
+ _, _, err = c.(*UDPConn).WriteMsgUDP(b, nil, nil)
+ if err == nil {
+ t.Fatal("should fail")
+ }
+ if err != nil && err.(*OpError).Err != errMissingAddress {
+ t.Fatalf("should fail as errMissingAddress: %v", err)
+ }
+ _, _, err = c.(*UDPConn).WriteMsgUDP(b, nil, ra)
+ if err != nil {
+ t.Fatal(err)
+ }
+}
+
+var udpConnLocalNameTests = []struct {
+ net string
+ laddr *UDPAddr
+}{
+ {"udp4", &UDPAddr{IP: IPv4(127, 0, 0, 1)}},
+ {"udp4", &UDPAddr{}},
+ {"udp4", nil},
+}
+
+func TestUDPConnLocalName(t *testing.T) {
+ testenv.MustHaveExternalNetwork(t)
+
+ for _, tt := range udpConnLocalNameTests {
+ c, err := ListenUDP(tt.net, tt.laddr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+ la := c.LocalAddr()
+ if a, ok := la.(*UDPAddr); !ok || a.Port == 0 {
+ t.Fatalf("got %v; expected a proper address with non-zero port number", la)
+ }
+ }
+}
+
+func TestUDPConnLocalAndRemoteNames(t *testing.T) {
+ for _, laddr := range []string{"", "127.0.0.1:0"} {
+ c1, err := ListenPacket("udp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c1.Close()
+
+ var la *UDPAddr
+ if laddr != "" {
+ var err error
+ if la, err = ResolveUDPAddr("udp", laddr); err != nil {
+ t.Fatal(err)
+ }
+ }
+ c2, err := DialUDP("udp", la, c1.LocalAddr().(*UDPAddr))
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c2.Close()
+
+ var connAddrs = [4]struct {
+ got Addr
+ ok bool
+ }{
+ {c1.LocalAddr(), true},
+ {c1.(*UDPConn).RemoteAddr(), false},
+ {c2.LocalAddr(), true},
+ {c2.RemoteAddr(), true},
+ }
+ for _, ca := range connAddrs {
+ if a, ok := ca.got.(*UDPAddr); ok != ca.ok || ok && a.Port == 0 {
+ t.Fatalf("got %v; expected a proper address with non-zero port number", ca.got)
+ }
+ }
+ }
+}
+
+func TestIPv6LinkLocalUnicastUDP(t *testing.T) {
+ testenv.MustHaveExternalNetwork(t)
+
+ if !supportsIPv6() {
+ t.Skip("IPv6 is not supported")
+ }
+
+ for i, tt := range ipv6LinkLocalUnicastUDPTests {
+ c1, err := ListenPacket(tt.network, tt.address)
+ if err != nil {
+ // It might return "LookupHost returned no
+ // suitable address" error on some platforms.
+ t.Log(err)
+ continue
+ }
+ ls := (&packetListener{PacketConn: c1}).newLocalServer()
+ defer ls.teardown()
+ ch := make(chan error, 1)
+ handler := func(ls *localPacketServer, c PacketConn) { packetTransponder(c, ch) }
+ if err := ls.buildup(handler); err != nil {
+ t.Fatal(err)
+ }
+ if la, ok := c1.LocalAddr().(*UDPAddr); !ok || !tt.nameLookup && la.Zone == "" {
+ t.Fatalf("got %v; expected a proper address with zone identifier", la)
+ }
+
+ c2, err := Dial(tt.network, ls.PacketConn.LocalAddr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c2.Close()
+ if la, ok := c2.LocalAddr().(*UDPAddr); !ok || !tt.nameLookup && la.Zone == "" {
+ t.Fatalf("got %v; expected a proper address with zone identifier", la)
+ }
+ if ra, ok := c2.RemoteAddr().(*UDPAddr); !ok || !tt.nameLookup && ra.Zone == "" {
+ t.Fatalf("got %v; expected a proper address with zone identifier", ra)
+ }
+
+ if _, err := c2.Write([]byte("UDP OVER IPV6 LINKLOCAL TEST")); err != nil {
+ t.Fatal(err)
+ }
+ b := make([]byte, 32)
+ if _, err := c2.Read(b); err != nil {
+ t.Fatal(err)
+ }
+
+ for err := range ch {
+ t.Errorf("#%d: %v", i, err)
+ }
+ }
+}
+
+func TestUDPZeroBytePayload(t *testing.T) {
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("not supported on %s", runtime.GOOS)
+ case "darwin", "ios":
+ testenv.SkipFlaky(t, 29225)
+ }
+
+ c := newLocalPacketListener(t, "udp")
+ defer c.Close()
+
+ for _, genericRead := range []bool{false, true} {
+ n, err := c.WriteTo(nil, c.LocalAddr())
+ if err != nil {
+ t.Fatal(err)
+ }
+ if n != 0 {
+ t.Errorf("got %d; want 0", n)
+ }
+ c.SetReadDeadline(time.Now().Add(30 * time.Second))
+ var b [1]byte
+ var name string
+ if genericRead {
+ _, err = c.(Conn).Read(b[:])
+ name = "Read"
+ } else {
+ _, _, err = c.ReadFrom(b[:])
+ name = "ReadFrom"
+ }
+ if err != nil {
+ t.Errorf("%s of zero byte packet failed: %v", name, err)
+ }
+ }
+}
+
+func TestUDPZeroByteBuffer(t *testing.T) {
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("not supported on %s", runtime.GOOS)
+ }
+
+ c := newLocalPacketListener(t, "udp")
+ defer c.Close()
+
+ b := []byte("UDP ZERO BYTE BUFFER TEST")
+ for _, genericRead := range []bool{false, true} {
+ n, err := c.WriteTo(b, c.LocalAddr())
+ if err != nil {
+ t.Fatal(err)
+ }
+ if n != len(b) {
+ t.Errorf("got %d; want %d", n, len(b))
+ }
+ c.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
+ if genericRead {
+ _, err = c.(Conn).Read(nil)
+ } else {
+ _, _, err = c.ReadFrom(nil)
+ }
+ switch err {
+ case nil: // ReadFrom succeeds
+ default: // Read may timeout, it depends on the platform
+ if nerr, ok := err.(Error); (!ok || !nerr.Timeout()) && runtime.GOOS != "windows" { // Windows returns WSAEMSGSIZE
+ t.Fatal(err)
+ }
+ }
+ }
+}
+
+func TestUDPReadSizeError(t *testing.T) {
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("not supported on %s", runtime.GOOS)
+ }
+
+ c1 := newLocalPacketListener(t, "udp")
+ defer c1.Close()
+
+ c2, err := Dial("udp", c1.LocalAddr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c2.Close()
+
+ b1 := []byte("READ SIZE ERROR TEST")
+ for _, genericRead := range []bool{false, true} {
+ n, err := c2.Write(b1)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if n != len(b1) {
+ t.Errorf("got %d; want %d", n, len(b1))
+ }
+ b2 := make([]byte, len(b1)-1)
+ if genericRead {
+ n, err = c1.(Conn).Read(b2)
+ } else {
+ n, _, err = c1.ReadFrom(b2)
+ }
+ if err != nil && runtime.GOOS != "windows" { // Windows returns WSAEMSGSIZE
+ t.Fatal(err)
+ }
+ if n != len(b1)-1 {
+ t.Fatalf("got %d; want %d", n, len(b1)-1)
+ }
+ }
+}
+
+// TestUDPReadTimeout verifies that ReadFromUDP with timeout returns an error
+// without data or an address.
+func TestUDPReadTimeout(t *testing.T) {
+ la, err := ResolveUDPAddr("udp4", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ c, err := ListenUDP("udp4", la)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+
+ c.SetDeadline(time.Now())
+ b := make([]byte, 1)
+ n, addr, err := c.ReadFromUDP(b)
+ if !errors.Is(err, os.ErrDeadlineExceeded) {
+ t.Errorf("ReadFromUDP got err %v want os.ErrDeadlineExceeded", err)
+ }
+ if n != 0 {
+ t.Errorf("ReadFromUDP got n %d want 0", n)
+ }
+ if addr != nil {
+ t.Errorf("ReadFromUDP got addr %+#v want nil", addr)
+ }
+}
+
+func TestAllocs(t *testing.T) {
+ switch runtime.GOOS {
+ case "plan9":
+ // Plan9 wasn't optimized.
+ t.Skipf("skipping on %v", runtime.GOOS)
+ }
+ // Optimizations are required to remove the allocs.
+ testenv.SkipIfOptimizationOff(t)
+
+ conn, err := ListenUDP("udp4", &UDPAddr{IP: IPv4(127, 0, 0, 1)})
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conn.Close()
+ addr := conn.LocalAddr()
+ addrPort := addr.(*UDPAddr).AddrPort()
+ buf := make([]byte, 8)
+
+ allocs := testing.AllocsPerRun(1000, func() {
+ _, _, err := conn.WriteMsgUDPAddrPort(buf, nil, addrPort)
+ if err != nil {
+ t.Fatal(err)
+ }
+ _, _, _, _, err = conn.ReadMsgUDPAddrPort(buf, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ })
+ if got := int(allocs); got != 0 {
+ t.Errorf("WriteMsgUDPAddrPort/ReadMsgUDPAddrPort allocated %d objects", got)
+ }
+
+ allocs = testing.AllocsPerRun(1000, func() {
+ _, err := conn.WriteToUDPAddrPort(buf, addrPort)
+ if err != nil {
+ t.Fatal(err)
+ }
+ _, _, err = conn.ReadFromUDPAddrPort(buf)
+ if err != nil {
+ t.Fatal(err)
+ }
+ })
+ if got := int(allocs); got != 0 {
+ t.Errorf("WriteToUDPAddrPort/ReadFromUDPAddrPort allocated %d objects", got)
+ }
+
+ allocs = testing.AllocsPerRun(1000, func() {
+ _, err := conn.WriteTo(buf, addr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ _, _, err = conn.ReadFromUDP(buf)
+ if err != nil {
+ t.Fatal(err)
+ }
+ })
+ if got := int(allocs); got != 1 {
+ t.Errorf("WriteTo/ReadFromUDP allocated %d objects", got)
+ }
+}
+
+func BenchmarkReadWriteMsgUDPAddrPort(b *testing.B) {
+ conn, err := ListenUDP("udp4", &UDPAddr{IP: IPv4(127, 0, 0, 1)})
+ if err != nil {
+ b.Fatal(err)
+ }
+ defer conn.Close()
+ addr := conn.LocalAddr().(*UDPAddr).AddrPort()
+ buf := make([]byte, 8)
+ b.ResetTimer()
+ b.ReportAllocs()
+ for i := 0; i < b.N; i++ {
+ _, _, err := conn.WriteMsgUDPAddrPort(buf, nil, addr)
+ if err != nil {
+ b.Fatal(err)
+ }
+ _, _, _, _, err = conn.ReadMsgUDPAddrPort(buf, nil)
+ if err != nil {
+ b.Fatal(err)
+ }
+ }
+}
+
+func BenchmarkWriteToReadFromUDP(b *testing.B) {
+ conn, err := ListenUDP("udp4", &UDPAddr{IP: IPv4(127, 0, 0, 1)})
+ if err != nil {
+ b.Fatal(err)
+ }
+ defer conn.Close()
+ addr := conn.LocalAddr()
+ buf := make([]byte, 8)
+ b.ResetTimer()
+ b.ReportAllocs()
+ for i := 0; i < b.N; i++ {
+ _, err := conn.WriteTo(buf, addr)
+ if err != nil {
+ b.Fatal(err)
+ }
+ _, _, err = conn.ReadFromUDP(buf)
+ if err != nil {
+ b.Fatal(err)
+ }
+ }
+}
+
+func BenchmarkWriteToReadFromUDPAddrPort(b *testing.B) {
+ conn, err := ListenUDP("udp4", &UDPAddr{IP: IPv4(127, 0, 0, 1)})
+ if err != nil {
+ b.Fatal(err)
+ }
+ defer conn.Close()
+ addr := conn.LocalAddr().(*UDPAddr).AddrPort()
+ buf := make([]byte, 8)
+ b.ResetTimer()
+ b.ReportAllocs()
+ for i := 0; i < b.N; i++ {
+ _, err := conn.WriteToUDPAddrPort(buf, addr)
+ if err != nil {
+ b.Fatal(err)
+ }
+ _, _, err = conn.ReadFromUDPAddrPort(buf)
+ if err != nil {
+ b.Fatal(err)
+ }
+ }
+}
+
+func TestUDPIPVersionReadMsg(t *testing.T) {
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("skipping on %v", runtime.GOOS)
+ }
+ conn, err := ListenUDP("udp4", &UDPAddr{IP: IPv4(127, 0, 0, 1)})
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conn.Close()
+ daddr := conn.LocalAddr().(*UDPAddr).AddrPort()
+ buf := make([]byte, 8)
+ _, err = conn.WriteToUDPAddrPort(buf, daddr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ _, _, _, saddr, err := conn.ReadMsgUDPAddrPort(buf, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !saddr.Addr().Is4() {
+ t.Error("returned AddrPort is not IPv4")
+ }
+ _, err = conn.WriteToUDPAddrPort(buf, daddr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ _, _, _, soldaddr, err := conn.ReadMsgUDP(buf, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(soldaddr.IP) != 4 {
+ t.Error("returned UDPAddr is not IPv4")
+ }
+}
+
+// TestIPv6WriteMsgUDPAddrPortTargetAddrIPVersion verifies that
+// WriteMsgUDPAddrPort accepts IPv4, IPv4-mapped IPv6, and IPv6 target addresses
+// on a UDPConn listening on "::".
+func TestIPv6WriteMsgUDPAddrPortTargetAddrIPVersion(t *testing.T) {
+ if !supportsIPv6() {
+ t.Skip("IPv6 is not supported")
+ }
+
+ switch runtime.GOOS {
+ case "dragonfly", "openbsd":
+ // DragonflyBSD's IPv6 sockets are always IPv6-only, according to the man page:
+ // https://www.dragonflybsd.org/cgi/web-man?command=ip6 (search for IPV6_V6ONLY).
+ // OpenBSD's IPv6 sockets are always IPv6-only, according to the man page:
+ // https://man.openbsd.org/ip6#IPV6_V6ONLY
+ t.Skipf("skipping on %v", runtime.GOOS)
+ }
+
+ conn, err := ListenUDP("udp", nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conn.Close()
+
+ daddr4 := netip.AddrPortFrom(netip.MustParseAddr("127.0.0.1"), 12345)
+ daddr4in6 := netip.AddrPortFrom(netip.MustParseAddr("::ffff:127.0.0.1"), 12345)
+ daddr6 := netip.AddrPortFrom(netip.MustParseAddr("::1"), 12345)
+ buf := make([]byte, 8)
+
+ _, _, err = conn.WriteMsgUDPAddrPort(buf, nil, daddr4)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ _, _, err = conn.WriteMsgUDPAddrPort(buf, nil, daddr4in6)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ _, _, err = conn.WriteMsgUDPAddrPort(buf, nil, daddr6)
+ if err != nil {
+ t.Fatal(err)
+ }
+}
diff --git a/src/net/unixsock.go b/src/net/unixsock.go
new file mode 100644
index 0000000..14fbac0
--- /dev/null
+++ b/src/net/unixsock.go
@@ -0,0 +1,352 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "context"
+ "os"
+ "sync"
+ "syscall"
+ "time"
+)
+
+// BUG(mikio): On JS, WASIP1 and Plan 9, methods and functions related
+// to UnixConn and UnixListener are not implemented.
+
+// BUG(mikio): On Windows, methods and functions related to UnixConn
+// and UnixListener don't work for "unixgram" and "unixpacket".
+
+// UnixAddr represents the address of a Unix domain socket end point.
+type UnixAddr struct {
+ Name string
+ Net string
+}
+
+// Network returns the address's network name, "unix", "unixgram" or
+// "unixpacket".
+func (a *UnixAddr) Network() string {
+ return a.Net
+}
+
+func (a *UnixAddr) String() string {
+ if a == nil {
+ return "<nil>"
+ }
+ return a.Name
+}
+
+func (a *UnixAddr) isWildcard() bool {
+ return a == nil || a.Name == ""
+}
+
+func (a *UnixAddr) opAddr() Addr {
+ if a == nil {
+ return nil
+ }
+ return a
+}
+
+// ResolveUnixAddr returns an address of Unix domain socket end point.
+//
+// The network must be a Unix network name.
+//
+// See func Dial for a description of the network and address
+// parameters.
+func ResolveUnixAddr(network, address string) (*UnixAddr, error) {
+ switch network {
+ case "unix", "unixgram", "unixpacket":
+ return &UnixAddr{Name: address, Net: network}, nil
+ default:
+ return nil, UnknownNetworkError(network)
+ }
+}
+
+// UnixConn is an implementation of the Conn interface for connections
+// to Unix domain sockets.
+type UnixConn struct {
+ conn
+}
+
+// SyscallConn returns a raw network connection.
+// This implements the syscall.Conn interface.
+func (c *UnixConn) SyscallConn() (syscall.RawConn, error) {
+ if !c.ok() {
+ return nil, syscall.EINVAL
+ }
+ return newRawConn(c.fd)
+}
+
+// CloseRead shuts down the reading side of the Unix domain connection.
+// Most callers should just use Close.
+func (c *UnixConn) CloseRead() error {
+ if !c.ok() {
+ return syscall.EINVAL
+ }
+ if err := c.fd.closeRead(); err != nil {
+ return &OpError{Op: "close", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
+ }
+ return nil
+}
+
+// CloseWrite shuts down the writing side of the Unix domain connection.
+// Most callers should just use Close.
+func (c *UnixConn) CloseWrite() error {
+ if !c.ok() {
+ return syscall.EINVAL
+ }
+ if err := c.fd.closeWrite(); err != nil {
+ return &OpError{Op: "close", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
+ }
+ return nil
+}
+
+// ReadFromUnix acts like ReadFrom but returns a UnixAddr.
+func (c *UnixConn) ReadFromUnix(b []byte) (int, *UnixAddr, error) {
+ if !c.ok() {
+ return 0, nil, syscall.EINVAL
+ }
+ n, addr, err := c.readFrom(b)
+ if err != nil {
+ err = &OpError{Op: "read", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
+ }
+ return n, addr, err
+}
+
+// ReadFrom implements the PacketConn ReadFrom method.
+func (c *UnixConn) ReadFrom(b []byte) (int, Addr, error) {
+ if !c.ok() {
+ return 0, nil, syscall.EINVAL
+ }
+ n, addr, err := c.readFrom(b)
+ if err != nil {
+ err = &OpError{Op: "read", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
+ }
+ if addr == nil {
+ return n, nil, err
+ }
+ return n, addr, err
+}
+
+// ReadMsgUnix reads a message from c, copying the payload into b and
+// the associated out-of-band data into oob. It returns the number of
+// bytes copied into b, the number of bytes copied into oob, the flags
+// that were set on the message and the source address of the message.
+//
+// Note that if len(b) == 0 and len(oob) > 0, this function will still
+// read (and discard) 1 byte from the connection.
+func (c *UnixConn) ReadMsgUnix(b, oob []byte) (n, oobn, flags int, addr *UnixAddr, err error) {
+ if !c.ok() {
+ return 0, 0, 0, nil, syscall.EINVAL
+ }
+ n, oobn, flags, addr, err = c.readMsg(b, oob)
+ if err != nil {
+ err = &OpError{Op: "read", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
+ }
+ return
+}
+
+// WriteToUnix acts like WriteTo but takes a UnixAddr.
+func (c *UnixConn) WriteToUnix(b []byte, addr *UnixAddr) (int, error) {
+ if !c.ok() {
+ return 0, syscall.EINVAL
+ }
+ n, err := c.writeTo(b, addr)
+ if err != nil {
+ err = &OpError{Op: "write", Net: c.fd.net, Source: c.fd.laddr, Addr: addr.opAddr(), Err: err}
+ }
+ return n, err
+}
+
+// WriteTo implements the PacketConn WriteTo method.
+func (c *UnixConn) WriteTo(b []byte, addr Addr) (int, error) {
+ if !c.ok() {
+ return 0, syscall.EINVAL
+ }
+ a, ok := addr.(*UnixAddr)
+ if !ok {
+ return 0, &OpError{Op: "write", Net: c.fd.net, Source: c.fd.laddr, Addr: addr, Err: syscall.EINVAL}
+ }
+ n, err := c.writeTo(b, a)
+ if err != nil {
+ err = &OpError{Op: "write", Net: c.fd.net, Source: c.fd.laddr, Addr: a.opAddr(), Err: err}
+ }
+ return n, err
+}
+
+// WriteMsgUnix writes a message to addr via c, copying the payload
+// from b and the associated out-of-band data from oob. It returns the
+// number of payload and out-of-band bytes written.
+//
+// Note that if len(b) == 0 and len(oob) > 0, this function will still
+// write 1 byte to the connection.
+func (c *UnixConn) WriteMsgUnix(b, oob []byte, addr *UnixAddr) (n, oobn int, err error) {
+ if !c.ok() {
+ return 0, 0, syscall.EINVAL
+ }
+ n, oobn, err = c.writeMsg(b, oob, addr)
+ if err != nil {
+ err = &OpError{Op: "write", Net: c.fd.net, Source: c.fd.laddr, Addr: addr.opAddr(), Err: err}
+ }
+ return
+}
+
+func newUnixConn(fd *netFD) *UnixConn { return &UnixConn{conn{fd}} }
+
+// DialUnix acts like Dial for Unix networks.
+//
+// The network must be a Unix network name; see func Dial for details.
+//
+// If laddr is non-nil, it is used as the local address for the
+// connection.
+func DialUnix(network string, laddr, raddr *UnixAddr) (*UnixConn, error) {
+ switch network {
+ case "unix", "unixgram", "unixpacket":
+ default:
+ return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: UnknownNetworkError(network)}
+ }
+ sd := &sysDialer{network: network, address: raddr.String()}
+ c, err := sd.dialUnix(context.Background(), laddr, raddr)
+ if err != nil {
+ return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
+ }
+ return c, nil
+}
+
+// UnixListener is a Unix domain socket listener. Clients should
+// typically use variables of type Listener instead of assuming Unix
+// domain sockets.
+type UnixListener struct {
+ fd *netFD
+ path string
+ unlink bool
+ unlinkOnce sync.Once
+}
+
+func (ln *UnixListener) ok() bool { return ln != nil && ln.fd != nil }
+
+// SyscallConn returns a raw network connection.
+// This implements the syscall.Conn interface.
+//
+// The returned RawConn only supports calling Control. Read and
+// Write return an error.
+func (l *UnixListener) SyscallConn() (syscall.RawConn, error) {
+ if !l.ok() {
+ return nil, syscall.EINVAL
+ }
+ return newRawListener(l.fd)
+}
+
+// AcceptUnix accepts the next incoming call and returns the new
+// connection.
+func (l *UnixListener) AcceptUnix() (*UnixConn, error) {
+ if !l.ok() {
+ return nil, syscall.EINVAL
+ }
+ c, err := l.accept()
+ if err != nil {
+ return nil, &OpError{Op: "accept", Net: l.fd.net, Source: nil, Addr: l.fd.laddr, Err: err}
+ }
+ return c, nil
+}
+
+// Accept implements the Accept method in the Listener interface.
+// Returned connections will be of type *UnixConn.
+func (l *UnixListener) Accept() (Conn, error) {
+ if !l.ok() {
+ return nil, syscall.EINVAL
+ }
+ c, err := l.accept()
+ if err != nil {
+ return nil, &OpError{Op: "accept", Net: l.fd.net, Source: nil, Addr: l.fd.laddr, Err: err}
+ }
+ return c, nil
+}
+
+// Close stops listening on the Unix address. Already accepted
+// connections are not closed.
+func (l *UnixListener) Close() error {
+ if !l.ok() {
+ return syscall.EINVAL
+ }
+ if err := l.close(); err != nil {
+ return &OpError{Op: "close", Net: l.fd.net, Source: nil, Addr: l.fd.laddr, Err: err}
+ }
+ return nil
+}
+
+// Addr returns the listener's network address.
+// The Addr returned is shared by all invocations of Addr, so
+// do not modify it.
+func (l *UnixListener) Addr() Addr { return l.fd.laddr }
+
+// SetDeadline sets the deadline associated with the listener.
+// A zero time value disables the deadline.
+func (l *UnixListener) SetDeadline(t time.Time) error {
+ if !l.ok() {
+ return syscall.EINVAL
+ }
+ if err := l.fd.pfd.SetDeadline(t); err != nil {
+ return &OpError{Op: "set", Net: l.fd.net, Source: nil, Addr: l.fd.laddr, Err: err}
+ }
+ return nil
+}
+
+// File returns a copy of the underlying os.File.
+// It is the caller's responsibility to close f when finished.
+// Closing l does not affect f, and closing f does not affect l.
+//
+// The returned os.File's file descriptor is different from the
+// connection's. Attempting to change properties of the original
+// using this duplicate may or may not have the desired effect.
+func (l *UnixListener) File() (f *os.File, err error) {
+ if !l.ok() {
+ return nil, syscall.EINVAL
+ }
+ f, err = l.file()
+ if err != nil {
+ err = &OpError{Op: "file", Net: l.fd.net, Source: nil, Addr: l.fd.laddr, Err: err}
+ }
+ return
+}
+
+// ListenUnix acts like Listen for Unix networks.
+//
+// The network must be "unix" or "unixpacket".
+func ListenUnix(network string, laddr *UnixAddr) (*UnixListener, error) {
+ switch network {
+ case "unix", "unixpacket":
+ default:
+ return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: laddr.opAddr(), Err: UnknownNetworkError(network)}
+ }
+ if laddr == nil {
+ return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: laddr.opAddr(), Err: errMissingAddress}
+ }
+ sl := &sysListener{network: network, address: laddr.String()}
+ ln, err := sl.listenUnix(context.Background(), laddr)
+ if err != nil {
+ return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: laddr.opAddr(), Err: err}
+ }
+ return ln, nil
+}
+
+// ListenUnixgram acts like ListenPacket for Unix networks.
+//
+// The network must be "unixgram".
+func ListenUnixgram(network string, laddr *UnixAddr) (*UnixConn, error) {
+ switch network {
+ case "unixgram":
+ default:
+ return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: laddr.opAddr(), Err: UnknownNetworkError(network)}
+ }
+ if laddr == nil {
+ return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: nil, Err: errMissingAddress}
+ }
+ sl := &sysListener{network: network, address: laddr.String()}
+ c, err := sl.listenUnixgram(context.Background(), laddr)
+ if err != nil {
+ return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: laddr.opAddr(), Err: err}
+ }
+ return c, nil
+}
diff --git a/src/net/unixsock_linux_test.go b/src/net/unixsock_linux_test.go
new file mode 100644
index 0000000..d04007c
--- /dev/null
+++ b/src/net/unixsock_linux_test.go
@@ -0,0 +1,104 @@
+// Copyright 2017 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "bytes"
+ "reflect"
+ "syscall"
+ "testing"
+ "time"
+)
+
+func TestUnixgramAutobind(t *testing.T) {
+ laddr := &UnixAddr{Name: "", Net: "unixgram"}
+ c1, err := ListenUnixgram("unixgram", laddr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c1.Close()
+
+ // retrieve the autobind address
+ autoAddr := c1.LocalAddr().(*UnixAddr)
+ if len(autoAddr.Name) <= 1 {
+ t.Fatalf("invalid autobind address: %v", autoAddr)
+ }
+ if autoAddr.Name[0] != '@' {
+ t.Fatalf("invalid autobind address: %v", autoAddr)
+ }
+
+ c2, err := DialUnix("unixgram", nil, autoAddr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c2.Close()
+
+ if !reflect.DeepEqual(c1.LocalAddr(), c2.RemoteAddr()) {
+ t.Fatalf("expected autobind address %v, got %v", c1.LocalAddr(), c2.RemoteAddr())
+ }
+}
+
+func TestUnixAutobindClose(t *testing.T) {
+ laddr := &UnixAddr{Name: "", Net: "unix"}
+ ln, err := ListenUnix("unix", laddr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ ln.Close()
+}
+
+func TestUnixgramLinuxAbstractLongName(t *testing.T) {
+ if !testableNetwork("unixgram") {
+ t.Skip("abstract unix socket long name test")
+ }
+
+ // Create an abstract socket name whose length is exactly
+ // the maximum RawSockkaddrUnix Path len
+ rsu := syscall.RawSockaddrUnix{}
+ addrBytes := make([]byte, len(rsu.Path))
+ copy(addrBytes, "@abstract_test")
+ addr := string(addrBytes)
+
+ la, err := ResolveUnixAddr("unixgram", addr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ c, err := ListenUnixgram("unixgram", la)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+
+ off := make(chan bool)
+ data := [5]byte{1, 2, 3, 4, 5}
+ go func() {
+ defer func() { off <- true }()
+ s, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_DGRAM, 0)
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ defer syscall.Close(s)
+ rsa := &syscall.SockaddrUnix{Name: addr}
+ if err := syscall.Sendto(s, data[:], 0, rsa); err != nil {
+ t.Error(err)
+ return
+ }
+ }()
+
+ <-off
+ b := make([]byte, 64)
+ c.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
+ n, from, err := c.ReadFrom(b)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if from != nil {
+ t.Fatalf("unexpected peer address: %v", from)
+ }
+ if !bytes.Equal(b[:n], data[:]) {
+ t.Fatalf("got %v; want %v", b[:n], data[:])
+ }
+}
diff --git a/src/net/unixsock_plan9.go b/src/net/unixsock_plan9.go
new file mode 100644
index 0000000..6ebd4d7
--- /dev/null
+++ b/src/net/unixsock_plan9.go
@@ -0,0 +1,51 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "context"
+ "os"
+ "syscall"
+)
+
+func (c *UnixConn) readFrom(b []byte) (int, *UnixAddr, error) {
+ return 0, nil, syscall.EPLAN9
+}
+
+func (c *UnixConn) readMsg(b, oob []byte) (n, oobn, flags int, addr *UnixAddr, err error) {
+ return 0, 0, 0, nil, syscall.EPLAN9
+}
+
+func (c *UnixConn) writeTo(b []byte, addr *UnixAddr) (int, error) {
+ return 0, syscall.EPLAN9
+}
+
+func (c *UnixConn) writeMsg(b, oob []byte, addr *UnixAddr) (n, oobn int, err error) {
+ return 0, 0, syscall.EPLAN9
+}
+
+func (sd *sysDialer) dialUnix(ctx context.Context, laddr, raddr *UnixAddr) (*UnixConn, error) {
+ return nil, syscall.EPLAN9
+}
+
+func (ln *UnixListener) accept() (*UnixConn, error) {
+ return nil, syscall.EPLAN9
+}
+
+func (ln *UnixListener) close() error {
+ return syscall.EPLAN9
+}
+
+func (ln *UnixListener) file() (*os.File, error) {
+ return nil, syscall.EPLAN9
+}
+
+func (sl *sysListener) listenUnix(ctx context.Context, laddr *UnixAddr) (*UnixListener, error) {
+ return nil, syscall.EPLAN9
+}
+
+func (sl *sysListener) listenUnixgram(ctx context.Context, laddr *UnixAddr) (*UnixConn, error) {
+ return nil, syscall.EPLAN9
+}
diff --git a/src/net/unixsock_posix.go b/src/net/unixsock_posix.go
new file mode 100644
index 0000000..c501b49
--- /dev/null
+++ b/src/net/unixsock_posix.go
@@ -0,0 +1,245 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build unix || (js && wasm) || wasip1 || windows
+
+package net
+
+import (
+ "context"
+ "errors"
+ "os"
+ "syscall"
+)
+
+func unixSocket(ctx context.Context, net string, laddr, raddr sockaddr, mode string, ctxCtrlFn func(context.Context, string, string, syscall.RawConn) error) (*netFD, error) {
+ var sotype int
+ switch net {
+ case "unix":
+ sotype = syscall.SOCK_STREAM
+ case "unixgram":
+ sotype = syscall.SOCK_DGRAM
+ case "unixpacket":
+ sotype = syscall.SOCK_SEQPACKET
+ default:
+ return nil, UnknownNetworkError(net)
+ }
+
+ switch mode {
+ case "dial":
+ if laddr != nil && laddr.isWildcard() {
+ laddr = nil
+ }
+ if raddr != nil && raddr.isWildcard() {
+ raddr = nil
+ }
+ if raddr == nil && (sotype != syscall.SOCK_DGRAM || laddr == nil) {
+ return nil, errMissingAddress
+ }
+ case "listen":
+ default:
+ return nil, errors.New("unknown mode: " + mode)
+ }
+
+ fd, err := socket(ctx, net, syscall.AF_UNIX, sotype, 0, false, laddr, raddr, ctxCtrlFn)
+ if err != nil {
+ return nil, err
+ }
+ return fd, nil
+}
+
+func sockaddrToUnix(sa syscall.Sockaddr) Addr {
+ if s, ok := sa.(*syscall.SockaddrUnix); ok {
+ return &UnixAddr{Name: s.Name, Net: "unix"}
+ }
+ return nil
+}
+
+func sockaddrToUnixgram(sa syscall.Sockaddr) Addr {
+ if s, ok := sa.(*syscall.SockaddrUnix); ok {
+ return &UnixAddr{Name: s.Name, Net: "unixgram"}
+ }
+ return nil
+}
+
+func sockaddrToUnixpacket(sa syscall.Sockaddr) Addr {
+ if s, ok := sa.(*syscall.SockaddrUnix); ok {
+ return &UnixAddr{Name: s.Name, Net: "unixpacket"}
+ }
+ return nil
+}
+
+func sotypeToNet(sotype int) string {
+ switch sotype {
+ case syscall.SOCK_STREAM:
+ return "unix"
+ case syscall.SOCK_DGRAM:
+ return "unixgram"
+ case syscall.SOCK_SEQPACKET:
+ return "unixpacket"
+ default:
+ panic("sotypeToNet unknown socket type")
+ }
+}
+
+func (a *UnixAddr) family() int {
+ return syscall.AF_UNIX
+}
+
+func (a *UnixAddr) sockaddr(family int) (syscall.Sockaddr, error) {
+ if a == nil {
+ return nil, nil
+ }
+ return &syscall.SockaddrUnix{Name: a.Name}, nil
+}
+
+func (a *UnixAddr) toLocal(net string) sockaddr {
+ return a
+}
+
+func (c *UnixConn) readFrom(b []byte) (int, *UnixAddr, error) {
+ var addr *UnixAddr
+ n, sa, err := c.fd.readFrom(b)
+ switch sa := sa.(type) {
+ case *syscall.SockaddrUnix:
+ if sa.Name != "" {
+ addr = &UnixAddr{Name: sa.Name, Net: sotypeToNet(c.fd.sotype)}
+ }
+ }
+ return n, addr, err
+}
+
+func (c *UnixConn) readMsg(b, oob []byte) (n, oobn, flags int, addr *UnixAddr, err error) {
+ var sa syscall.Sockaddr
+ n, oobn, flags, sa, err = c.fd.readMsg(b, oob, readMsgFlags)
+ if readMsgFlags == 0 && err == nil && oobn > 0 {
+ setReadMsgCloseOnExec(oob[:oobn])
+ }
+
+ switch sa := sa.(type) {
+ case *syscall.SockaddrUnix:
+ if sa.Name != "" {
+ addr = &UnixAddr{Name: sa.Name, Net: sotypeToNet(c.fd.sotype)}
+ }
+ }
+ return
+}
+
+func (c *UnixConn) writeTo(b []byte, addr *UnixAddr) (int, error) {
+ if c.fd.isConnected {
+ return 0, ErrWriteToConnected
+ }
+ if addr == nil {
+ return 0, errMissingAddress
+ }
+ if addr.Net != sotypeToNet(c.fd.sotype) {
+ return 0, syscall.EAFNOSUPPORT
+ }
+ sa := &syscall.SockaddrUnix{Name: addr.Name}
+ return c.fd.writeTo(b, sa)
+}
+
+func (c *UnixConn) writeMsg(b, oob []byte, addr *UnixAddr) (n, oobn int, err error) {
+ if c.fd.sotype == syscall.SOCK_DGRAM && c.fd.isConnected {
+ return 0, 0, ErrWriteToConnected
+ }
+ var sa syscall.Sockaddr
+ if addr != nil {
+ if addr.Net != sotypeToNet(c.fd.sotype) {
+ return 0, 0, syscall.EAFNOSUPPORT
+ }
+ sa = &syscall.SockaddrUnix{Name: addr.Name}
+ }
+ return c.fd.writeMsg(b, oob, sa)
+}
+
+func (sd *sysDialer) dialUnix(ctx context.Context, laddr, raddr *UnixAddr) (*UnixConn, error) {
+ ctrlCtxFn := sd.Dialer.ControlContext
+ if ctrlCtxFn == nil && sd.Dialer.Control != nil {
+ ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
+ return sd.Dialer.Control(network, address, c)
+ }
+ }
+ fd, err := unixSocket(ctx, sd.network, laddr, raddr, "dial", ctrlCtxFn)
+ if err != nil {
+ return nil, err
+ }
+ return newUnixConn(fd), nil
+}
+
+func (ln *UnixListener) accept() (*UnixConn, error) {
+ fd, err := ln.fd.accept()
+ if err != nil {
+ return nil, err
+ }
+ return newUnixConn(fd), nil
+}
+
+func (ln *UnixListener) close() error {
+ // The operating system doesn't clean up
+ // the file that announcing created, so
+ // we have to clean it up ourselves.
+ // There's a race here--we can't know for
+ // sure whether someone else has come along
+ // and replaced our socket name already--
+ // but this sequence (remove then close)
+ // is at least compatible with the auto-remove
+ // sequence in ListenUnix. It's only non-Go
+ // programs that can mess us up.
+ // Even if there are racy calls to Close, we want to unlink only for the first one.
+ ln.unlinkOnce.Do(func() {
+ if ln.path[0] != '@' && ln.unlink {
+ syscall.Unlink(ln.path)
+ }
+ })
+ return ln.fd.Close()
+}
+
+func (ln *UnixListener) file() (*os.File, error) {
+ f, err := ln.fd.dup()
+ if err != nil {
+ return nil, err
+ }
+ return f, nil
+}
+
+// SetUnlinkOnClose sets whether the underlying socket file should be removed
+// from the file system when the listener is closed.
+//
+// The default behavior is to unlink the socket file only when package net created it.
+// That is, when the listener and the underlying socket file were created by a call to
+// Listen or ListenUnix, then by default closing the listener will remove the socket file.
+// but if the listener was created by a call to FileListener to use an already existing
+// socket file, then by default closing the listener will not remove the socket file.
+func (l *UnixListener) SetUnlinkOnClose(unlink bool) {
+ l.unlink = unlink
+}
+
+func (sl *sysListener) listenUnix(ctx context.Context, laddr *UnixAddr) (*UnixListener, error) {
+ var ctrlCtxFn func(cxt context.Context, network, address string, c syscall.RawConn) error
+ if sl.ListenConfig.Control != nil {
+ ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
+ return sl.ListenConfig.Control(network, address, c)
+ }
+ }
+ fd, err := unixSocket(ctx, sl.network, laddr, nil, "listen", ctrlCtxFn)
+ if err != nil {
+ return nil, err
+ }
+ return &UnixListener{fd: fd, path: fd.laddr.String(), unlink: true}, nil
+}
+
+func (sl *sysListener) listenUnixgram(ctx context.Context, laddr *UnixAddr) (*UnixConn, error) {
+ var ctrlCtxFn func(cxt context.Context, network, address string, c syscall.RawConn) error
+ if sl.ListenConfig.Control != nil {
+ ctrlCtxFn = func(cxt context.Context, network, address string, c syscall.RawConn) error {
+ return sl.ListenConfig.Control(network, address, c)
+ }
+ }
+ fd, err := unixSocket(ctx, sl.network, laddr, nil, "listen", ctrlCtxFn)
+ if err != nil {
+ return nil, err
+ }
+ return newUnixConn(fd), nil
+}
diff --git a/src/net/unixsock_readmsg_cloexec.go b/src/net/unixsock_readmsg_cloexec.go
new file mode 100644
index 0000000..fa4fd7d
--- /dev/null
+++ b/src/net/unixsock_readmsg_cloexec.go
@@ -0,0 +1,30 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build aix || darwin || freebsd || solaris
+
+package net
+
+import "syscall"
+
+const readMsgFlags = 0
+
+func setReadMsgCloseOnExec(oob []byte) {
+ scms, err := syscall.ParseSocketControlMessage(oob)
+ if err != nil {
+ return
+ }
+
+ for _, scm := range scms {
+ if scm.Header.Level == syscall.SOL_SOCKET && scm.Header.Type == syscall.SCM_RIGHTS {
+ fds, err := syscall.ParseUnixRights(&scm)
+ if err != nil {
+ continue
+ }
+ for _, fd := range fds {
+ syscall.CloseOnExec(fd)
+ }
+ }
+ }
+}
diff --git a/src/net/unixsock_readmsg_cmsg_cloexec.go b/src/net/unixsock_readmsg_cmsg_cloexec.go
new file mode 100644
index 0000000..6b0de87
--- /dev/null
+++ b/src/net/unixsock_readmsg_cmsg_cloexec.go
@@ -0,0 +1,13 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build dragonfly || linux || netbsd || openbsd
+
+package net
+
+import "syscall"
+
+const readMsgFlags = syscall.MSG_CMSG_CLOEXEC
+
+func setReadMsgCloseOnExec(oob []byte) {}
diff --git a/src/net/unixsock_readmsg_other.go b/src/net/unixsock_readmsg_other.go
new file mode 100644
index 0000000..0899a6d
--- /dev/null
+++ b/src/net/unixsock_readmsg_other.go
@@ -0,0 +1,11 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build (js && wasm) || wasip1 || windows
+
+package net
+
+const readMsgFlags = 0
+
+func setReadMsgCloseOnExec(oob []byte) {}
diff --git a/src/net/unixsock_readmsg_test.go b/src/net/unixsock_readmsg_test.go
new file mode 100644
index 0000000..2d89dc4
--- /dev/null
+++ b/src/net/unixsock_readmsg_test.go
@@ -0,0 +1,105 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build unix
+
+package net
+
+import (
+ "internal/syscall/unix"
+ "os"
+ "syscall"
+ "testing"
+ "time"
+)
+
+func TestUnixConnReadMsgUnixSCMRightsCloseOnExec(t *testing.T) {
+ if !testableNetwork("unix") {
+ t.Skip("not unix system")
+ }
+
+ scmFile, err := os.Open(os.DevNull)
+ if err != nil {
+ t.Fatalf("file open: %v", err)
+ }
+ defer scmFile.Close()
+
+ rights := syscall.UnixRights(int(scmFile.Fd()))
+ fds, err := syscall.Socketpair(syscall.AF_LOCAL, syscall.SOCK_STREAM, 0)
+ if err != nil {
+ t.Fatalf("Socketpair: %v", err)
+ }
+
+ writeFile := os.NewFile(uintptr(fds[0]), "write-socket")
+ defer writeFile.Close()
+ readFile := os.NewFile(uintptr(fds[1]), "read-socket")
+ defer readFile.Close()
+
+ cw, err := FileConn(writeFile)
+ if err != nil {
+ t.Fatalf("FileConn: %v", err)
+ }
+ defer cw.Close()
+ cr, err := FileConn(readFile)
+ if err != nil {
+ t.Fatalf("FileConn: %v", err)
+ }
+ defer cr.Close()
+
+ ucw, ok := cw.(*UnixConn)
+ if !ok {
+ t.Fatalf("got %T; want UnixConn", cw)
+ }
+ ucr, ok := cr.(*UnixConn)
+ if !ok {
+ t.Fatalf("got %T; want UnixConn", cr)
+ }
+
+ oob := make([]byte, syscall.CmsgSpace(4))
+ err = ucw.SetWriteDeadline(time.Now().Add(5 * time.Second))
+ if err != nil {
+ t.Fatalf("Can't set unix connection timeout: %v", err)
+ }
+ _, _, err = ucw.WriteMsgUnix(nil, rights, nil)
+ if err != nil {
+ t.Fatalf("UnixConn readMsg: %v", err)
+ }
+ err = ucr.SetReadDeadline(time.Now().Add(5 * time.Second))
+ if err != nil {
+ t.Fatalf("Can't set unix connection timeout: %v", err)
+ }
+ _, oobn, _, _, err := ucr.ReadMsgUnix(nil, oob)
+ if err != nil {
+ t.Fatalf("UnixConn readMsg: %v", err)
+ }
+
+ scms, err := syscall.ParseSocketControlMessage(oob[:oobn])
+ if err != nil {
+ t.Fatalf("ParseSocketControlMessage: %v", err)
+ }
+ if len(scms) != 1 {
+ t.Fatalf("got scms = %#v; expected 1 SocketControlMessage", scms)
+ }
+ scm := scms[0]
+ gotFDs, err := syscall.ParseUnixRights(&scm)
+ if err != nil {
+ t.Fatalf("syscall.ParseUnixRights: %v", err)
+ }
+ if len(gotFDs) != 1 {
+ t.Fatalf("got FDs %#v: wanted only 1 fd", gotFDs)
+ }
+ defer func() {
+ if err := syscall.Close(gotFDs[0]); err != nil {
+ t.Fatalf("fail to close gotFDs: %v", err)
+ }
+ }()
+
+ flags, err := unix.Fcntl(gotFDs[0], syscall.F_GETFD, 0)
+ if err != nil {
+ t.Fatalf("Can't get flags of fd:%#v, with err:%v", gotFDs[0], err)
+ }
+ if flags&syscall.FD_CLOEXEC == 0 {
+ t.Fatalf("got flags %#x, want %#x (FD_CLOEXEC) set", flags, syscall.FD_CLOEXEC)
+ }
+}
diff --git a/src/net/unixsock_test.go b/src/net/unixsock_test.go
new file mode 100644
index 0000000..8402519
--- /dev/null
+++ b/src/net/unixsock_test.go
@@ -0,0 +1,463 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !js && !plan9 && !wasip1 && !windows
+
+package net
+
+import (
+ "bytes"
+ "internal/testenv"
+ "os"
+ "reflect"
+ "runtime"
+ "syscall"
+ "testing"
+ "time"
+)
+
+func TestReadUnixgramWithUnnamedSocket(t *testing.T) {
+ if !testableNetwork("unixgram") {
+ t.Skip("unixgram test")
+ }
+ if runtime.GOOS == "openbsd" {
+ testenv.SkipFlaky(t, 15157)
+ }
+
+ addr := testUnixAddr(t)
+ la, err := ResolveUnixAddr("unixgram", addr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ c, err := ListenUnixgram("unixgram", la)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer func() {
+ c.Close()
+ os.Remove(addr)
+ }()
+
+ off := make(chan bool)
+ data := [5]byte{1, 2, 3, 4, 5}
+ go func() {
+ defer func() { off <- true }()
+ s, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_DGRAM, 0)
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ defer syscall.Close(s)
+ rsa := &syscall.SockaddrUnix{Name: addr}
+ if err := syscall.Sendto(s, data[:], 0, rsa); err != nil {
+ t.Error(err)
+ return
+ }
+ }()
+
+ <-off
+ b := make([]byte, 64)
+ c.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
+ n, from, err := c.ReadFrom(b)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if from != nil {
+ t.Fatalf("unexpected peer address: %v", from)
+ }
+ if !bytes.Equal(b[:n], data[:]) {
+ t.Fatalf("got %v; want %v", b[:n], data[:])
+ }
+}
+
+func TestUnixgramZeroBytePayload(t *testing.T) {
+ if !testableNetwork("unixgram") {
+ t.Skip("unixgram test")
+ }
+
+ c1 := newLocalPacketListener(t, "unixgram")
+ defer os.Remove(c1.LocalAddr().String())
+ defer c1.Close()
+
+ c2, err := Dial("unixgram", c1.LocalAddr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer os.Remove(c2.LocalAddr().String())
+ defer c2.Close()
+
+ for _, genericRead := range []bool{false, true} {
+ n, err := c2.Write(nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if n != 0 {
+ t.Errorf("got %d; want 0", n)
+ }
+ c1.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
+ var b [1]byte
+ var peer Addr
+ if genericRead {
+ _, err = c1.(Conn).Read(b[:])
+ } else {
+ _, peer, err = c1.ReadFrom(b[:])
+ }
+ switch err {
+ case nil: // ReadFrom succeeds
+ if peer != nil { // peer is connected-mode
+ t.Fatalf("unexpected peer address: %v", peer)
+ }
+ default: // Read may timeout, it depends on the platform
+ if !isDeadlineExceeded(err) {
+ t.Fatal(err)
+ }
+ }
+ }
+}
+
+func TestUnixgramZeroByteBuffer(t *testing.T) {
+ if !testableNetwork("unixgram") {
+ t.Skip("unixgram test")
+ }
+ // issue 4352: Recvfrom failed with "address family not
+ // supported by protocol family" if zero-length buffer provided
+
+ c1 := newLocalPacketListener(t, "unixgram")
+ defer os.Remove(c1.LocalAddr().String())
+ defer c1.Close()
+
+ c2, err := Dial("unixgram", c1.LocalAddr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer os.Remove(c2.LocalAddr().String())
+ defer c2.Close()
+
+ b := []byte("UNIXGRAM ZERO BYTE BUFFER TEST")
+ for _, genericRead := range []bool{false, true} {
+ n, err := c2.Write(b)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if n != len(b) {
+ t.Errorf("got %d; want %d", n, len(b))
+ }
+ c1.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
+ var peer Addr
+ if genericRead {
+ _, err = c1.(Conn).Read(nil)
+ } else {
+ _, peer, err = c1.ReadFrom(nil)
+ }
+ switch err {
+ case nil: // ReadFrom succeeds
+ if peer != nil { // peer is connected-mode
+ t.Fatalf("unexpected peer address: %v", peer)
+ }
+ default: // Read may timeout, it depends on the platform
+ if !isDeadlineExceeded(err) {
+ t.Fatal(err)
+ }
+ }
+ }
+}
+
+func TestUnixgramWrite(t *testing.T) {
+ if !testableNetwork("unixgram") {
+ t.Skip("unixgram test")
+ }
+
+ addr := testUnixAddr(t)
+ laddr, err := ResolveUnixAddr("unixgram", addr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ c, err := ListenPacket("unixgram", addr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer os.Remove(addr)
+ defer c.Close()
+
+ testUnixgramWriteConn(t, laddr)
+ testUnixgramWritePacketConn(t, laddr)
+}
+
+func testUnixgramWriteConn(t *testing.T, raddr *UnixAddr) {
+ c, err := Dial("unixgram", raddr.String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+
+ b := []byte("CONNECTED-MODE SOCKET")
+ if _, err := c.(*UnixConn).WriteToUnix(b, raddr); err == nil {
+ t.Fatal("should fail")
+ } else if err.(*OpError).Err != ErrWriteToConnected {
+ t.Fatalf("should fail as ErrWriteToConnected: %v", err)
+ }
+ if _, err = c.(*UnixConn).WriteTo(b, raddr); err == nil {
+ t.Fatal("should fail")
+ } else if err.(*OpError).Err != ErrWriteToConnected {
+ t.Fatalf("should fail as ErrWriteToConnected: %v", err)
+ }
+ if _, _, err = c.(*UnixConn).WriteMsgUnix(b, nil, raddr); err == nil {
+ t.Fatal("should fail")
+ } else if err.(*OpError).Err != ErrWriteToConnected {
+ t.Fatalf("should fail as ErrWriteToConnected: %v", err)
+ }
+ if _, err := c.Write(b); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func testUnixgramWritePacketConn(t *testing.T, raddr *UnixAddr) {
+ addr := testUnixAddr(t)
+ c, err := ListenPacket("unixgram", addr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer os.Remove(addr)
+ defer c.Close()
+
+ b := []byte("UNCONNECTED-MODE SOCKET")
+ if _, err := c.(*UnixConn).WriteToUnix(b, raddr); err != nil {
+ t.Fatal(err)
+ }
+ if _, err := c.WriteTo(b, raddr); err != nil {
+ t.Fatal(err)
+ }
+ if _, _, err := c.(*UnixConn).WriteMsgUnix(b, nil, raddr); err != nil {
+ t.Fatal(err)
+ }
+ if _, err := c.(*UnixConn).Write(b); err == nil {
+ t.Fatal("should fail")
+ }
+}
+
+func TestUnixConnLocalAndRemoteNames(t *testing.T) {
+ if !testableNetwork("unix") {
+ t.Skip("unix test")
+ }
+
+ handler := func(ls *localServer, ln Listener) {}
+ for _, laddr := range []string{"", testUnixAddr(t)} {
+ laddr := laddr
+ taddr := testUnixAddr(t)
+ ta, err := ResolveUnixAddr("unix", taddr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ ln, err := ListenUnix("unix", ta)
+ if err != nil {
+ t.Fatal(err)
+ }
+ ls := (&streamListener{Listener: ln}).newLocalServer()
+ defer ls.teardown()
+ if err := ls.buildup(handler); err != nil {
+ t.Fatal(err)
+ }
+
+ la, err := ResolveUnixAddr("unix", laddr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ c, err := DialUnix("unix", la, ta)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer func() {
+ c.Close()
+ if la != nil {
+ defer os.Remove(laddr)
+ }
+ }()
+ if _, err := c.Write([]byte("UNIXCONN LOCAL AND REMOTE NAME TEST")); err != nil {
+ t.Fatal(err)
+ }
+
+ switch runtime.GOOS {
+ case "android", "linux":
+ if laddr == "" {
+ laddr = "@" // autobind feature
+ }
+ }
+ var connAddrs = [3]struct{ got, want Addr }{
+ {ln.Addr(), ta},
+ {c.LocalAddr(), &UnixAddr{Name: laddr, Net: "unix"}},
+ {c.RemoteAddr(), ta},
+ }
+ for _, ca := range connAddrs {
+ if !reflect.DeepEqual(ca.got, ca.want) {
+ t.Fatalf("got %#v, expected %#v", ca.got, ca.want)
+ }
+ }
+ }
+}
+
+func TestUnixgramConnLocalAndRemoteNames(t *testing.T) {
+ if !testableNetwork("unixgram") {
+ t.Skip("unixgram test")
+ }
+
+ for _, laddr := range []string{"", testUnixAddr(t)} {
+ laddr := laddr
+ taddr := testUnixAddr(t)
+ ta, err := ResolveUnixAddr("unixgram", taddr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ c1, err := ListenUnixgram("unixgram", ta)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer func() {
+ c1.Close()
+ os.Remove(taddr)
+ }()
+
+ var la *UnixAddr
+ if laddr != "" {
+ if la, err = ResolveUnixAddr("unixgram", laddr); err != nil {
+ t.Fatal(err)
+ }
+ }
+ c2, err := DialUnix("unixgram", la, ta)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer func() {
+ c2.Close()
+ if la != nil {
+ defer os.Remove(laddr)
+ }
+ }()
+
+ switch runtime.GOOS {
+ case "android", "linux":
+ if laddr == "" {
+ laddr = "@" // autobind feature
+ }
+ }
+
+ var connAddrs = [4]struct{ got, want Addr }{
+ {c1.LocalAddr(), ta},
+ {c1.RemoteAddr(), nil},
+ {c2.LocalAddr(), &UnixAddr{Name: laddr, Net: "unixgram"}},
+ {c2.RemoteAddr(), ta},
+ }
+ for _, ca := range connAddrs {
+ if !reflect.DeepEqual(ca.got, ca.want) {
+ t.Fatalf("got %#v; want %#v", ca.got, ca.want)
+ }
+ }
+ }
+}
+
+func TestUnixUnlink(t *testing.T) {
+ if !testableNetwork("unix") {
+ t.Skip("unix test")
+ }
+ name := testUnixAddr(t)
+
+ listen := func(t *testing.T) *UnixListener {
+ l, err := Listen("unix", name)
+ if err != nil {
+ t.Fatal(err)
+ }
+ return l.(*UnixListener)
+ }
+ checkExists := func(t *testing.T, desc string) {
+ if _, err := os.Stat(name); err != nil {
+ t.Fatalf("unix socket does not exist %s: %v", desc, err)
+ }
+ }
+ checkNotExists := func(t *testing.T, desc string) {
+ if _, err := os.Stat(name); err == nil {
+ t.Fatalf("unix socket does exist %s: %v", desc, err)
+ }
+ }
+
+ // Listener should remove on close.
+ t.Run("Listen", func(t *testing.T) {
+ l := listen(t)
+ checkExists(t, "after Listen")
+ l.Close()
+ checkNotExists(t, "after Listener close")
+ })
+
+ // FileListener should not.
+ t.Run("FileListener", func(t *testing.T) {
+ l := listen(t)
+ f, _ := l.File()
+ l1, _ := FileListener(f)
+ checkExists(t, "after FileListener")
+ f.Close()
+ checkExists(t, "after File close")
+ l1.Close()
+ checkExists(t, "after FileListener close")
+ l.Close()
+ checkNotExists(t, "after Listener close")
+ })
+
+ // Only first call to l.Close should remove.
+ t.Run("SecondClose", func(t *testing.T) {
+ l := listen(t)
+ checkExists(t, "after Listen")
+ l.Close()
+ checkNotExists(t, "after Listener close")
+ if err := os.WriteFile(name, []byte("hello world"), 0666); err != nil {
+ t.Fatalf("cannot recreate socket file: %v", err)
+ }
+ checkExists(t, "after writing temp file")
+ l.Close()
+ checkExists(t, "after second Listener close")
+ os.Remove(name)
+ })
+
+ // SetUnlinkOnClose should do what it says.
+
+ t.Run("Listen/SetUnlinkOnClose(true)", func(t *testing.T) {
+ l := listen(t)
+ checkExists(t, "after Listen")
+ l.SetUnlinkOnClose(true)
+ l.Close()
+ checkNotExists(t, "after Listener close")
+ })
+
+ t.Run("Listen/SetUnlinkOnClose(false)", func(t *testing.T) {
+ l := listen(t)
+ checkExists(t, "after Listen")
+ l.SetUnlinkOnClose(false)
+ l.Close()
+ checkExists(t, "after Listener close")
+ os.Remove(name)
+ })
+
+ t.Run("FileListener/SetUnlinkOnClose(true)", func(t *testing.T) {
+ l := listen(t)
+ f, _ := l.File()
+ l1, _ := FileListener(f)
+ checkExists(t, "after FileListener")
+ l1.(*UnixListener).SetUnlinkOnClose(true)
+ f.Close()
+ checkExists(t, "after File close")
+ l1.Close()
+ checkNotExists(t, "after FileListener close")
+ l.Close()
+ })
+
+ t.Run("FileListener/SetUnlinkOnClose(false)", func(t *testing.T) {
+ l := listen(t)
+ f, _ := l.File()
+ l1, _ := FileListener(f)
+ checkExists(t, "after FileListener")
+ l1.(*UnixListener).SetUnlinkOnClose(false)
+ f.Close()
+ checkExists(t, "after File close")
+ l1.Close()
+ checkExists(t, "after FileListener close")
+ l.Close()
+ })
+}
diff --git a/src/net/unixsock_windows_test.go b/src/net/unixsock_windows_test.go
new file mode 100644
index 0000000..d541d89
--- /dev/null
+++ b/src/net/unixsock_windows_test.go
@@ -0,0 +1,97 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build windows
+
+package net
+
+import (
+ "internal/syscall/windows/registry"
+ "os"
+ "reflect"
+ "runtime"
+ "strconv"
+ "testing"
+)
+
+func isBuild17063() bool {
+ k, err := registry.OpenKey(registry.LOCAL_MACHINE, `SOFTWARE\Microsoft\Windows NT\CurrentVersion`, registry.READ)
+ if err != nil {
+ return false
+ }
+ defer k.Close()
+
+ s, _, err := k.GetStringValue("CurrentBuild")
+ if err != nil {
+ return false
+ }
+ ver, err := strconv.Atoi(s)
+ if err != nil {
+ return false
+ }
+ return ver >= 17063
+}
+
+func TestUnixConnLocalWindows(t *testing.T) {
+ switch runtime.GOARCH {
+ case "386":
+ t.Skip("not supported on windows/386, see golang.org/issue/27943")
+ case "arm":
+ t.Skip("not supported on windows/arm, see golang.org/issue/28061")
+ }
+ if !isBuild17063() {
+ t.Skip("unix test")
+ }
+
+ handler := func(ls *localServer, ln Listener) {}
+ for _, laddr := range []string{"", testUnixAddr(t)} {
+ laddr := laddr
+ taddr := testUnixAddr(t)
+ ta, err := ResolveUnixAddr("unix", taddr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ ln, err := ListenUnix("unix", ta)
+ if err != nil {
+ t.Fatal(err)
+ }
+ ls := (&streamListener{Listener: ln}).newLocalServer()
+ defer ls.teardown()
+ if err := ls.buildup(handler); err != nil {
+ t.Fatal(err)
+ }
+
+ la, err := ResolveUnixAddr("unix", laddr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ c, err := DialUnix("unix", la, ta)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer func() {
+ c.Close()
+ if la != nil {
+ defer os.Remove(laddr)
+ }
+ }()
+ if _, err := c.Write([]byte("UNIXCONN LOCAL AND REMOTE NAME TEST")); err != nil {
+ t.Fatal(err)
+ }
+
+ if laddr == "" {
+ laddr = "@"
+ }
+ var connAddrs = [3]struct{ got, want Addr }{
+ {ln.Addr(), ta},
+ {c.LocalAddr(), &UnixAddr{Name: laddr, Net: "unix"}},
+ {c.RemoteAddr(), ta},
+ }
+ for _, ca := range connAddrs {
+ if !reflect.DeepEqual(ca.got, ca.want) {
+ t.Fatalf("got %#v, expected %#v", ca.got, ca.want)
+ }
+ }
+ }
+}
diff --git a/src/net/url/example_test.go b/src/net/url/example_test.go
new file mode 100644
index 0000000..a191350
--- /dev/null
+++ b/src/net/url/example_test.go
@@ -0,0 +1,374 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package url_test
+
+import (
+ "encoding/json"
+ "fmt"
+ "log"
+ "net/url"
+ "strings"
+)
+
+func ExamplePathEscape() {
+ path := url.PathEscape("my/cool+blog&about,stuff")
+ fmt.Println(path)
+
+ // Output:
+ // my%2Fcool+blog&about%2Cstuff
+}
+
+func ExamplePathUnescape() {
+ escapedPath := "my%2Fcool+blog&about%2Cstuff"
+ path, err := url.PathUnescape(escapedPath)
+ if err != nil {
+ log.Fatal(err)
+ }
+ fmt.Println(path)
+
+ // Output:
+ // my/cool+blog&about,stuff
+}
+
+func ExampleQueryEscape() {
+ query := url.QueryEscape("my/cool+blog&about,stuff")
+ fmt.Println(query)
+
+ // Output:
+ // my%2Fcool%2Bblog%26about%2Cstuff
+}
+
+func ExampleQueryUnescape() {
+ escapedQuery := "my%2Fcool%2Bblog%26about%2Cstuff"
+ query, err := url.QueryUnescape(escapedQuery)
+ if err != nil {
+ log.Fatal(err)
+ }
+ fmt.Println(query)
+
+ // Output:
+ // my/cool+blog&about,stuff
+}
+
+func ExampleValues() {
+ v := url.Values{}
+ v.Set("name", "Ava")
+ v.Add("friend", "Jess")
+ v.Add("friend", "Sarah")
+ v.Add("friend", "Zoe")
+ // v.Encode() == "name=Ava&friend=Jess&friend=Sarah&friend=Zoe"
+ fmt.Println(v.Get("name"))
+ fmt.Println(v.Get("friend"))
+ fmt.Println(v["friend"])
+ // Output:
+ // Ava
+ // Jess
+ // [Jess Sarah Zoe]
+}
+
+func ExampleValues_Add() {
+ v := url.Values{}
+ v.Add("cat sounds", "meow")
+ v.Add("cat sounds", "mew")
+ v.Add("cat sounds", "mau")
+ fmt.Println(v["cat sounds"])
+
+ // Output:
+ // [meow mew mau]
+}
+
+func ExampleValues_Del() {
+ v := url.Values{}
+ v.Add("cat sounds", "meow")
+ v.Add("cat sounds", "mew")
+ v.Add("cat sounds", "mau")
+ fmt.Println(v["cat sounds"])
+
+ v.Del("cat sounds")
+ fmt.Println(v["cat sounds"])
+
+ // Output:
+ // [meow mew mau]
+ // []
+}
+
+func ExampleValues_Encode() {
+ v := url.Values{}
+ v.Add("cat sounds", "meow")
+ v.Add("cat sounds", "mew/")
+ v.Add("cat sounds", "mau$")
+ fmt.Println(v.Encode())
+
+ // Output:
+ // cat+sounds=meow&cat+sounds=mew%2F&cat+sounds=mau%24
+}
+
+func ExampleValues_Get() {
+ v := url.Values{}
+ v.Add("cat sounds", "meow")
+ v.Add("cat sounds", "mew")
+ v.Add("cat sounds", "mau")
+ fmt.Printf("%q\n", v.Get("cat sounds"))
+ fmt.Printf("%q\n", v.Get("dog sounds"))
+
+ // Output:
+ // "meow"
+ // ""
+}
+
+func ExampleValues_Has() {
+ v := url.Values{}
+ v.Add("cat sounds", "meow")
+ v.Add("cat sounds", "mew")
+ v.Add("cat sounds", "mau")
+ fmt.Println(v.Has("cat sounds"))
+ fmt.Println(v.Has("dog sounds"))
+
+ // Output:
+ // true
+ // false
+}
+
+func ExampleValues_Set() {
+ v := url.Values{}
+ v.Add("cat sounds", "meow")
+ v.Add("cat sounds", "mew")
+ v.Add("cat sounds", "mau")
+ fmt.Println(v["cat sounds"])
+
+ v.Set("cat sounds", "meow")
+ fmt.Println(v["cat sounds"])
+
+ // Output:
+ // [meow mew mau]
+ // [meow]
+}
+
+func ExampleURL() {
+ u, err := url.Parse("http://bing.com/search?q=dotnet")
+ if err != nil {
+ log.Fatal(err)
+ }
+ u.Scheme = "https"
+ u.Host = "google.com"
+ q := u.Query()
+ q.Set("q", "golang")
+ u.RawQuery = q.Encode()
+ fmt.Println(u)
+ // Output: https://google.com/search?q=golang
+}
+
+func ExampleURL_roundtrip() {
+ // Parse + String preserve the original encoding.
+ u, err := url.Parse("https://example.com/foo%2fbar")
+ if err != nil {
+ log.Fatal(err)
+ }
+ fmt.Println(u.Path)
+ fmt.Println(u.RawPath)
+ fmt.Println(u.String())
+ // Output:
+ // /foo/bar
+ // /foo%2fbar
+ // https://example.com/foo%2fbar
+}
+
+func ExampleURL_ResolveReference() {
+ u, err := url.Parse("../../..//search?q=dotnet")
+ if err != nil {
+ log.Fatal(err)
+ }
+ base, err := url.Parse("http://example.com/directory/")
+ if err != nil {
+ log.Fatal(err)
+ }
+ fmt.Println(base.ResolveReference(u))
+ // Output:
+ // http://example.com/search?q=dotnet
+}
+
+func ExampleParseQuery() {
+ m, err := url.ParseQuery(`x=1&y=2&y=3`)
+ if err != nil {
+ log.Fatal(err)
+ }
+ fmt.Println(toJSON(m))
+ // Output:
+ // {"x":["1"], "y":["2", "3"]}
+}
+
+func ExampleURL_EscapedPath() {
+ u, err := url.Parse("http://example.com/x/y%2Fz")
+ if err != nil {
+ log.Fatal(err)
+ }
+ fmt.Println("Path:", u.Path)
+ fmt.Println("RawPath:", u.RawPath)
+ fmt.Println("EscapedPath:", u.EscapedPath())
+ // Output:
+ // Path: /x/y/z
+ // RawPath: /x/y%2Fz
+ // EscapedPath: /x/y%2Fz
+}
+
+func ExampleURL_EscapedFragment() {
+ u, err := url.Parse("http://example.com/#x/y%2Fz")
+ if err != nil {
+ log.Fatal(err)
+ }
+ fmt.Println("Fragment:", u.Fragment)
+ fmt.Println("RawFragment:", u.RawFragment)
+ fmt.Println("EscapedFragment:", u.EscapedFragment())
+ // Output:
+ // Fragment: x/y/z
+ // RawFragment: x/y%2Fz
+ // EscapedFragment: x/y%2Fz
+}
+
+func ExampleURL_Hostname() {
+ u, err := url.Parse("https://example.org:8000/path")
+ if err != nil {
+ log.Fatal(err)
+ }
+ fmt.Println(u.Hostname())
+ u, err = url.Parse("https://[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:17000")
+ if err != nil {
+ log.Fatal(err)
+ }
+ fmt.Println(u.Hostname())
+ // Output:
+ // example.org
+ // 2001:0db8:85a3:0000:0000:8a2e:0370:7334
+}
+
+func ExampleURL_IsAbs() {
+ u := url.URL{Host: "example.com", Path: "foo"}
+ fmt.Println(u.IsAbs())
+ u.Scheme = "http"
+ fmt.Println(u.IsAbs())
+ // Output:
+ // false
+ // true
+}
+
+func ExampleURL_MarshalBinary() {
+ u, _ := url.Parse("https://example.org")
+ b, err := u.MarshalBinary()
+ if err != nil {
+ log.Fatal(err)
+ }
+ fmt.Printf("%s\n", b)
+ // Output:
+ // https://example.org
+}
+
+func ExampleURL_Parse() {
+ u, err := url.Parse("https://example.org")
+ if err != nil {
+ log.Fatal(err)
+ }
+ rel, err := u.Parse("/foo")
+ if err != nil {
+ log.Fatal(err)
+ }
+ fmt.Println(rel)
+ _, err = u.Parse(":foo")
+ if _, ok := err.(*url.Error); !ok {
+ log.Fatal(err)
+ }
+ // Output:
+ // https://example.org/foo
+}
+
+func ExampleURL_Port() {
+ u, err := url.Parse("https://example.org")
+ if err != nil {
+ log.Fatal(err)
+ }
+ fmt.Println(u.Port())
+ u, err = url.Parse("https://example.org:8080")
+ if err != nil {
+ log.Fatal(err)
+ }
+ fmt.Println(u.Port())
+ // Output:
+ //
+ // 8080
+}
+
+func ExampleURL_Query() {
+ u, err := url.Parse("https://example.org/?a=1&a=2&b=&=3&&&&")
+ if err != nil {
+ log.Fatal(err)
+ }
+ q := u.Query()
+ fmt.Println(q["a"])
+ fmt.Println(q.Get("b"))
+ fmt.Println(q.Get(""))
+ // Output:
+ // [1 2]
+ //
+ // 3
+}
+
+func ExampleURL_String() {
+ u := &url.URL{
+ Scheme: "https",
+ User: url.UserPassword("me", "pass"),
+ Host: "example.com",
+ Path: "foo/bar",
+ RawQuery: "x=1&y=2",
+ Fragment: "anchor",
+ }
+ fmt.Println(u.String())
+ u.Opaque = "opaque"
+ fmt.Println(u.String())
+ // Output:
+ // https://me:pass@example.com/foo/bar?x=1&y=2#anchor
+ // https:opaque?x=1&y=2#anchor
+}
+
+func ExampleURL_UnmarshalBinary() {
+ u := &url.URL{}
+ err := u.UnmarshalBinary([]byte("https://example.org/foo"))
+ if err != nil {
+ log.Fatal(err)
+ }
+ fmt.Printf("%s\n", u)
+ // Output:
+ // https://example.org/foo
+}
+
+func ExampleURL_Redacted() {
+ u := &url.URL{
+ Scheme: "https",
+ User: url.UserPassword("user", "password"),
+ Host: "example.com",
+ Path: "foo/bar",
+ }
+ fmt.Println(u.Redacted())
+ u.User = url.UserPassword("me", "newerPassword")
+ fmt.Println(u.Redacted())
+ // Output:
+ // https://user:xxxxx@example.com/foo/bar
+ // https://me:xxxxx@example.com/foo/bar
+}
+
+func ExampleURL_RequestURI() {
+ u, err := url.Parse("https://example.org/path?foo=bar")
+ if err != nil {
+ log.Fatal(err)
+ }
+ fmt.Println(u.RequestURI())
+ // Output: /path?foo=bar
+}
+
+func toJSON(m any) string {
+ js, err := json.Marshal(m)
+ if err != nil {
+ log.Fatal(err)
+ }
+ return strings.ReplaceAll(string(js), ",", ", ")
+}
diff --git a/src/net/url/url.go b/src/net/url/url.go
new file mode 100644
index 0000000..501b263
--- /dev/null
+++ b/src/net/url/url.go
@@ -0,0 +1,1265 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package url parses URLs and implements query escaping.
+package url
+
+// See RFC 3986. This package generally follows RFC 3986, except where
+// it deviates for compatibility reasons. When sending changes, first
+// search old issues for history on decisions. Unit tests should also
+// contain references to issue numbers with details.
+
+import (
+ "errors"
+ "fmt"
+ "path"
+ "sort"
+ "strconv"
+ "strings"
+)
+
+// Error reports an error and the operation and URL that caused it.
+type Error struct {
+ Op string
+ URL string
+ Err error
+}
+
+func (e *Error) Unwrap() error { return e.Err }
+func (e *Error) Error() string { return fmt.Sprintf("%s %q: %s", e.Op, e.URL, e.Err) }
+
+func (e *Error) Timeout() bool {
+ t, ok := e.Err.(interface {
+ Timeout() bool
+ })
+ return ok && t.Timeout()
+}
+
+func (e *Error) Temporary() bool {
+ t, ok := e.Err.(interface {
+ Temporary() bool
+ })
+ return ok && t.Temporary()
+}
+
+const upperhex = "0123456789ABCDEF"
+
+func ishex(c byte) bool {
+ switch {
+ case '0' <= c && c <= '9':
+ return true
+ case 'a' <= c && c <= 'f':
+ return true
+ case 'A' <= c && c <= 'F':
+ return true
+ }
+ return false
+}
+
+func unhex(c byte) byte {
+ switch {
+ case '0' <= c && c <= '9':
+ return c - '0'
+ case 'a' <= c && c <= 'f':
+ return c - 'a' + 10
+ case 'A' <= c && c <= 'F':
+ return c - 'A' + 10
+ }
+ return 0
+}
+
+type encoding int
+
+const (
+ encodePath encoding = 1 + iota
+ encodePathSegment
+ encodeHost
+ encodeZone
+ encodeUserPassword
+ encodeQueryComponent
+ encodeFragment
+)
+
+type EscapeError string
+
+func (e EscapeError) Error() string {
+ return "invalid URL escape " + strconv.Quote(string(e))
+}
+
+type InvalidHostError string
+
+func (e InvalidHostError) Error() string {
+ return "invalid character " + strconv.Quote(string(e)) + " in host name"
+}
+
+// Return true if the specified character should be escaped when
+// appearing in a URL string, according to RFC 3986.
+//
+// Please be informed that for now shouldEscape does not check all
+// reserved characters correctly. See golang.org/issue/5684.
+func shouldEscape(c byte, mode encoding) bool {
+ // §2.3 Unreserved characters (alphanum)
+ if 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || '0' <= c && c <= '9' {
+ return false
+ }
+
+ if mode == encodeHost || mode == encodeZone {
+ // §3.2.2 Host allows
+ // sub-delims = "!" / "$" / "&" / "'" / "(" / ")" / "*" / "+" / "," / ";" / "="
+ // as part of reg-name.
+ // We add : because we include :port as part of host.
+ // We add [ ] because we include [ipv6]:port as part of host.
+ // We add < > because they're the only characters left that
+ // we could possibly allow, and Parse will reject them if we
+ // escape them (because hosts can't use %-encoding for
+ // ASCII bytes).
+ switch c {
+ case '!', '$', '&', '\'', '(', ')', '*', '+', ',', ';', '=', ':', '[', ']', '<', '>', '"':
+ return false
+ }
+ }
+
+ switch c {
+ case '-', '_', '.', '~': // §2.3 Unreserved characters (mark)
+ return false
+
+ case '$', '&', '+', ',', '/', ':', ';', '=', '?', '@': // §2.2 Reserved characters (reserved)
+ // Different sections of the URL allow a few of
+ // the reserved characters to appear unescaped.
+ switch mode {
+ case encodePath: // §3.3
+ // The RFC allows : @ & = + $ but saves / ; , for assigning
+ // meaning to individual path segments. This package
+ // only manipulates the path as a whole, so we allow those
+ // last three as well. That leaves only ? to escape.
+ return c == '?'
+
+ case encodePathSegment: // §3.3
+ // The RFC allows : @ & = + $ but saves / ; , for assigning
+ // meaning to individual path segments.
+ return c == '/' || c == ';' || c == ',' || c == '?'
+
+ case encodeUserPassword: // §3.2.1
+ // The RFC allows ';', ':', '&', '=', '+', '$', and ',' in
+ // userinfo, so we must escape only '@', '/', and '?'.
+ // The parsing of userinfo treats ':' as special so we must escape
+ // that too.
+ return c == '@' || c == '/' || c == '?' || c == ':'
+
+ case encodeQueryComponent: // §3.4
+ // The RFC reserves (so we must escape) everything.
+ return true
+
+ case encodeFragment: // §4.1
+ // The RFC text is silent but the grammar allows
+ // everything, so escape nothing.
+ return false
+ }
+ }
+
+ if mode == encodeFragment {
+ // RFC 3986 §2.2 allows not escaping sub-delims. A subset of sub-delims are
+ // included in reserved from RFC 2396 §2.2. The remaining sub-delims do not
+ // need to be escaped. To minimize potential breakage, we apply two restrictions:
+ // (1) we always escape sub-delims outside of the fragment, and (2) we always
+ // escape single quote to avoid breaking callers that had previously assumed that
+ // single quotes would be escaped. See issue #19917.
+ switch c {
+ case '!', '(', ')', '*':
+ return false
+ }
+ }
+
+ // Everything else must be escaped.
+ return true
+}
+
+// QueryUnescape does the inverse transformation of QueryEscape,
+// converting each 3-byte encoded substring of the form "%AB" into the
+// hex-decoded byte 0xAB.
+// It returns an error if any % is not followed by two hexadecimal
+// digits.
+func QueryUnescape(s string) (string, error) {
+ return unescape(s, encodeQueryComponent)
+}
+
+// PathUnescape does the inverse transformation of PathEscape,
+// converting each 3-byte encoded substring of the form "%AB" into the
+// hex-decoded byte 0xAB. It returns an error if any % is not followed
+// by two hexadecimal digits.
+//
+// PathUnescape is identical to QueryUnescape except that it does not
+// unescape '+' to ' ' (space).
+func PathUnescape(s string) (string, error) {
+ return unescape(s, encodePathSegment)
+}
+
+// unescape unescapes a string; the mode specifies
+// which section of the URL string is being unescaped.
+func unescape(s string, mode encoding) (string, error) {
+ // Count %, check that they're well-formed.
+ n := 0
+ hasPlus := false
+ for i := 0; i < len(s); {
+ switch s[i] {
+ case '%':
+ n++
+ if i+2 >= len(s) || !ishex(s[i+1]) || !ishex(s[i+2]) {
+ s = s[i:]
+ if len(s) > 3 {
+ s = s[:3]
+ }
+ return "", EscapeError(s)
+ }
+ // Per https://tools.ietf.org/html/rfc3986#page-21
+ // in the host component %-encoding can only be used
+ // for non-ASCII bytes.
+ // But https://tools.ietf.org/html/rfc6874#section-2
+ // introduces %25 being allowed to escape a percent sign
+ // in IPv6 scoped-address literals. Yay.
+ if mode == encodeHost && unhex(s[i+1]) < 8 && s[i:i+3] != "%25" {
+ return "", EscapeError(s[i : i+3])
+ }
+ if mode == encodeZone {
+ // RFC 6874 says basically "anything goes" for zone identifiers
+ // and that even non-ASCII can be redundantly escaped,
+ // but it seems prudent to restrict %-escaped bytes here to those
+ // that are valid host name bytes in their unescaped form.
+ // That is, you can use escaping in the zone identifier but not
+ // to introduce bytes you couldn't just write directly.
+ // But Windows puts spaces here! Yay.
+ v := unhex(s[i+1])<<4 | unhex(s[i+2])
+ if s[i:i+3] != "%25" && v != ' ' && shouldEscape(v, encodeHost) {
+ return "", EscapeError(s[i : i+3])
+ }
+ }
+ i += 3
+ case '+':
+ hasPlus = mode == encodeQueryComponent
+ i++
+ default:
+ if (mode == encodeHost || mode == encodeZone) && s[i] < 0x80 && shouldEscape(s[i], mode) {
+ return "", InvalidHostError(s[i : i+1])
+ }
+ i++
+ }
+ }
+
+ if n == 0 && !hasPlus {
+ return s, nil
+ }
+
+ var t strings.Builder
+ t.Grow(len(s) - 2*n)
+ for i := 0; i < len(s); i++ {
+ switch s[i] {
+ case '%':
+ t.WriteByte(unhex(s[i+1])<<4 | unhex(s[i+2]))
+ i += 2
+ case '+':
+ if mode == encodeQueryComponent {
+ t.WriteByte(' ')
+ } else {
+ t.WriteByte('+')
+ }
+ default:
+ t.WriteByte(s[i])
+ }
+ }
+ return t.String(), nil
+}
+
+// QueryEscape escapes the string so it can be safely placed
+// inside a URL query.
+func QueryEscape(s string) string {
+ return escape(s, encodeQueryComponent)
+}
+
+// PathEscape escapes the string so it can be safely placed inside a URL path segment,
+// replacing special characters (including /) with %XX sequences as needed.
+func PathEscape(s string) string {
+ return escape(s, encodePathSegment)
+}
+
+func escape(s string, mode encoding) string {
+ spaceCount, hexCount := 0, 0
+ for i := 0; i < len(s); i++ {
+ c := s[i]
+ if shouldEscape(c, mode) {
+ if c == ' ' && mode == encodeQueryComponent {
+ spaceCount++
+ } else {
+ hexCount++
+ }
+ }
+ }
+
+ if spaceCount == 0 && hexCount == 0 {
+ return s
+ }
+
+ var buf [64]byte
+ var t []byte
+
+ required := len(s) + 2*hexCount
+ if required <= len(buf) {
+ t = buf[:required]
+ } else {
+ t = make([]byte, required)
+ }
+
+ if hexCount == 0 {
+ copy(t, s)
+ for i := 0; i < len(s); i++ {
+ if s[i] == ' ' {
+ t[i] = '+'
+ }
+ }
+ return string(t)
+ }
+
+ j := 0
+ for i := 0; i < len(s); i++ {
+ switch c := s[i]; {
+ case c == ' ' && mode == encodeQueryComponent:
+ t[j] = '+'
+ j++
+ case shouldEscape(c, mode):
+ t[j] = '%'
+ t[j+1] = upperhex[c>>4]
+ t[j+2] = upperhex[c&15]
+ j += 3
+ default:
+ t[j] = s[i]
+ j++
+ }
+ }
+ return string(t)
+}
+
+// A URL represents a parsed URL (technically, a URI reference).
+//
+// The general form represented is:
+//
+// [scheme:][//[userinfo@]host][/]path[?query][#fragment]
+//
+// URLs that do not start with a slash after the scheme are interpreted as:
+//
+// scheme:opaque[?query][#fragment]
+//
+// Note that the Path field is stored in decoded form: /%47%6f%2f becomes /Go/.
+// A consequence is that it is impossible to tell which slashes in the Path were
+// slashes in the raw URL and which were %2f. This distinction is rarely important,
+// but when it is, the code should use the EscapedPath method, which preserves
+// the original encoding of Path.
+//
+// The RawPath field is an optional field which is only set when the default
+// encoding of Path is different from the escaped path. See the EscapedPath method
+// for more details.
+//
+// URL's String method uses the EscapedPath method to obtain the path.
+type URL struct {
+ Scheme string
+ Opaque string // encoded opaque data
+ User *Userinfo // username and password information
+ Host string // host or host:port
+ Path string // path (relative paths may omit leading slash)
+ RawPath string // encoded path hint (see EscapedPath method)
+ OmitHost bool // do not emit empty host (authority)
+ ForceQuery bool // append a query ('?') even if RawQuery is empty
+ RawQuery string // encoded query values, without '?'
+ Fragment string // fragment for references, without '#'
+ RawFragment string // encoded fragment hint (see EscapedFragment method)
+}
+
+// User returns a Userinfo containing the provided username
+// and no password set.
+func User(username string) *Userinfo {
+ return &Userinfo{username, "", false}
+}
+
+// UserPassword returns a Userinfo containing the provided username
+// and password.
+//
+// This functionality should only be used with legacy web sites.
+// RFC 2396 warns that interpreting Userinfo this way
+// “is NOT RECOMMENDED, because the passing of authentication
+// information in clear text (such as URI) has proven to be a
+// security risk in almost every case where it has been used.”
+func UserPassword(username, password string) *Userinfo {
+ return &Userinfo{username, password, true}
+}
+
+// The Userinfo type is an immutable encapsulation of username and
+// password details for a URL. An existing Userinfo value is guaranteed
+// to have a username set (potentially empty, as allowed by RFC 2396),
+// and optionally a password.
+type Userinfo struct {
+ username string
+ password string
+ passwordSet bool
+}
+
+// Username returns the username.
+func (u *Userinfo) Username() string {
+ if u == nil {
+ return ""
+ }
+ return u.username
+}
+
+// Password returns the password in case it is set, and whether it is set.
+func (u *Userinfo) Password() (string, bool) {
+ if u == nil {
+ return "", false
+ }
+ return u.password, u.passwordSet
+}
+
+// String returns the encoded userinfo information in the standard form
+// of "username[:password]".
+func (u *Userinfo) String() string {
+ if u == nil {
+ return ""
+ }
+ s := escape(u.username, encodeUserPassword)
+ if u.passwordSet {
+ s += ":" + escape(u.password, encodeUserPassword)
+ }
+ return s
+}
+
+// Maybe rawURL is of the form scheme:path.
+// (Scheme must be [a-zA-Z][a-zA-Z0-9+.-]*)
+// If so, return scheme, path; else return "", rawURL.
+func getScheme(rawURL string) (scheme, path string, err error) {
+ for i := 0; i < len(rawURL); i++ {
+ c := rawURL[i]
+ switch {
+ case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z':
+ // do nothing
+ case '0' <= c && c <= '9' || c == '+' || c == '-' || c == '.':
+ if i == 0 {
+ return "", rawURL, nil
+ }
+ case c == ':':
+ if i == 0 {
+ return "", "", errors.New("missing protocol scheme")
+ }
+ return rawURL[:i], rawURL[i+1:], nil
+ default:
+ // we have encountered an invalid character,
+ // so there is no valid scheme
+ return "", rawURL, nil
+ }
+ }
+ return "", rawURL, nil
+}
+
+// Parse parses a raw url into a URL structure.
+//
+// The url may be relative (a path, without a host) or absolute
+// (starting with a scheme). Trying to parse a hostname and path
+// without a scheme is invalid but may not necessarily return an
+// error, due to parsing ambiguities.
+func Parse(rawURL string) (*URL, error) {
+ // Cut off #frag
+ u, frag, _ := strings.Cut(rawURL, "#")
+ url, err := parse(u, false)
+ if err != nil {
+ return nil, &Error{"parse", u, err}
+ }
+ if frag == "" {
+ return url, nil
+ }
+ if err = url.setFragment(frag); err != nil {
+ return nil, &Error{"parse", rawURL, err}
+ }
+ return url, nil
+}
+
+// ParseRequestURI parses a raw url into a URL structure. It assumes that
+// url was received in an HTTP request, so the url is interpreted
+// only as an absolute URI or an absolute path.
+// The string url is assumed not to have a #fragment suffix.
+// (Web browsers strip #fragment before sending the URL to a web server.)
+func ParseRequestURI(rawURL string) (*URL, error) {
+ url, err := parse(rawURL, true)
+ if err != nil {
+ return nil, &Error{"parse", rawURL, err}
+ }
+ return url, nil
+}
+
+// parse parses a URL from a string in one of two contexts. If
+// viaRequest is true, the URL is assumed to have arrived via an HTTP request,
+// in which case only absolute URLs or path-absolute relative URLs are allowed.
+// If viaRequest is false, all forms of relative URLs are allowed.
+func parse(rawURL string, viaRequest bool) (*URL, error) {
+ var rest string
+ var err error
+
+ if stringContainsCTLByte(rawURL) {
+ return nil, errors.New("net/url: invalid control character in URL")
+ }
+
+ if rawURL == "" && viaRequest {
+ return nil, errors.New("empty url")
+ }
+ url := new(URL)
+
+ if rawURL == "*" {
+ url.Path = "*"
+ return url, nil
+ }
+
+ // Split off possible leading "http:", "mailto:", etc.
+ // Cannot contain escaped characters.
+ if url.Scheme, rest, err = getScheme(rawURL); err != nil {
+ return nil, err
+ }
+ url.Scheme = strings.ToLower(url.Scheme)
+
+ if strings.HasSuffix(rest, "?") && strings.Count(rest, "?") == 1 {
+ url.ForceQuery = true
+ rest = rest[:len(rest)-1]
+ } else {
+ rest, url.RawQuery, _ = strings.Cut(rest, "?")
+ }
+
+ if !strings.HasPrefix(rest, "/") {
+ if url.Scheme != "" {
+ // We consider rootless paths per RFC 3986 as opaque.
+ url.Opaque = rest
+ return url, nil
+ }
+ if viaRequest {
+ return nil, errors.New("invalid URI for request")
+ }
+
+ // Avoid confusion with malformed schemes, like cache_object:foo/bar.
+ // See golang.org/issue/16822.
+ //
+ // RFC 3986, §3.3:
+ // In addition, a URI reference (Section 4.1) may be a relative-path reference,
+ // in which case the first path segment cannot contain a colon (":") character.
+ if segment, _, _ := strings.Cut(rest, "/"); strings.Contains(segment, ":") {
+ // First path segment has colon. Not allowed in relative URL.
+ return nil, errors.New("first path segment in URL cannot contain colon")
+ }
+ }
+
+ if (url.Scheme != "" || !viaRequest && !strings.HasPrefix(rest, "///")) && strings.HasPrefix(rest, "//") {
+ var authority string
+ authority, rest = rest[2:], ""
+ if i := strings.Index(authority, "/"); i >= 0 {
+ authority, rest = authority[:i], authority[i:]
+ }
+ url.User, url.Host, err = parseAuthority(authority)
+ if err != nil {
+ return nil, err
+ }
+ } else if url.Scheme != "" && strings.HasPrefix(rest, "/") {
+ // OmitHost is set to true when rawURL has an empty host (authority).
+ // See golang.org/issue/46059.
+ url.OmitHost = true
+ }
+
+ // Set Path and, optionally, RawPath.
+ // RawPath is a hint of the encoding of Path. We don't want to set it if
+ // the default escaping of Path is equivalent, to help make sure that people
+ // don't rely on it in general.
+ if err := url.setPath(rest); err != nil {
+ return nil, err
+ }
+ return url, nil
+}
+
+func parseAuthority(authority string) (user *Userinfo, host string, err error) {
+ i := strings.LastIndex(authority, "@")
+ if i < 0 {
+ host, err = parseHost(authority)
+ } else {
+ host, err = parseHost(authority[i+1:])
+ }
+ if err != nil {
+ return nil, "", err
+ }
+ if i < 0 {
+ return nil, host, nil
+ }
+ userinfo := authority[:i]
+ if !validUserinfo(userinfo) {
+ return nil, "", errors.New("net/url: invalid userinfo")
+ }
+ if !strings.Contains(userinfo, ":") {
+ if userinfo, err = unescape(userinfo, encodeUserPassword); err != nil {
+ return nil, "", err
+ }
+ user = User(userinfo)
+ } else {
+ username, password, _ := strings.Cut(userinfo, ":")
+ if username, err = unescape(username, encodeUserPassword); err != nil {
+ return nil, "", err
+ }
+ if password, err = unescape(password, encodeUserPassword); err != nil {
+ return nil, "", err
+ }
+ user = UserPassword(username, password)
+ }
+ return user, host, nil
+}
+
+// parseHost parses host as an authority without user
+// information. That is, as host[:port].
+func parseHost(host string) (string, error) {
+ if strings.HasPrefix(host, "[") {
+ // Parse an IP-Literal in RFC 3986 and RFC 6874.
+ // E.g., "[fe80::1]", "[fe80::1%25en0]", "[fe80::1]:80".
+ i := strings.LastIndex(host, "]")
+ if i < 0 {
+ return "", errors.New("missing ']' in host")
+ }
+ colonPort := host[i+1:]
+ if !validOptionalPort(colonPort) {
+ return "", fmt.Errorf("invalid port %q after host", colonPort)
+ }
+
+ // RFC 6874 defines that %25 (%-encoded percent) introduces
+ // the zone identifier, and the zone identifier can use basically
+ // any %-encoding it likes. That's different from the host, which
+ // can only %-encode non-ASCII bytes.
+ // We do impose some restrictions on the zone, to avoid stupidity
+ // like newlines.
+ zone := strings.Index(host[:i], "%25")
+ if zone >= 0 {
+ host1, err := unescape(host[:zone], encodeHost)
+ if err != nil {
+ return "", err
+ }
+ host2, err := unescape(host[zone:i], encodeZone)
+ if err != nil {
+ return "", err
+ }
+ host3, err := unescape(host[i:], encodeHost)
+ if err != nil {
+ return "", err
+ }
+ return host1 + host2 + host3, nil
+ }
+ } else if i := strings.LastIndex(host, ":"); i != -1 {
+ colonPort := host[i:]
+ if !validOptionalPort(colonPort) {
+ return "", fmt.Errorf("invalid port %q after host", colonPort)
+ }
+ }
+
+ var err error
+ if host, err = unescape(host, encodeHost); err != nil {
+ return "", err
+ }
+ return host, nil
+}
+
+// setPath sets the Path and RawPath fields of the URL based on the provided
+// escaped path p. It maintains the invariant that RawPath is only specified
+// when it differs from the default encoding of the path.
+// For example:
+// - setPath("/foo/bar") will set Path="/foo/bar" and RawPath=""
+// - setPath("/foo%2fbar") will set Path="/foo/bar" and RawPath="/foo%2fbar"
+// setPath will return an error only if the provided path contains an invalid
+// escaping.
+func (u *URL) setPath(p string) error {
+ path, err := unescape(p, encodePath)
+ if err != nil {
+ return err
+ }
+ u.Path = path
+ if escp := escape(path, encodePath); p == escp {
+ // Default encoding is fine.
+ u.RawPath = ""
+ } else {
+ u.RawPath = p
+ }
+ return nil
+}
+
+// EscapedPath returns the escaped form of u.Path.
+// In general there are multiple possible escaped forms of any path.
+// EscapedPath returns u.RawPath when it is a valid escaping of u.Path.
+// Otherwise EscapedPath ignores u.RawPath and computes an escaped
+// form on its own.
+// The String and RequestURI methods use EscapedPath to construct
+// their results.
+// In general, code should call EscapedPath instead of
+// reading u.RawPath directly.
+func (u *URL) EscapedPath() string {
+ if u.RawPath != "" && validEncoded(u.RawPath, encodePath) {
+ p, err := unescape(u.RawPath, encodePath)
+ if err == nil && p == u.Path {
+ return u.RawPath
+ }
+ }
+ if u.Path == "*" {
+ return "*" // don't escape (Issue 11202)
+ }
+ return escape(u.Path, encodePath)
+}
+
+// validEncoded reports whether s is a valid encoded path or fragment,
+// according to mode.
+// It must not contain any bytes that require escaping during encoding.
+func validEncoded(s string, mode encoding) bool {
+ for i := 0; i < len(s); i++ {
+ // RFC 3986, Appendix A.
+ // pchar = unreserved / pct-encoded / sub-delims / ":" / "@".
+ // shouldEscape is not quite compliant with the RFC,
+ // so we check the sub-delims ourselves and let
+ // shouldEscape handle the others.
+ switch s[i] {
+ case '!', '$', '&', '\'', '(', ')', '*', '+', ',', ';', '=', ':', '@':
+ // ok
+ case '[', ']':
+ // ok - not specified in RFC 3986 but left alone by modern browsers
+ case '%':
+ // ok - percent encoded, will decode
+ default:
+ if shouldEscape(s[i], mode) {
+ return false
+ }
+ }
+ }
+ return true
+}
+
+// setFragment is like setPath but for Fragment/RawFragment.
+func (u *URL) setFragment(f string) error {
+ frag, err := unescape(f, encodeFragment)
+ if err != nil {
+ return err
+ }
+ u.Fragment = frag
+ if escf := escape(frag, encodeFragment); f == escf {
+ // Default encoding is fine.
+ u.RawFragment = ""
+ } else {
+ u.RawFragment = f
+ }
+ return nil
+}
+
+// EscapedFragment returns the escaped form of u.Fragment.
+// In general there are multiple possible escaped forms of any fragment.
+// EscapedFragment returns u.RawFragment when it is a valid escaping of u.Fragment.
+// Otherwise EscapedFragment ignores u.RawFragment and computes an escaped
+// form on its own.
+// The String method uses EscapedFragment to construct its result.
+// In general, code should call EscapedFragment instead of
+// reading u.RawFragment directly.
+func (u *URL) EscapedFragment() string {
+ if u.RawFragment != "" && validEncoded(u.RawFragment, encodeFragment) {
+ f, err := unescape(u.RawFragment, encodeFragment)
+ if err == nil && f == u.Fragment {
+ return u.RawFragment
+ }
+ }
+ return escape(u.Fragment, encodeFragment)
+}
+
+// validOptionalPort reports whether port is either an empty string
+// or matches /^:\d*$/
+func validOptionalPort(port string) bool {
+ if port == "" {
+ return true
+ }
+ if port[0] != ':' {
+ return false
+ }
+ for _, b := range port[1:] {
+ if b < '0' || b > '9' {
+ return false
+ }
+ }
+ return true
+}
+
+// String reassembles the URL into a valid URL string.
+// The general form of the result is one of:
+//
+// scheme:opaque?query#fragment
+// scheme://userinfo@host/path?query#fragment
+//
+// If u.Opaque is non-empty, String uses the first form;
+// otherwise it uses the second form.
+// Any non-ASCII characters in host are escaped.
+// To obtain the path, String uses u.EscapedPath().
+//
+// In the second form, the following rules apply:
+// - if u.Scheme is empty, scheme: is omitted.
+// - if u.User is nil, userinfo@ is omitted.
+// - if u.Host is empty, host/ is omitted.
+// - if u.Scheme and u.Host are empty and u.User is nil,
+// the entire scheme://userinfo@host/ is omitted.
+// - if u.Host is non-empty and u.Path begins with a /,
+// the form host/path does not add its own /.
+// - if u.RawQuery is empty, ?query is omitted.
+// - if u.Fragment is empty, #fragment is omitted.
+func (u *URL) String() string {
+ var buf strings.Builder
+ if u.Scheme != "" {
+ buf.WriteString(u.Scheme)
+ buf.WriteByte(':')
+ }
+ if u.Opaque != "" {
+ buf.WriteString(u.Opaque)
+ } else {
+ if u.Scheme != "" || u.Host != "" || u.User != nil {
+ if u.OmitHost && u.Host == "" && u.User == nil {
+ // omit empty host
+ } else {
+ if u.Host != "" || u.Path != "" || u.User != nil {
+ buf.WriteString("//")
+ }
+ if ui := u.User; ui != nil {
+ buf.WriteString(ui.String())
+ buf.WriteByte('@')
+ }
+ if h := u.Host; h != "" {
+ buf.WriteString(escape(h, encodeHost))
+ }
+ }
+ }
+ path := u.EscapedPath()
+ if path != "" && path[0] != '/' && u.Host != "" {
+ buf.WriteByte('/')
+ }
+ if buf.Len() == 0 {
+ // RFC 3986 §4.2
+ // A path segment that contains a colon character (e.g., "this:that")
+ // cannot be used as the first segment of a relative-path reference, as
+ // it would be mistaken for a scheme name. Such a segment must be
+ // preceded by a dot-segment (e.g., "./this:that") to make a relative-
+ // path reference.
+ if segment, _, _ := strings.Cut(path, "/"); strings.Contains(segment, ":") {
+ buf.WriteString("./")
+ }
+ }
+ buf.WriteString(path)
+ }
+ if u.ForceQuery || u.RawQuery != "" {
+ buf.WriteByte('?')
+ buf.WriteString(u.RawQuery)
+ }
+ if u.Fragment != "" {
+ buf.WriteByte('#')
+ buf.WriteString(u.EscapedFragment())
+ }
+ return buf.String()
+}
+
+// Redacted is like String but replaces any password with "xxxxx".
+// Only the password in u.User is redacted.
+func (u *URL) Redacted() string {
+ if u == nil {
+ return ""
+ }
+
+ ru := *u
+ if _, has := ru.User.Password(); has {
+ ru.User = UserPassword(ru.User.Username(), "xxxxx")
+ }
+ return ru.String()
+}
+
+// Values maps a string key to a list of values.
+// It is typically used for query parameters and form values.
+// Unlike in the http.Header map, the keys in a Values map
+// are case-sensitive.
+type Values map[string][]string
+
+// Get gets the first value associated with the given key.
+// If there are no values associated with the key, Get returns
+// the empty string. To access multiple values, use the map
+// directly.
+func (v Values) Get(key string) string {
+ vs := v[key]
+ if len(vs) == 0 {
+ return ""
+ }
+ return vs[0]
+}
+
+// Set sets the key to value. It replaces any existing
+// values.
+func (v Values) Set(key, value string) {
+ v[key] = []string{value}
+}
+
+// Add adds the value to key. It appends to any existing
+// values associated with key.
+func (v Values) Add(key, value string) {
+ v[key] = append(v[key], value)
+}
+
+// Del deletes the values associated with key.
+func (v Values) Del(key string) {
+ delete(v, key)
+}
+
+// Has checks whether a given key is set.
+func (v Values) Has(key string) bool {
+ _, ok := v[key]
+ return ok
+}
+
+// ParseQuery parses the URL-encoded query string and returns
+// a map listing the values specified for each key.
+// ParseQuery always returns a non-nil map containing all the
+// valid query parameters found; err describes the first decoding error
+// encountered, if any.
+//
+// Query is expected to be a list of key=value settings separated by ampersands.
+// A setting without an equals sign is interpreted as a key set to an empty
+// value.
+// Settings containing a non-URL-encoded semicolon are considered invalid.
+func ParseQuery(query string) (Values, error) {
+ m := make(Values)
+ err := parseQuery(m, query)
+ return m, err
+}
+
+func parseQuery(m Values, query string) (err error) {
+ for query != "" {
+ var key string
+ key, query, _ = strings.Cut(query, "&")
+ if strings.Contains(key, ";") {
+ err = fmt.Errorf("invalid semicolon separator in query")
+ continue
+ }
+ if key == "" {
+ continue
+ }
+ key, value, _ := strings.Cut(key, "=")
+ key, err1 := QueryUnescape(key)
+ if err1 != nil {
+ if err == nil {
+ err = err1
+ }
+ continue
+ }
+ value, err1 = QueryUnescape(value)
+ if err1 != nil {
+ if err == nil {
+ err = err1
+ }
+ continue
+ }
+ m[key] = append(m[key], value)
+ }
+ return err
+}
+
+// Encode encodes the values into “URL encoded” form
+// ("bar=baz&foo=quux") sorted by key.
+func (v Values) Encode() string {
+ if v == nil {
+ return ""
+ }
+ var buf strings.Builder
+ keys := make([]string, 0, len(v))
+ for k := range v {
+ keys = append(keys, k)
+ }
+ sort.Strings(keys)
+ for _, k := range keys {
+ vs := v[k]
+ keyEscaped := QueryEscape(k)
+ for _, v := range vs {
+ if buf.Len() > 0 {
+ buf.WriteByte('&')
+ }
+ buf.WriteString(keyEscaped)
+ buf.WriteByte('=')
+ buf.WriteString(QueryEscape(v))
+ }
+ }
+ return buf.String()
+}
+
+// resolvePath applies special path segments from refs and applies
+// them to base, per RFC 3986.
+func resolvePath(base, ref string) string {
+ var full string
+ if ref == "" {
+ full = base
+ } else if ref[0] != '/' {
+ i := strings.LastIndex(base, "/")
+ full = base[:i+1] + ref
+ } else {
+ full = ref
+ }
+ if full == "" {
+ return ""
+ }
+
+ var (
+ elem string
+ dst strings.Builder
+ )
+ first := true
+ remaining := full
+ // We want to return a leading '/', so write it now.
+ dst.WriteByte('/')
+ found := true
+ for found {
+ elem, remaining, found = strings.Cut(remaining, "/")
+ if elem == "." {
+ first = false
+ // drop
+ continue
+ }
+
+ if elem == ".." {
+ // Ignore the leading '/' we already wrote.
+ str := dst.String()[1:]
+ index := strings.LastIndexByte(str, '/')
+
+ dst.Reset()
+ dst.WriteByte('/')
+ if index == -1 {
+ first = true
+ } else {
+ dst.WriteString(str[:index])
+ }
+ } else {
+ if !first {
+ dst.WriteByte('/')
+ }
+ dst.WriteString(elem)
+ first = false
+ }
+ }
+
+ if elem == "." || elem == ".." {
+ dst.WriteByte('/')
+ }
+
+ // We wrote an initial '/', but we don't want two.
+ r := dst.String()
+ if len(r) > 1 && r[1] == '/' {
+ r = r[1:]
+ }
+ return r
+}
+
+// IsAbs reports whether the URL is absolute.
+// Absolute means that it has a non-empty scheme.
+func (u *URL) IsAbs() bool {
+ return u.Scheme != ""
+}
+
+// Parse parses a URL in the context of the receiver. The provided URL
+// may be relative or absolute. Parse returns nil, err on parse
+// failure, otherwise its return value is the same as ResolveReference.
+func (u *URL) Parse(ref string) (*URL, error) {
+ refURL, err := Parse(ref)
+ if err != nil {
+ return nil, err
+ }
+ return u.ResolveReference(refURL), nil
+}
+
+// ResolveReference resolves a URI reference to an absolute URI from
+// an absolute base URI u, per RFC 3986 Section 5.2. The URI reference
+// may be relative or absolute. ResolveReference always returns a new
+// URL instance, even if the returned URL is identical to either the
+// base or reference. If ref is an absolute URL, then ResolveReference
+// ignores base and returns a copy of ref.
+func (u *URL) ResolveReference(ref *URL) *URL {
+ url := *ref
+ if ref.Scheme == "" {
+ url.Scheme = u.Scheme
+ }
+ if ref.Scheme != "" || ref.Host != "" || ref.User != nil {
+ // The "absoluteURI" or "net_path" cases.
+ // We can ignore the error from setPath since we know we provided a
+ // validly-escaped path.
+ url.setPath(resolvePath(ref.EscapedPath(), ""))
+ return &url
+ }
+ if ref.Opaque != "" {
+ url.User = nil
+ url.Host = ""
+ url.Path = ""
+ return &url
+ }
+ if ref.Path == "" && !ref.ForceQuery && ref.RawQuery == "" {
+ url.RawQuery = u.RawQuery
+ if ref.Fragment == "" {
+ url.Fragment = u.Fragment
+ url.RawFragment = u.RawFragment
+ }
+ }
+ // The "abs_path" or "rel_path" cases.
+ url.Host = u.Host
+ url.User = u.User
+ url.setPath(resolvePath(u.EscapedPath(), ref.EscapedPath()))
+ return &url
+}
+
+// Query parses RawQuery and returns the corresponding values.
+// It silently discards malformed value pairs.
+// To check errors use ParseQuery.
+func (u *URL) Query() Values {
+ v, _ := ParseQuery(u.RawQuery)
+ return v
+}
+
+// RequestURI returns the encoded path?query or opaque?query
+// string that would be used in an HTTP request for u.
+func (u *URL) RequestURI() string {
+ result := u.Opaque
+ if result == "" {
+ result = u.EscapedPath()
+ if result == "" {
+ result = "/"
+ }
+ } else {
+ if strings.HasPrefix(result, "//") {
+ result = u.Scheme + ":" + result
+ }
+ }
+ if u.ForceQuery || u.RawQuery != "" {
+ result += "?" + u.RawQuery
+ }
+ return result
+}
+
+// Hostname returns u.Host, stripping any valid port number if present.
+//
+// If the result is enclosed in square brackets, as literal IPv6 addresses are,
+// the square brackets are removed from the result.
+func (u *URL) Hostname() string {
+ host, _ := splitHostPort(u.Host)
+ return host
+}
+
+// Port returns the port part of u.Host, without the leading colon.
+//
+// If u.Host doesn't contain a valid numeric port, Port returns an empty string.
+func (u *URL) Port() string {
+ _, port := splitHostPort(u.Host)
+ return port
+}
+
+// splitHostPort separates host and port. If the port is not valid, it returns
+// the entire input as host, and it doesn't check the validity of the host.
+// Unlike net.SplitHostPort, but per RFC 3986, it requires ports to be numeric.
+func splitHostPort(hostPort string) (host, port string) {
+ host = hostPort
+
+ colon := strings.LastIndexByte(host, ':')
+ if colon != -1 && validOptionalPort(host[colon:]) {
+ host, port = host[:colon], host[colon+1:]
+ }
+
+ if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
+ host = host[1 : len(host)-1]
+ }
+
+ return
+}
+
+// Marshaling interface implementations.
+// Would like to implement MarshalText/UnmarshalText but that will change the JSON representation of URLs.
+
+func (u *URL) MarshalBinary() (text []byte, err error) {
+ return []byte(u.String()), nil
+}
+
+func (u *URL) UnmarshalBinary(text []byte) error {
+ u1, err := Parse(string(text))
+ if err != nil {
+ return err
+ }
+ *u = *u1
+ return nil
+}
+
+// JoinPath returns a new URL with the provided path elements joined to
+// any existing path and the resulting path cleaned of any ./ or ../ elements.
+// Any sequences of multiple / characters will be reduced to a single /.
+func (u *URL) JoinPath(elem ...string) *URL {
+ elem = append([]string{u.EscapedPath()}, elem...)
+ var p string
+ if !strings.HasPrefix(elem[0], "/") {
+ // Return a relative path if u is relative,
+ // but ensure that it contains no ../ elements.
+ elem[0] = "/" + elem[0]
+ p = path.Join(elem...)[1:]
+ } else {
+ p = path.Join(elem...)
+ }
+ // path.Join will remove any trailing slashes.
+ // Preserve at least one.
+ if strings.HasSuffix(elem[len(elem)-1], "/") && !strings.HasSuffix(p, "/") {
+ p += "/"
+ }
+ url := *u
+ url.setPath(p)
+ return &url
+}
+
+// validUserinfo reports whether s is a valid userinfo string per RFC 3986
+// Section 3.2.1:
+//
+// userinfo = *( unreserved / pct-encoded / sub-delims / ":" )
+// unreserved = ALPHA / DIGIT / "-" / "." / "_" / "~"
+// sub-delims = "!" / "$" / "&" / "'" / "(" / ")"
+// / "*" / "+" / "," / ";" / "="
+//
+// It doesn't validate pct-encoded. The caller does that via func unescape.
+func validUserinfo(s string) bool {
+ for _, r := range s {
+ if 'A' <= r && r <= 'Z' {
+ continue
+ }
+ if 'a' <= r && r <= 'z' {
+ continue
+ }
+ if '0' <= r && r <= '9' {
+ continue
+ }
+ switch r {
+ case '-', '.', '_', ':', '~', '!', '$', '&', '\'',
+ '(', ')', '*', '+', ',', ';', '=', '%', '@':
+ continue
+ default:
+ return false
+ }
+ }
+ return true
+}
+
+// stringContainsCTLByte reports whether s contains any ASCII control character.
+func stringContainsCTLByte(s string) bool {
+ for i := 0; i < len(s); i++ {
+ b := s[i]
+ if b < ' ' || b == 0x7f {
+ return true
+ }
+ }
+ return false
+}
+
+// JoinPath returns a URL string with the provided path elements joined to
+// the existing path of base and the resulting path cleaned of any ./ or ../ elements.
+func JoinPath(base string, elem ...string) (result string, err error) {
+ url, err := Parse(base)
+ if err != nil {
+ return
+ }
+ result = url.JoinPath(elem...).String()
+ return
+}
diff --git a/src/net/url/url_test.go b/src/net/url/url_test.go
new file mode 100644
index 0000000..23c5c58
--- /dev/null
+++ b/src/net/url/url_test.go
@@ -0,0 +1,2210 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package url
+
+import (
+ "bytes"
+ encodingPkg "encoding"
+ "encoding/gob"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net"
+ "reflect"
+ "strings"
+ "testing"
+)
+
+type URLTest struct {
+ in string
+ out *URL // expected parse
+ roundtrip string // expected result of reserializing the URL; empty means same as "in".
+}
+
+var urltests = []URLTest{
+ // no path
+ {
+ "http://www.google.com",
+ &URL{
+ Scheme: "http",
+ Host: "www.google.com",
+ },
+ "",
+ },
+ // path
+ {
+ "http://www.google.com/",
+ &URL{
+ Scheme: "http",
+ Host: "www.google.com",
+ Path: "/",
+ },
+ "",
+ },
+ // path with hex escaping
+ {
+ "http://www.google.com/file%20one%26two",
+ &URL{
+ Scheme: "http",
+ Host: "www.google.com",
+ Path: "/file one&two",
+ RawPath: "/file%20one%26two",
+ },
+ "",
+ },
+ // fragment with hex escaping
+ {
+ "http://www.google.com/#file%20one%26two",
+ &URL{
+ Scheme: "http",
+ Host: "www.google.com",
+ Path: "/",
+ Fragment: "file one&two",
+ RawFragment: "file%20one%26two",
+ },
+ "",
+ },
+ // user
+ {
+ "ftp://webmaster@www.google.com/",
+ &URL{
+ Scheme: "ftp",
+ User: User("webmaster"),
+ Host: "www.google.com",
+ Path: "/",
+ },
+ "",
+ },
+ // escape sequence in username
+ {
+ "ftp://john%20doe@www.google.com/",
+ &URL{
+ Scheme: "ftp",
+ User: User("john doe"),
+ Host: "www.google.com",
+ Path: "/",
+ },
+ "ftp://john%20doe@www.google.com/",
+ },
+ // empty query
+ {
+ "http://www.google.com/?",
+ &URL{
+ Scheme: "http",
+ Host: "www.google.com",
+ Path: "/",
+ ForceQuery: true,
+ },
+ "",
+ },
+ // query ending in question mark (Issue 14573)
+ {
+ "http://www.google.com/?foo=bar?",
+ &URL{
+ Scheme: "http",
+ Host: "www.google.com",
+ Path: "/",
+ RawQuery: "foo=bar?",
+ },
+ "",
+ },
+ // query
+ {
+ "http://www.google.com/?q=go+language",
+ &URL{
+ Scheme: "http",
+ Host: "www.google.com",
+ Path: "/",
+ RawQuery: "q=go+language",
+ },
+ "",
+ },
+ // query with hex escaping: NOT parsed
+ {
+ "http://www.google.com/?q=go%20language",
+ &URL{
+ Scheme: "http",
+ Host: "www.google.com",
+ Path: "/",
+ RawQuery: "q=go%20language",
+ },
+ "",
+ },
+ // %20 outside query
+ {
+ "http://www.google.com/a%20b?q=c+d",
+ &URL{
+ Scheme: "http",
+ Host: "www.google.com",
+ Path: "/a b",
+ RawQuery: "q=c+d",
+ },
+ "",
+ },
+ // path without leading /, so no parsing
+ {
+ "http:www.google.com/?q=go+language",
+ &URL{
+ Scheme: "http",
+ Opaque: "www.google.com/",
+ RawQuery: "q=go+language",
+ },
+ "http:www.google.com/?q=go+language",
+ },
+ // path without leading /, so no parsing
+ {
+ "http:%2f%2fwww.google.com/?q=go+language",
+ &URL{
+ Scheme: "http",
+ Opaque: "%2f%2fwww.google.com/",
+ RawQuery: "q=go+language",
+ },
+ "http:%2f%2fwww.google.com/?q=go+language",
+ },
+ // non-authority with path; see golang.org/issue/46059
+ {
+ "mailto:/webmaster@golang.org",
+ &URL{
+ Scheme: "mailto",
+ Path: "/webmaster@golang.org",
+ OmitHost: true,
+ },
+ "",
+ },
+ // non-authority
+ {
+ "mailto:webmaster@golang.org",
+ &URL{
+ Scheme: "mailto",
+ Opaque: "webmaster@golang.org",
+ },
+ "",
+ },
+ // unescaped :// in query should not create a scheme
+ {
+ "/foo?query=http://bad",
+ &URL{
+ Path: "/foo",
+ RawQuery: "query=http://bad",
+ },
+ "",
+ },
+ // leading // without scheme should create an authority
+ {
+ "//foo",
+ &URL{
+ Host: "foo",
+ },
+ "",
+ },
+ // leading // without scheme, with userinfo, path, and query
+ {
+ "//user@foo/path?a=b",
+ &URL{
+ User: User("user"),
+ Host: "foo",
+ Path: "/path",
+ RawQuery: "a=b",
+ },
+ "",
+ },
+ // Three leading slashes isn't an authority, but doesn't return an error.
+ // (We can't return an error, as this code is also used via
+ // ServeHTTP -> ReadRequest -> Parse, which is arguably a
+ // different URL parsing context, but currently shares the
+ // same codepath)
+ {
+ "///threeslashes",
+ &URL{
+ Path: "///threeslashes",
+ },
+ "",
+ },
+ {
+ "http://user:password@google.com",
+ &URL{
+ Scheme: "http",
+ User: UserPassword("user", "password"),
+ Host: "google.com",
+ },
+ "http://user:password@google.com",
+ },
+ // unescaped @ in username should not confuse host
+ {
+ "http://j@ne:password@google.com",
+ &URL{
+ Scheme: "http",
+ User: UserPassword("j@ne", "password"),
+ Host: "google.com",
+ },
+ "http://j%40ne:password@google.com",
+ },
+ // unescaped @ in password should not confuse host
+ {
+ "http://jane:p@ssword@google.com",
+ &URL{
+ Scheme: "http",
+ User: UserPassword("jane", "p@ssword"),
+ Host: "google.com",
+ },
+ "http://jane:p%40ssword@google.com",
+ },
+ {
+ "http://j@ne:password@google.com/p@th?q=@go",
+ &URL{
+ Scheme: "http",
+ User: UserPassword("j@ne", "password"),
+ Host: "google.com",
+ Path: "/p@th",
+ RawQuery: "q=@go",
+ },
+ "http://j%40ne:password@google.com/p@th?q=@go",
+ },
+ {
+ "http://www.google.com/?q=go+language#foo",
+ &URL{
+ Scheme: "http",
+ Host: "www.google.com",
+ Path: "/",
+ RawQuery: "q=go+language",
+ Fragment: "foo",
+ },
+ "",
+ },
+ {
+ "http://www.google.com/?q=go+language#foo&bar",
+ &URL{
+ Scheme: "http",
+ Host: "www.google.com",
+ Path: "/",
+ RawQuery: "q=go+language",
+ Fragment: "foo&bar",
+ },
+ "http://www.google.com/?q=go+language#foo&bar",
+ },
+ {
+ "http://www.google.com/?q=go+language#foo%26bar",
+ &URL{
+ Scheme: "http",
+ Host: "www.google.com",
+ Path: "/",
+ RawQuery: "q=go+language",
+ Fragment: "foo&bar",
+ RawFragment: "foo%26bar",
+ },
+ "http://www.google.com/?q=go+language#foo%26bar",
+ },
+ {
+ "file:///home/adg/rabbits",
+ &URL{
+ Scheme: "file",
+ Host: "",
+ Path: "/home/adg/rabbits",
+ },
+ "file:///home/adg/rabbits",
+ },
+ // "Windows" paths are no exception to the rule.
+ // See golang.org/issue/6027, especially comment #9.
+ {
+ "file:///C:/FooBar/Baz.txt",
+ &URL{
+ Scheme: "file",
+ Host: "",
+ Path: "/C:/FooBar/Baz.txt",
+ },
+ "file:///C:/FooBar/Baz.txt",
+ },
+ // case-insensitive scheme
+ {
+ "MaIlTo:webmaster@golang.org",
+ &URL{
+ Scheme: "mailto",
+ Opaque: "webmaster@golang.org",
+ },
+ "mailto:webmaster@golang.org",
+ },
+ // Relative path
+ {
+ "a/b/c",
+ &URL{
+ Path: "a/b/c",
+ },
+ "a/b/c",
+ },
+ // escaped '?' in username and password
+ {
+ "http://%3Fam:pa%3Fsword@google.com",
+ &URL{
+ Scheme: "http",
+ User: UserPassword("?am", "pa?sword"),
+ Host: "google.com",
+ },
+ "",
+ },
+ // host subcomponent; IPv4 address in RFC 3986
+ {
+ "http://192.168.0.1/",
+ &URL{
+ Scheme: "http",
+ Host: "192.168.0.1",
+ Path: "/",
+ },
+ "",
+ },
+ // host and port subcomponents; IPv4 address in RFC 3986
+ {
+ "http://192.168.0.1:8080/",
+ &URL{
+ Scheme: "http",
+ Host: "192.168.0.1:8080",
+ Path: "/",
+ },
+ "",
+ },
+ // host subcomponent; IPv6 address in RFC 3986
+ {
+ "http://[fe80::1]/",
+ &URL{
+ Scheme: "http",
+ Host: "[fe80::1]",
+ Path: "/",
+ },
+ "",
+ },
+ // host and port subcomponents; IPv6 address in RFC 3986
+ {
+ "http://[fe80::1]:8080/",
+ &URL{
+ Scheme: "http",
+ Host: "[fe80::1]:8080",
+ Path: "/",
+ },
+ "",
+ },
+ // host subcomponent; IPv6 address with zone identifier in RFC 6874
+ {
+ "http://[fe80::1%25en0]/", // alphanum zone identifier
+ &URL{
+ Scheme: "http",
+ Host: "[fe80::1%en0]",
+ Path: "/",
+ },
+ "",
+ },
+ // host and port subcomponents; IPv6 address with zone identifier in RFC 6874
+ {
+ "http://[fe80::1%25en0]:8080/", // alphanum zone identifier
+ &URL{
+ Scheme: "http",
+ Host: "[fe80::1%en0]:8080",
+ Path: "/",
+ },
+ "",
+ },
+ // host subcomponent; IPv6 address with zone identifier in RFC 6874
+ {
+ "http://[fe80::1%25%65%6e%301-._~]/", // percent-encoded+unreserved zone identifier
+ &URL{
+ Scheme: "http",
+ Host: "[fe80::1%en01-._~]",
+ Path: "/",
+ },
+ "http://[fe80::1%25en01-._~]/",
+ },
+ // host and port subcomponents; IPv6 address with zone identifier in RFC 6874
+ {
+ "http://[fe80::1%25%65%6e%301-._~]:8080/", // percent-encoded+unreserved zone identifier
+ &URL{
+ Scheme: "http",
+ Host: "[fe80::1%en01-._~]:8080",
+ Path: "/",
+ },
+ "http://[fe80::1%25en01-._~]:8080/",
+ },
+ // alternate escapings of path survive round trip
+ {
+ "http://rest.rsc.io/foo%2fbar/baz%2Fquux?alt=media",
+ &URL{
+ Scheme: "http",
+ Host: "rest.rsc.io",
+ Path: "/foo/bar/baz/quux",
+ RawPath: "/foo%2fbar/baz%2Fquux",
+ RawQuery: "alt=media",
+ },
+ "",
+ },
+ // issue 12036
+ {
+ "mysql://a,b,c/bar",
+ &URL{
+ Scheme: "mysql",
+ Host: "a,b,c",
+ Path: "/bar",
+ },
+ "",
+ },
+ // worst case host, still round trips
+ {
+ "scheme://!$&'()*+,;=hello!:1/path",
+ &URL{
+ Scheme: "scheme",
+ Host: "!$&'()*+,;=hello!:1",
+ Path: "/path",
+ },
+ "",
+ },
+ // worst case path, still round trips
+ {
+ "http://host/!$&'()*+,;=:@[hello]",
+ &URL{
+ Scheme: "http",
+ Host: "host",
+ Path: "/!$&'()*+,;=:@[hello]",
+ RawPath: "/!$&'()*+,;=:@[hello]",
+ },
+ "",
+ },
+ // golang.org/issue/5684
+ {
+ "http://example.com/oid/[order_id]",
+ &URL{
+ Scheme: "http",
+ Host: "example.com",
+ Path: "/oid/[order_id]",
+ RawPath: "/oid/[order_id]",
+ },
+ "",
+ },
+ // golang.org/issue/12200 (colon with empty port)
+ {
+ "http://192.168.0.2:8080/foo",
+ &URL{
+ Scheme: "http",
+ Host: "192.168.0.2:8080",
+ Path: "/foo",
+ },
+ "",
+ },
+ {
+ "http://192.168.0.2:/foo",
+ &URL{
+ Scheme: "http",
+ Host: "192.168.0.2:",
+ Path: "/foo",
+ },
+ "",
+ },
+ {
+ // Malformed IPv6 but still accepted.
+ "http://2b01:e34:ef40:7730:8e70:5aff:fefe:edac:8080/foo",
+ &URL{
+ Scheme: "http",
+ Host: "2b01:e34:ef40:7730:8e70:5aff:fefe:edac:8080",
+ Path: "/foo",
+ },
+ "",
+ },
+ {
+ // Malformed IPv6 but still accepted.
+ "http://2b01:e34:ef40:7730:8e70:5aff:fefe:edac:/foo",
+ &URL{
+ Scheme: "http",
+ Host: "2b01:e34:ef40:7730:8e70:5aff:fefe:edac:",
+ Path: "/foo",
+ },
+ "",
+ },
+ {
+ "http://[2b01:e34:ef40:7730:8e70:5aff:fefe:edac]:8080/foo",
+ &URL{
+ Scheme: "http",
+ Host: "[2b01:e34:ef40:7730:8e70:5aff:fefe:edac]:8080",
+ Path: "/foo",
+ },
+ "",
+ },
+ {
+ "http://[2b01:e34:ef40:7730:8e70:5aff:fefe:edac]:/foo",
+ &URL{
+ Scheme: "http",
+ Host: "[2b01:e34:ef40:7730:8e70:5aff:fefe:edac]:",
+ Path: "/foo",
+ },
+ "",
+ },
+ // golang.org/issue/7991 and golang.org/issue/12719 (non-ascii %-encoded in host)
+ {
+ "http://hello.世界.com/foo",
+ &URL{
+ Scheme: "http",
+ Host: "hello.世界.com",
+ Path: "/foo",
+ },
+ "http://hello.%E4%B8%96%E7%95%8C.com/foo",
+ },
+ {
+ "http://hello.%e4%b8%96%e7%95%8c.com/foo",
+ &URL{
+ Scheme: "http",
+ Host: "hello.世界.com",
+ Path: "/foo",
+ },
+ "http://hello.%E4%B8%96%E7%95%8C.com/foo",
+ },
+ {
+ "http://hello.%E4%B8%96%E7%95%8C.com/foo",
+ &URL{
+ Scheme: "http",
+ Host: "hello.世界.com",
+ Path: "/foo",
+ },
+ "",
+ },
+ // golang.org/issue/10433 (path beginning with //)
+ {
+ "http://example.com//foo",
+ &URL{
+ Scheme: "http",
+ Host: "example.com",
+ Path: "//foo",
+ },
+ "",
+ },
+ // test that we can reparse the host names we accept.
+ {
+ "myscheme://authority<\"hi\">/foo",
+ &URL{
+ Scheme: "myscheme",
+ Host: "authority<\"hi\">",
+ Path: "/foo",
+ },
+ "",
+ },
+ // spaces in hosts are disallowed but escaped spaces in IPv6 scope IDs are grudgingly OK.
+ // This happens on Windows.
+ // golang.org/issue/14002
+ {
+ "tcp://[2020::2020:20:2020:2020%25Windows%20Loves%20Spaces]:2020",
+ &URL{
+ Scheme: "tcp",
+ Host: "[2020::2020:20:2020:2020%Windows Loves Spaces]:2020",
+ },
+ "",
+ },
+ // test we can roundtrip magnet url
+ // fix issue https://golang.org/issue/20054
+ {
+ "magnet:?xt=urn:btih:c12fe1c06bba254a9dc9f519b335aa7c1367a88a&dn",
+ &URL{
+ Scheme: "magnet",
+ Host: "",
+ Path: "",
+ RawQuery: "xt=urn:btih:c12fe1c06bba254a9dc9f519b335aa7c1367a88a&dn",
+ },
+ "magnet:?xt=urn:btih:c12fe1c06bba254a9dc9f519b335aa7c1367a88a&dn",
+ },
+ {
+ "mailto:?subject=hi",
+ &URL{
+ Scheme: "mailto",
+ Host: "",
+ Path: "",
+ RawQuery: "subject=hi",
+ },
+ "mailto:?subject=hi",
+ },
+}
+
+// more useful string for debugging than fmt's struct printer
+func ufmt(u *URL) string {
+ var user, pass any
+ if u.User != nil {
+ user = u.User.Username()
+ if p, ok := u.User.Password(); ok {
+ pass = p
+ }
+ }
+ return fmt.Sprintf("opaque=%q, scheme=%q, user=%#v, pass=%#v, host=%q, path=%q, rawpath=%q, rawq=%q, frag=%q, rawfrag=%q, forcequery=%v, omithost=%t",
+ u.Opaque, u.Scheme, user, pass, u.Host, u.Path, u.RawPath, u.RawQuery, u.Fragment, u.RawFragment, u.ForceQuery, u.OmitHost)
+}
+
+func BenchmarkString(b *testing.B) {
+ b.StopTimer()
+ b.ReportAllocs()
+ for _, tt := range urltests {
+ u, err := Parse(tt.in)
+ if err != nil {
+ b.Errorf("Parse(%q) returned error %s", tt.in, err)
+ continue
+ }
+ if tt.roundtrip == "" {
+ continue
+ }
+ b.StartTimer()
+ var g string
+ for i := 0; i < b.N; i++ {
+ g = u.String()
+ }
+ b.StopTimer()
+ if w := tt.roundtrip; b.N > 0 && g != w {
+ b.Errorf("Parse(%q).String() == %q, want %q", tt.in, g, w)
+ }
+ }
+}
+
+func TestParse(t *testing.T) {
+ for _, tt := range urltests {
+ u, err := Parse(tt.in)
+ if err != nil {
+ t.Errorf("Parse(%q) returned error %v", tt.in, err)
+ continue
+ }
+ if !reflect.DeepEqual(u, tt.out) {
+ t.Errorf("Parse(%q):\n\tgot %v\n\twant %v\n", tt.in, ufmt(u), ufmt(tt.out))
+ }
+ }
+}
+
+const pathThatLooksSchemeRelative = "//not.a.user@not.a.host/just/a/path"
+
+var parseRequestURLTests = []struct {
+ url string
+ expectedValid bool
+}{
+ {"http://foo.com", true},
+ {"http://foo.com/", true},
+ {"http://foo.com/path", true},
+ {"/", true},
+ {pathThatLooksSchemeRelative, true},
+ {"//not.a.user@%66%6f%6f.com/just/a/path/also", true},
+ {"*", true},
+ {"http://192.168.0.1/", true},
+ {"http://192.168.0.1:8080/", true},
+ {"http://[fe80::1]/", true},
+ {"http://[fe80::1]:8080/", true},
+
+ // Tests exercising RFC 6874 compliance:
+ {"http://[fe80::1%25en0]/", true}, // with alphanum zone identifier
+ {"http://[fe80::1%25en0]:8080/", true}, // with alphanum zone identifier
+ {"http://[fe80::1%25%65%6e%301-._~]/", true}, // with percent-encoded+unreserved zone identifier
+ {"http://[fe80::1%25%65%6e%301-._~]:8080/", true}, // with percent-encoded+unreserved zone identifier
+
+ {"foo.html", false},
+ {"../dir/", false},
+ {" http://foo.com", false},
+ {"http://192.168.0.%31/", false},
+ {"http://192.168.0.%31:8080/", false},
+ {"http://[fe80::%31]/", false},
+ {"http://[fe80::%31]:8080/", false},
+ {"http://[fe80::%31%25en0]/", false},
+ {"http://[fe80::%31%25en0]:8080/", false},
+
+ // These two cases are valid as textual representations as
+ // described in RFC 4007, but are not valid as address
+ // literals with IPv6 zone identifiers in URIs as described in
+ // RFC 6874.
+ {"http://[fe80::1%en0]/", false},
+ {"http://[fe80::1%en0]:8080/", false},
+}
+
+func TestParseRequestURI(t *testing.T) {
+ for _, test := range parseRequestURLTests {
+ _, err := ParseRequestURI(test.url)
+ if test.expectedValid && err != nil {
+ t.Errorf("ParseRequestURI(%q) gave err %v; want no error", test.url, err)
+ } else if !test.expectedValid && err == nil {
+ t.Errorf("ParseRequestURI(%q) gave nil error; want some error", test.url)
+ }
+ }
+
+ url, err := ParseRequestURI(pathThatLooksSchemeRelative)
+ if err != nil {
+ t.Fatalf("Unexpected error %v", err)
+ }
+ if url.Path != pathThatLooksSchemeRelative {
+ t.Errorf("ParseRequestURI path:\ngot %q\nwant %q", url.Path, pathThatLooksSchemeRelative)
+ }
+}
+
+var stringURLTests = []struct {
+ url URL
+ want string
+}{
+ // No leading slash on path should prepend slash on String() call
+ {
+ url: URL{
+ Scheme: "http",
+ Host: "www.google.com",
+ Path: "search",
+ },
+ want: "http://www.google.com/search",
+ },
+ // Relative path with first element containing ":" should be prepended with "./", golang.org/issue/17184
+ {
+ url: URL{
+ Path: "this:that",
+ },
+ want: "./this:that",
+ },
+ // Relative path with second element containing ":" should not be prepended with "./"
+ {
+ url: URL{
+ Path: "here/this:that",
+ },
+ want: "here/this:that",
+ },
+ // Non-relative path with first element containing ":" should not be prepended with "./"
+ {
+ url: URL{
+ Scheme: "http",
+ Host: "www.google.com",
+ Path: "this:that",
+ },
+ want: "http://www.google.com/this:that",
+ },
+}
+
+func TestURLString(t *testing.T) {
+ for _, tt := range urltests {
+ u, err := Parse(tt.in)
+ if err != nil {
+ t.Errorf("Parse(%q) returned error %s", tt.in, err)
+ continue
+ }
+ expected := tt.in
+ if tt.roundtrip != "" {
+ expected = tt.roundtrip
+ }
+ s := u.String()
+ if s != expected {
+ t.Errorf("Parse(%q).String() == %q (expected %q)", tt.in, s, expected)
+ }
+ }
+
+ for _, tt := range stringURLTests {
+ if got := tt.url.String(); got != tt.want {
+ t.Errorf("%+v.String() = %q; want %q", tt.url, got, tt.want)
+ }
+ }
+}
+
+func TestURLRedacted(t *testing.T) {
+ cases := []struct {
+ name string
+ url *URL
+ want string
+ }{
+ {
+ name: "non-blank Password",
+ url: &URL{
+ Scheme: "http",
+ Host: "host.tld",
+ Path: "this:that",
+ User: UserPassword("user", "password"),
+ },
+ want: "http://user:xxxxx@host.tld/this:that",
+ },
+ {
+ name: "blank Password",
+ url: &URL{
+ Scheme: "http",
+ Host: "host.tld",
+ Path: "this:that",
+ User: User("user"),
+ },
+ want: "http://user@host.tld/this:that",
+ },
+ {
+ name: "nil User",
+ url: &URL{
+ Scheme: "http",
+ Host: "host.tld",
+ Path: "this:that",
+ User: UserPassword("", "password"),
+ },
+ want: "http://:xxxxx@host.tld/this:that",
+ },
+ {
+ name: "blank Username, blank Password",
+ url: &URL{
+ Scheme: "http",
+ Host: "host.tld",
+ Path: "this:that",
+ },
+ want: "http://host.tld/this:that",
+ },
+ {
+ name: "empty URL",
+ url: &URL{},
+ want: "",
+ },
+ {
+ name: "nil URL",
+ url: nil,
+ want: "",
+ },
+ }
+
+ for _, tt := range cases {
+ t := t
+ t.Run(tt.name, func(t *testing.T) {
+ if g, w := tt.url.Redacted(), tt.want; g != w {
+ t.Fatalf("got: %q\nwant: %q", g, w)
+ }
+ })
+ }
+}
+
+type EscapeTest struct {
+ in string
+ out string
+ err error
+}
+
+var unescapeTests = []EscapeTest{
+ {
+ "",
+ "",
+ nil,
+ },
+ {
+ "abc",
+ "abc",
+ nil,
+ },
+ {
+ "1%41",
+ "1A",
+ nil,
+ },
+ {
+ "1%41%42%43",
+ "1ABC",
+ nil,
+ },
+ {
+ "%4a",
+ "J",
+ nil,
+ },
+ {
+ "%6F",
+ "o",
+ nil,
+ },
+ {
+ "%", // not enough characters after %
+ "",
+ EscapeError("%"),
+ },
+ {
+ "%a", // not enough characters after %
+ "",
+ EscapeError("%a"),
+ },
+ {
+ "%1", // not enough characters after %
+ "",
+ EscapeError("%1"),
+ },
+ {
+ "123%45%6", // not enough characters after %
+ "",
+ EscapeError("%6"),
+ },
+ {
+ "%zzzzz", // invalid hex digits
+ "",
+ EscapeError("%zz"),
+ },
+ {
+ "a+b",
+ "a b",
+ nil,
+ },
+ {
+ "a%20b",
+ "a b",
+ nil,
+ },
+}
+
+func TestUnescape(t *testing.T) {
+ for _, tt := range unescapeTests {
+ actual, err := QueryUnescape(tt.in)
+ if actual != tt.out || (err != nil) != (tt.err != nil) {
+ t.Errorf("QueryUnescape(%q) = %q, %s; want %q, %s", tt.in, actual, err, tt.out, tt.err)
+ }
+
+ in := tt.in
+ out := tt.out
+ if strings.Contains(tt.in, "+") {
+ in = strings.ReplaceAll(tt.in, "+", "%20")
+ actual, err := PathUnescape(in)
+ if actual != tt.out || (err != nil) != (tt.err != nil) {
+ t.Errorf("PathUnescape(%q) = %q, %s; want %q, %s", in, actual, err, tt.out, tt.err)
+ }
+ if tt.err == nil {
+ s, err := QueryUnescape(strings.ReplaceAll(tt.in, "+", "XXX"))
+ if err != nil {
+ continue
+ }
+ in = tt.in
+ out = strings.ReplaceAll(s, "XXX", "+")
+ }
+ }
+
+ actual, err = PathUnescape(in)
+ if actual != out || (err != nil) != (tt.err != nil) {
+ t.Errorf("PathUnescape(%q) = %q, %s; want %q, %s", in, actual, err, out, tt.err)
+ }
+ }
+}
+
+var queryEscapeTests = []EscapeTest{
+ {
+ "",
+ "",
+ nil,
+ },
+ {
+ "abc",
+ "abc",
+ nil,
+ },
+ {
+ "one two",
+ "one+two",
+ nil,
+ },
+ {
+ "10%",
+ "10%25",
+ nil,
+ },
+ {
+ " ?&=#+%!<>#\"{}|\\^[]`☺\t:/@$'()*,;",
+ "+%3F%26%3D%23%2B%25%21%3C%3E%23%22%7B%7D%7C%5C%5E%5B%5D%60%E2%98%BA%09%3A%2F%40%24%27%28%29%2A%2C%3B",
+ nil,
+ },
+}
+
+func TestQueryEscape(t *testing.T) {
+ for _, tt := range queryEscapeTests {
+ actual := QueryEscape(tt.in)
+ if tt.out != actual {
+ t.Errorf("QueryEscape(%q) = %q, want %q", tt.in, actual, tt.out)
+ }
+
+ // for bonus points, verify that escape:unescape is an identity.
+ roundtrip, err := QueryUnescape(actual)
+ if roundtrip != tt.in || err != nil {
+ t.Errorf("QueryUnescape(%q) = %q, %s; want %q, %s", actual, roundtrip, err, tt.in, "[no error]")
+ }
+ }
+}
+
+var pathEscapeTests = []EscapeTest{
+ {
+ "",
+ "",
+ nil,
+ },
+ {
+ "abc",
+ "abc",
+ nil,
+ },
+ {
+ "abc+def",
+ "abc+def",
+ nil,
+ },
+ {
+ "a/b",
+ "a%2Fb",
+ nil,
+ },
+ {
+ "one two",
+ "one%20two",
+ nil,
+ },
+ {
+ "10%",
+ "10%25",
+ nil,
+ },
+ {
+ " ?&=#+%!<>#\"{}|\\^[]`☺\t:/@$'()*,;",
+ "%20%3F&=%23+%25%21%3C%3E%23%22%7B%7D%7C%5C%5E%5B%5D%60%E2%98%BA%09:%2F@$%27%28%29%2A%2C%3B",
+ nil,
+ },
+}
+
+func TestPathEscape(t *testing.T) {
+ for _, tt := range pathEscapeTests {
+ actual := PathEscape(tt.in)
+ if tt.out != actual {
+ t.Errorf("PathEscape(%q) = %q, want %q", tt.in, actual, tt.out)
+ }
+
+ // for bonus points, verify that escape:unescape is an identity.
+ roundtrip, err := PathUnescape(actual)
+ if roundtrip != tt.in || err != nil {
+ t.Errorf("PathUnescape(%q) = %q, %s; want %q, %s", actual, roundtrip, err, tt.in, "[no error]")
+ }
+ }
+}
+
+//var userinfoTests = []UserinfoTest{
+// {"user", "password", "user:password"},
+// {"foo:bar", "~!@#$%^&*()_+{}|[]\\-=`:;'\"<>?,./",
+// "foo%3Abar:~!%40%23$%25%5E&*()_+%7B%7D%7C%5B%5D%5C-=%60%3A;'%22%3C%3E?,.%2F"},
+//}
+
+type EncodeQueryTest struct {
+ m Values
+ expected string
+}
+
+var encodeQueryTests = []EncodeQueryTest{
+ {nil, ""},
+ {Values{"q": {"puppies"}, "oe": {"utf8"}}, "oe=utf8&q=puppies"},
+ {Values{"q": {"dogs", "&", "7"}}, "q=dogs&q=%26&q=7"},
+ {Values{
+ "a": {"a1", "a2", "a3"},
+ "b": {"b1", "b2", "b3"},
+ "c": {"c1", "c2", "c3"},
+ }, "a=a1&a=a2&a=a3&b=b1&b=b2&b=b3&c=c1&c=c2&c=c3"},
+}
+
+func TestEncodeQuery(t *testing.T) {
+ for _, tt := range encodeQueryTests {
+ if q := tt.m.Encode(); q != tt.expected {
+ t.Errorf(`EncodeQuery(%+v) = %q, want %q`, tt.m, q, tt.expected)
+ }
+ }
+}
+
+var resolvePathTests = []struct {
+ base, ref, expected string
+}{
+ {"a/b", ".", "/a/"},
+ {"a/b", "c", "/a/c"},
+ {"a/b", "..", "/"},
+ {"a/", "..", "/"},
+ {"a/", "../..", "/"},
+ {"a/b/c", "..", "/a/"},
+ {"a/b/c", "../d", "/a/d"},
+ {"a/b/c", ".././d", "/a/d"},
+ {"a/b", "./..", "/"},
+ {"a/./b", ".", "/a/"},
+ {"a/../", ".", "/"},
+ {"a/.././b", "c", "/c"},
+}
+
+func TestResolvePath(t *testing.T) {
+ for _, test := range resolvePathTests {
+ got := resolvePath(test.base, test.ref)
+ if got != test.expected {
+ t.Errorf("For %q + %q got %q; expected %q", test.base, test.ref, got, test.expected)
+ }
+ }
+}
+
+func BenchmarkResolvePath(b *testing.B) {
+ b.ReportAllocs()
+ for i := 0; i < b.N; i++ {
+ resolvePath("a/b/c", ".././d")
+ }
+}
+
+var resolveReferenceTests = []struct {
+ base, rel, expected string
+}{
+ // Absolute URL references
+ {"http://foo.com?a=b", "https://bar.com/", "https://bar.com/"},
+ {"http://foo.com/", "https://bar.com/?a=b", "https://bar.com/?a=b"},
+ {"http://foo.com/", "https://bar.com/?", "https://bar.com/?"},
+ {"http://foo.com/bar", "mailto:foo@example.com", "mailto:foo@example.com"},
+
+ // Path-absolute references
+ {"http://foo.com/bar", "/baz", "http://foo.com/baz"},
+ {"http://foo.com/bar?a=b#f", "/baz", "http://foo.com/baz"},
+ {"http://foo.com/bar?a=b", "/baz?", "http://foo.com/baz?"},
+ {"http://foo.com/bar?a=b", "/baz?c=d", "http://foo.com/baz?c=d"},
+
+ // Multiple slashes
+ {"http://foo.com/bar", "http://foo.com//baz", "http://foo.com//baz"},
+ {"http://foo.com/bar", "http://foo.com///baz/quux", "http://foo.com///baz/quux"},
+
+ // Scheme-relative
+ {"https://foo.com/bar?a=b", "//bar.com/quux", "https://bar.com/quux"},
+
+ // Path-relative references:
+
+ // ... current directory
+ {"http://foo.com", ".", "http://foo.com/"},
+ {"http://foo.com/bar", ".", "http://foo.com/"},
+ {"http://foo.com/bar/", ".", "http://foo.com/bar/"},
+
+ // ... going down
+ {"http://foo.com", "bar", "http://foo.com/bar"},
+ {"http://foo.com/", "bar", "http://foo.com/bar"},
+ {"http://foo.com/bar/baz", "quux", "http://foo.com/bar/quux"},
+
+ // ... going up
+ {"http://foo.com/bar/baz", "../quux", "http://foo.com/quux"},
+ {"http://foo.com/bar/baz", "../../../../../quux", "http://foo.com/quux"},
+ {"http://foo.com/bar", "..", "http://foo.com/"},
+ {"http://foo.com/bar/baz", "./..", "http://foo.com/"},
+ // ".." in the middle (issue 3560)
+ {"http://foo.com/bar/baz", "quux/dotdot/../tail", "http://foo.com/bar/quux/tail"},
+ {"http://foo.com/bar/baz", "quux/./dotdot/../tail", "http://foo.com/bar/quux/tail"},
+ {"http://foo.com/bar/baz", "quux/./dotdot/.././tail", "http://foo.com/bar/quux/tail"},
+ {"http://foo.com/bar/baz", "quux/./dotdot/./../tail", "http://foo.com/bar/quux/tail"},
+ {"http://foo.com/bar/baz", "quux/./dotdot/dotdot/././../../tail", "http://foo.com/bar/quux/tail"},
+ {"http://foo.com/bar/baz", "quux/./dotdot/dotdot/./.././../tail", "http://foo.com/bar/quux/tail"},
+ {"http://foo.com/bar/baz", "quux/./dotdot/dotdot/dotdot/./../../.././././tail", "http://foo.com/bar/quux/tail"},
+ {"http://foo.com/bar/baz", "quux/./dotdot/../dotdot/../dot/./tail/..", "http://foo.com/bar/quux/dot/"},
+
+ // Remove any dot-segments prior to forming the target URI.
+ // https://datatracker.ietf.org/doc/html/rfc3986#section-5.2.4
+ {"http://foo.com/dot/./dotdot/../foo/bar", "../baz", "http://foo.com/dot/baz"},
+
+ // Triple dot isn't special
+ {"http://foo.com/bar", "...", "http://foo.com/..."},
+
+ // Fragment
+ {"http://foo.com/bar", ".#frag", "http://foo.com/#frag"},
+ {"http://example.org/", "#!$&%27()*+,;=", "http://example.org/#!$&%27()*+,;="},
+
+ // Paths with escaping (issue 16947).
+ {"http://foo.com/foo%2fbar/", "../baz", "http://foo.com/baz"},
+ {"http://foo.com/1/2%2f/3%2f4/5", "../../a/b/c", "http://foo.com/1/a/b/c"},
+ {"http://foo.com/1/2/3", "./a%2f../../b/..%2fc", "http://foo.com/1/2/b/..%2fc"},
+ {"http://foo.com/1/2%2f/3%2f4/5", "./a%2f../b/../c", "http://foo.com/1/2%2f/3%2f4/a%2f../c"},
+ {"http://foo.com/foo%20bar/", "../baz", "http://foo.com/baz"},
+ {"http://foo.com/foo", "../bar%2fbaz", "http://foo.com/bar%2fbaz"},
+ {"http://foo.com/foo%2dbar/", "./baz-quux", "http://foo.com/foo%2dbar/baz-quux"},
+
+ // RFC 3986: Normal Examples
+ // https://datatracker.ietf.org/doc/html/rfc3986#section-5.4.1
+ {"http://a/b/c/d;p?q", "g:h", "g:h"},
+ {"http://a/b/c/d;p?q", "g", "http://a/b/c/g"},
+ {"http://a/b/c/d;p?q", "./g", "http://a/b/c/g"},
+ {"http://a/b/c/d;p?q", "g/", "http://a/b/c/g/"},
+ {"http://a/b/c/d;p?q", "/g", "http://a/g"},
+ {"http://a/b/c/d;p?q", "//g", "http://g"},
+ {"http://a/b/c/d;p?q", "?y", "http://a/b/c/d;p?y"},
+ {"http://a/b/c/d;p?q", "g?y", "http://a/b/c/g?y"},
+ {"http://a/b/c/d;p?q", "#s", "http://a/b/c/d;p?q#s"},
+ {"http://a/b/c/d;p?q", "g#s", "http://a/b/c/g#s"},
+ {"http://a/b/c/d;p?q", "g?y#s", "http://a/b/c/g?y#s"},
+ {"http://a/b/c/d;p?q", ";x", "http://a/b/c/;x"},
+ {"http://a/b/c/d;p?q", "g;x", "http://a/b/c/g;x"},
+ {"http://a/b/c/d;p?q", "g;x?y#s", "http://a/b/c/g;x?y#s"},
+ {"http://a/b/c/d;p?q", "", "http://a/b/c/d;p?q"},
+ {"http://a/b/c/d;p?q", ".", "http://a/b/c/"},
+ {"http://a/b/c/d;p?q", "./", "http://a/b/c/"},
+ {"http://a/b/c/d;p?q", "..", "http://a/b/"},
+ {"http://a/b/c/d;p?q", "../", "http://a/b/"},
+ {"http://a/b/c/d;p?q", "../g", "http://a/b/g"},
+ {"http://a/b/c/d;p?q", "../..", "http://a/"},
+ {"http://a/b/c/d;p?q", "../../", "http://a/"},
+ {"http://a/b/c/d;p?q", "../../g", "http://a/g"},
+
+ // RFC 3986: Abnormal Examples
+ // https://datatracker.ietf.org/doc/html/rfc3986#section-5.4.2
+ {"http://a/b/c/d;p?q", "../../../g", "http://a/g"},
+ {"http://a/b/c/d;p?q", "../../../../g", "http://a/g"},
+ {"http://a/b/c/d;p?q", "/./g", "http://a/g"},
+ {"http://a/b/c/d;p?q", "/../g", "http://a/g"},
+ {"http://a/b/c/d;p?q", "g.", "http://a/b/c/g."},
+ {"http://a/b/c/d;p?q", ".g", "http://a/b/c/.g"},
+ {"http://a/b/c/d;p?q", "g..", "http://a/b/c/g.."},
+ {"http://a/b/c/d;p?q", "..g", "http://a/b/c/..g"},
+ {"http://a/b/c/d;p?q", "./../g", "http://a/b/g"},
+ {"http://a/b/c/d;p?q", "./g/.", "http://a/b/c/g/"},
+ {"http://a/b/c/d;p?q", "g/./h", "http://a/b/c/g/h"},
+ {"http://a/b/c/d;p?q", "g/../h", "http://a/b/c/h"},
+ {"http://a/b/c/d;p?q", "g;x=1/./y", "http://a/b/c/g;x=1/y"},
+ {"http://a/b/c/d;p?q", "g;x=1/../y", "http://a/b/c/y"},
+ {"http://a/b/c/d;p?q", "g?y/./x", "http://a/b/c/g?y/./x"},
+ {"http://a/b/c/d;p?q", "g?y/../x", "http://a/b/c/g?y/../x"},
+ {"http://a/b/c/d;p?q", "g#s/./x", "http://a/b/c/g#s/./x"},
+ {"http://a/b/c/d;p?q", "g#s/../x", "http://a/b/c/g#s/../x"},
+
+ // Extras.
+ {"https://a/b/c/d;p?q", "//g?q", "https://g?q"},
+ {"https://a/b/c/d;p?q", "//g#s", "https://g#s"},
+ {"https://a/b/c/d;p?q", "//g/d/e/f?y#s", "https://g/d/e/f?y#s"},
+ {"https://a/b/c/d;p#s", "?y", "https://a/b/c/d;p?y"},
+ {"https://a/b/c/d;p?q#s", "?y", "https://a/b/c/d;p?y"},
+
+ // Empty path and query but with ForceQuery (issue 46033).
+ {"https://a/b/c/d;p?q#s", "?", "https://a/b/c/d;p?"},
+}
+
+func TestResolveReference(t *testing.T) {
+ mustParse := func(url string) *URL {
+ u, err := Parse(url)
+ if err != nil {
+ t.Fatalf("Parse(%q) got err %v", url, err)
+ }
+ return u
+ }
+ opaque := &URL{Scheme: "scheme", Opaque: "opaque"}
+ for _, test := range resolveReferenceTests {
+ base := mustParse(test.base)
+ rel := mustParse(test.rel)
+ url := base.ResolveReference(rel)
+ if got := url.String(); got != test.expected {
+ t.Errorf("URL(%q).ResolveReference(%q)\ngot %q\nwant %q", test.base, test.rel, got, test.expected)
+ }
+ // Ensure that new instances are returned.
+ if base == url {
+ t.Errorf("Expected URL.ResolveReference to return new URL instance.")
+ }
+ // Test the convenience wrapper too.
+ url, err := base.Parse(test.rel)
+ if err != nil {
+ t.Errorf("URL(%q).Parse(%q) failed: %v", test.base, test.rel, err)
+ } else if got := url.String(); got != test.expected {
+ t.Errorf("URL(%q).Parse(%q)\ngot %q\nwant %q", test.base, test.rel, got, test.expected)
+ } else if base == url {
+ // Ensure that new instances are returned for the wrapper too.
+ t.Errorf("Expected URL.Parse to return new URL instance.")
+ }
+ // Ensure Opaque resets the URL.
+ url = base.ResolveReference(opaque)
+ if *url != *opaque {
+ t.Errorf("ResolveReference failed to resolve opaque URL:\ngot %#v\nwant %#v", url, opaque)
+ }
+ // Test the convenience wrapper with an opaque URL too.
+ url, err = base.Parse("scheme:opaque")
+ if err != nil {
+ t.Errorf(`URL(%q).Parse("scheme:opaque") failed: %v`, test.base, err)
+ } else if *url != *opaque {
+ t.Errorf("Parse failed to resolve opaque URL:\ngot %#v\nwant %#v", opaque, url)
+ } else if base == url {
+ // Ensure that new instances are returned, again.
+ t.Errorf("Expected URL.Parse to return new URL instance.")
+ }
+ }
+}
+
+func TestQueryValues(t *testing.T) {
+ u, _ := Parse("http://x.com?foo=bar&bar=1&bar=2&baz")
+ v := u.Query()
+ if len(v) != 3 {
+ t.Errorf("got %d keys in Query values, want 3", len(v))
+ }
+ if g, e := v.Get("foo"), "bar"; g != e {
+ t.Errorf("Get(foo) = %q, want %q", g, e)
+ }
+ // Case sensitive:
+ if g, e := v.Get("Foo"), ""; g != e {
+ t.Errorf("Get(Foo) = %q, want %q", g, e)
+ }
+ if g, e := v.Get("bar"), "1"; g != e {
+ t.Errorf("Get(bar) = %q, want %q", g, e)
+ }
+ if g, e := v.Get("baz"), ""; g != e {
+ t.Errorf("Get(baz) = %q, want %q", g, e)
+ }
+ if h, e := v.Has("foo"), true; h != e {
+ t.Errorf("Has(foo) = %t, want %t", h, e)
+ }
+ if h, e := v.Has("bar"), true; h != e {
+ t.Errorf("Has(bar) = %t, want %t", h, e)
+ }
+ if h, e := v.Has("baz"), true; h != e {
+ t.Errorf("Has(baz) = %t, want %t", h, e)
+ }
+ if h, e := v.Has("noexist"), false; h != e {
+ t.Errorf("Has(noexist) = %t, want %t", h, e)
+ }
+ v.Del("bar")
+ if g, e := v.Get("bar"), ""; g != e {
+ t.Errorf("second Get(bar) = %q, want %q", g, e)
+ }
+}
+
+type parseTest struct {
+ query string
+ out Values
+ ok bool
+}
+
+var parseTests = []parseTest{
+ {
+ query: "a=1",
+ out: Values{"a": []string{"1"}},
+ ok: true,
+ },
+ {
+ query: "a=1&b=2",
+ out: Values{"a": []string{"1"}, "b": []string{"2"}},
+ ok: true,
+ },
+ {
+ query: "a=1&a=2&a=banana",
+ out: Values{"a": []string{"1", "2", "banana"}},
+ ok: true,
+ },
+ {
+ query: "ascii=%3Ckey%3A+0x90%3E",
+ out: Values{"ascii": []string{"<key: 0x90>"}},
+ ok: true,
+ }, {
+ query: "a=1;b=2",
+ out: Values{},
+ ok: false,
+ }, {
+ query: "a;b=1",
+ out: Values{},
+ ok: false,
+ }, {
+ query: "a=%3B", // hex encoding for semicolon
+ out: Values{"a": []string{";"}},
+ ok: true,
+ },
+ {
+ query: "a%3Bb=1",
+ out: Values{"a;b": []string{"1"}},
+ ok: true,
+ },
+ {
+ query: "a=1&a=2;a=banana",
+ out: Values{"a": []string{"1"}},
+ ok: false,
+ },
+ {
+ query: "a;b&c=1",
+ out: Values{"c": []string{"1"}},
+ ok: false,
+ },
+ {
+ query: "a=1&b=2;a=3&c=4",
+ out: Values{"a": []string{"1"}, "c": []string{"4"}},
+ ok: false,
+ },
+ {
+ query: "a=1&b=2;c=3",
+ out: Values{"a": []string{"1"}},
+ ok: false,
+ },
+ {
+ query: ";",
+ out: Values{},
+ ok: false,
+ },
+ {
+ query: "a=1;",
+ out: Values{},
+ ok: false,
+ },
+ {
+ query: "a=1&;",
+ out: Values{"a": []string{"1"}},
+ ok: false,
+ },
+ {
+ query: ";a=1&b=2",
+ out: Values{"b": []string{"2"}},
+ ok: false,
+ },
+ {
+ query: "a=1&b=2;",
+ out: Values{"a": []string{"1"}},
+ ok: false,
+ },
+}
+
+func TestParseQuery(t *testing.T) {
+ for _, test := range parseTests {
+ t.Run(test.query, func(t *testing.T) {
+ form, err := ParseQuery(test.query)
+ if test.ok != (err == nil) {
+ want := "<error>"
+ if test.ok {
+ want = "<nil>"
+ }
+ t.Errorf("Unexpected error: %v, want %v", err, want)
+ }
+ if len(form) != len(test.out) {
+ t.Errorf("len(form) = %d, want %d", len(form), len(test.out))
+ }
+ for k, evs := range test.out {
+ vs, ok := form[k]
+ if !ok {
+ t.Errorf("Missing key %q", k)
+ continue
+ }
+ if len(vs) != len(evs) {
+ t.Errorf("len(form[%q]) = %d, want %d", k, len(vs), len(evs))
+ continue
+ }
+ for j, ev := range evs {
+ if v := vs[j]; v != ev {
+ t.Errorf("form[%q][%d] = %q, want %q", k, j, v, ev)
+ }
+ }
+ }
+ })
+ }
+}
+
+type RequestURITest struct {
+ url *URL
+ out string
+}
+
+var requritests = []RequestURITest{
+ {
+ &URL{
+ Scheme: "http",
+ Host: "example.com",
+ Path: "",
+ },
+ "/",
+ },
+ {
+ &URL{
+ Scheme: "http",
+ Host: "example.com",
+ Path: "/a b",
+ },
+ "/a%20b",
+ },
+ // golang.org/issue/4860 variant 1
+ {
+ &URL{
+ Scheme: "http",
+ Host: "example.com",
+ Opaque: "/%2F/%2F/",
+ },
+ "/%2F/%2F/",
+ },
+ // golang.org/issue/4860 variant 2
+ {
+ &URL{
+ Scheme: "http",
+ Host: "example.com",
+ Opaque: "//other.example.com/%2F/%2F/",
+ },
+ "http://other.example.com/%2F/%2F/",
+ },
+ // better fix for issue 4860
+ {
+ &URL{
+ Scheme: "http",
+ Host: "example.com",
+ Path: "/////",
+ RawPath: "/%2F/%2F/",
+ },
+ "/%2F/%2F/",
+ },
+ {
+ &URL{
+ Scheme: "http",
+ Host: "example.com",
+ Path: "/////",
+ RawPath: "/WRONG/", // ignored because doesn't match Path
+ },
+ "/////",
+ },
+ {
+ &URL{
+ Scheme: "http",
+ Host: "example.com",
+ Path: "/a b",
+ RawQuery: "q=go+language",
+ },
+ "/a%20b?q=go+language",
+ },
+ {
+ &URL{
+ Scheme: "http",
+ Host: "example.com",
+ Path: "/a b",
+ RawPath: "/a b", // ignored because invalid
+ RawQuery: "q=go+language",
+ },
+ "/a%20b?q=go+language",
+ },
+ {
+ &URL{
+ Scheme: "http",
+ Host: "example.com",
+ Path: "/a?b",
+ RawPath: "/a?b", // ignored because invalid
+ RawQuery: "q=go+language",
+ },
+ "/a%3Fb?q=go+language",
+ },
+ {
+ &URL{
+ Scheme: "myschema",
+ Opaque: "opaque",
+ },
+ "opaque",
+ },
+ {
+ &URL{
+ Scheme: "myschema",
+ Opaque: "opaque",
+ RawQuery: "q=go+language",
+ },
+ "opaque?q=go+language",
+ },
+ {
+ &URL{
+ Scheme: "http",
+ Host: "example.com",
+ Path: "//foo",
+ },
+ "//foo",
+ },
+ {
+ &URL{
+ Scheme: "http",
+ Host: "example.com",
+ Path: "/foo",
+ ForceQuery: true,
+ },
+ "/foo?",
+ },
+}
+
+func TestRequestURI(t *testing.T) {
+ for _, tt := range requritests {
+ s := tt.url.RequestURI()
+ if s != tt.out {
+ t.Errorf("%#v.RequestURI() == %q (expected %q)", tt.url, s, tt.out)
+ }
+ }
+}
+
+func TestParseFailure(t *testing.T) {
+ // Test that the first parse error is returned.
+ const url = "%gh&%ij"
+ _, err := ParseQuery(url)
+ errStr := fmt.Sprint(err)
+ if !strings.Contains(errStr, "%gh") {
+ t.Errorf(`ParseQuery(%q) returned error %q, want something containing %q"`, url, errStr, "%gh")
+ }
+}
+
+func TestParseErrors(t *testing.T) {
+ tests := []struct {
+ in string
+ wantErr bool
+ }{
+ {"http://[::1]", false},
+ {"http://[::1]:80", false},
+ {"http://[::1]:namedport", true}, // rfc3986 3.2.3
+ {"http://x:namedport", true}, // rfc3986 3.2.3
+ {"http://[::1]/", false},
+ {"http://[::1]a", true},
+ {"http://[::1]%23", true},
+ {"http://[::1%25en0]", false}, // valid zone id
+ {"http://[::1]:", false}, // colon, but no port OK
+ {"http://x:", false}, // colon, but no port OK
+ {"http://[::1]:%38%30", true}, // not allowed: % encoding only for non-ASCII
+ {"http://[::1%25%41]", false}, // RFC 6874 allows over-escaping in zone
+ {"http://[%10::1]", true}, // no %xx escapes in IP address
+ {"http://[::1]/%48", false}, // %xx in path is fine
+ {"http://%41:8080/", true}, // not allowed: % encoding only for non-ASCII
+ {"mysql://x@y(z:123)/foo", true}, // not well-formed per RFC 3986, golang.org/issue/33646
+ {"mysql://x@y(1.2.3.4:123)/foo", true},
+
+ {" http://foo.com", true}, // invalid character in schema
+ {"ht tp://foo.com", true}, // invalid character in schema
+ {"ahttp://foo.com", false}, // valid schema characters
+ {"1http://foo.com", true}, // invalid character in schema
+
+ {"http://[]%20%48%54%54%50%2f%31%2e%31%0a%4d%79%48%65%61%64%65%72%3a%20%31%32%33%0a%0a/", true}, // golang.org/issue/11208
+ {"http://a b.com/", true}, // no space in host name please
+ {"cache_object://foo", true}, // scheme cannot have _, relative path cannot have : in first segment
+ {"cache_object:foo", true},
+ {"cache_object:foo/bar", true},
+ {"cache_object/:foo/bar", false},
+ }
+ for _, tt := range tests {
+ u, err := Parse(tt.in)
+ if tt.wantErr {
+ if err == nil {
+ t.Errorf("Parse(%q) = %#v; want an error", tt.in, u)
+ }
+ continue
+ }
+ if err != nil {
+ t.Errorf("Parse(%q) = %v; want no error", tt.in, err)
+ }
+ }
+}
+
+// Issue 11202
+func TestStarRequest(t *testing.T) {
+ u, err := Parse("*")
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got, want := u.RequestURI(), "*"; got != want {
+ t.Errorf("RequestURI = %q; want %q", got, want)
+ }
+}
+
+type shouldEscapeTest struct {
+ in byte
+ mode encoding
+ escape bool
+}
+
+var shouldEscapeTests = []shouldEscapeTest{
+ // Unreserved characters (§2.3)
+ {'a', encodePath, false},
+ {'a', encodeUserPassword, false},
+ {'a', encodeQueryComponent, false},
+ {'a', encodeFragment, false},
+ {'a', encodeHost, false},
+ {'z', encodePath, false},
+ {'A', encodePath, false},
+ {'Z', encodePath, false},
+ {'0', encodePath, false},
+ {'9', encodePath, false},
+ {'-', encodePath, false},
+ {'-', encodeUserPassword, false},
+ {'-', encodeQueryComponent, false},
+ {'-', encodeFragment, false},
+ {'.', encodePath, false},
+ {'_', encodePath, false},
+ {'~', encodePath, false},
+
+ // User information (§3.2.1)
+ {':', encodeUserPassword, true},
+ {'/', encodeUserPassword, true},
+ {'?', encodeUserPassword, true},
+ {'@', encodeUserPassword, true},
+ {'$', encodeUserPassword, false},
+ {'&', encodeUserPassword, false},
+ {'+', encodeUserPassword, false},
+ {',', encodeUserPassword, false},
+ {';', encodeUserPassword, false},
+ {'=', encodeUserPassword, false},
+
+ // Host (IP address, IPv6 address, registered name, port suffix; §3.2.2)
+ {'!', encodeHost, false},
+ {'$', encodeHost, false},
+ {'&', encodeHost, false},
+ {'\'', encodeHost, false},
+ {'(', encodeHost, false},
+ {')', encodeHost, false},
+ {'*', encodeHost, false},
+ {'+', encodeHost, false},
+ {',', encodeHost, false},
+ {';', encodeHost, false},
+ {'=', encodeHost, false},
+ {':', encodeHost, false},
+ {'[', encodeHost, false},
+ {']', encodeHost, false},
+ {'0', encodeHost, false},
+ {'9', encodeHost, false},
+ {'A', encodeHost, false},
+ {'z', encodeHost, false},
+ {'_', encodeHost, false},
+ {'-', encodeHost, false},
+ {'.', encodeHost, false},
+}
+
+func TestShouldEscape(t *testing.T) {
+ for _, tt := range shouldEscapeTests {
+ if shouldEscape(tt.in, tt.mode) != tt.escape {
+ t.Errorf("shouldEscape(%q, %v) returned %v; expected %v", tt.in, tt.mode, !tt.escape, tt.escape)
+ }
+ }
+}
+
+type timeoutError struct {
+ timeout bool
+}
+
+func (e *timeoutError) Error() string { return "timeout error" }
+func (e *timeoutError) Timeout() bool { return e.timeout }
+
+type temporaryError struct {
+ temporary bool
+}
+
+func (e *temporaryError) Error() string { return "temporary error" }
+func (e *temporaryError) Temporary() bool { return e.temporary }
+
+type timeoutTemporaryError struct {
+ timeoutError
+ temporaryError
+}
+
+func (e *timeoutTemporaryError) Error() string { return "timeout/temporary error" }
+
+var netErrorTests = []struct {
+ err error
+ timeout bool
+ temporary bool
+}{{
+ err: &Error{"Get", "http://google.com/", &timeoutError{timeout: true}},
+ timeout: true,
+ temporary: false,
+}, {
+ err: &Error{"Get", "http://google.com/", &timeoutError{timeout: false}},
+ timeout: false,
+ temporary: false,
+}, {
+ err: &Error{"Get", "http://google.com/", &temporaryError{temporary: true}},
+ timeout: false,
+ temporary: true,
+}, {
+ err: &Error{"Get", "http://google.com/", &temporaryError{temporary: false}},
+ timeout: false,
+ temporary: false,
+}, {
+ err: &Error{"Get", "http://google.com/", &timeoutTemporaryError{timeoutError{timeout: true}, temporaryError{temporary: true}}},
+ timeout: true,
+ temporary: true,
+}, {
+ err: &Error{"Get", "http://google.com/", &timeoutTemporaryError{timeoutError{timeout: false}, temporaryError{temporary: true}}},
+ timeout: false,
+ temporary: true,
+}, {
+ err: &Error{"Get", "http://google.com/", &timeoutTemporaryError{timeoutError{timeout: true}, temporaryError{temporary: false}}},
+ timeout: true,
+ temporary: false,
+}, {
+ err: &Error{"Get", "http://google.com/", &timeoutTemporaryError{timeoutError{timeout: false}, temporaryError{temporary: false}}},
+ timeout: false,
+ temporary: false,
+}, {
+ err: &Error{"Get", "http://google.com/", io.EOF},
+ timeout: false,
+ temporary: false,
+}}
+
+// Test that url.Error implements net.Error and that it forwards
+func TestURLErrorImplementsNetError(t *testing.T) {
+ for i, tt := range netErrorTests {
+ err, ok := tt.err.(net.Error)
+ if !ok {
+ t.Errorf("%d: %T does not implement net.Error", i+1, tt.err)
+ continue
+ }
+ if err.Timeout() != tt.timeout {
+ t.Errorf("%d: err.Timeout(): got %v, want %v", i+1, err.Timeout(), tt.timeout)
+ continue
+ }
+ if err.Temporary() != tt.temporary {
+ t.Errorf("%d: err.Temporary(): got %v, want %v", i+1, err.Temporary(), tt.temporary)
+ }
+ }
+}
+
+func TestURLHostnameAndPort(t *testing.T) {
+ tests := []struct {
+ in string // URL.Host field
+ host string
+ port string
+ }{
+ {"foo.com:80", "foo.com", "80"},
+ {"foo.com", "foo.com", ""},
+ {"foo.com:", "foo.com", ""},
+ {"FOO.COM", "FOO.COM", ""}, // no canonicalization
+ {"1.2.3.4", "1.2.3.4", ""},
+ {"1.2.3.4:80", "1.2.3.4", "80"},
+ {"[1:2:3:4]", "1:2:3:4", ""},
+ {"[1:2:3:4]:80", "1:2:3:4", "80"},
+ {"[::1]:80", "::1", "80"},
+ {"[::1]", "::1", ""},
+ {"[::1]:", "::1", ""},
+ {"localhost", "localhost", ""},
+ {"localhost:443", "localhost", "443"},
+ {"some.super.long.domain.example.org:8080", "some.super.long.domain.example.org", "8080"},
+ {"[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:17000", "2001:0db8:85a3:0000:0000:8a2e:0370:7334", "17000"},
+ {"[2001:0db8:85a3:0000:0000:8a2e:0370:7334]", "2001:0db8:85a3:0000:0000:8a2e:0370:7334", ""},
+
+ // Ensure that even when not valid, Host is one of "Hostname",
+ // "Hostname:Port", "[Hostname]" or "[Hostname]:Port".
+ // See https://golang.org/issue/29098.
+ {"[google.com]:80", "google.com", "80"},
+ {"google.com]:80", "google.com]", "80"},
+ {"google.com:80_invalid_port", "google.com:80_invalid_port", ""},
+ {"[::1]extra]:80", "::1]extra", "80"},
+ {"google.com]extra:extra", "google.com]extra:extra", ""},
+ }
+ for _, tt := range tests {
+ u := &URL{Host: tt.in}
+ host, port := u.Hostname(), u.Port()
+ if host != tt.host {
+ t.Errorf("Hostname for Host %q = %q; want %q", tt.in, host, tt.host)
+ }
+ if port != tt.port {
+ t.Errorf("Port for Host %q = %q; want %q", tt.in, port, tt.port)
+ }
+ }
+}
+
+var _ encodingPkg.BinaryMarshaler = (*URL)(nil)
+var _ encodingPkg.BinaryUnmarshaler = (*URL)(nil)
+
+func TestJSON(t *testing.T) {
+ u, err := Parse("https://www.google.com/x?y=z")
+ if err != nil {
+ t.Fatal(err)
+ }
+ js, err := json.Marshal(u)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // If only we could implement TextMarshaler/TextUnmarshaler,
+ // this would work:
+ //
+ // if string(js) != strconv.Quote(u.String()) {
+ // t.Errorf("json encoding: %s\nwant: %s\n", js, strconv.Quote(u.String()))
+ // }
+
+ u1 := new(URL)
+ err = json.Unmarshal(js, u1)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if u1.String() != u.String() {
+ t.Errorf("json decoded to: %s\nwant: %s\n", u1, u)
+ }
+}
+
+func TestGob(t *testing.T) {
+ u, err := Parse("https://www.google.com/x?y=z")
+ if err != nil {
+ t.Fatal(err)
+ }
+ var w bytes.Buffer
+ err = gob.NewEncoder(&w).Encode(u)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ u1 := new(URL)
+ err = gob.NewDecoder(&w).Decode(u1)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if u1.String() != u.String() {
+ t.Errorf("json decoded to: %s\nwant: %s\n", u1, u)
+ }
+}
+
+func TestNilUser(t *testing.T) {
+ defer func() {
+ if v := recover(); v != nil {
+ t.Fatalf("unexpected panic: %v", v)
+ }
+ }()
+
+ u, err := Parse("http://foo.com/")
+
+ if err != nil {
+ t.Fatalf("parse err: %v", err)
+ }
+
+ if v := u.User.Username(); v != "" {
+ t.Fatalf("expected empty username, got %s", v)
+ }
+
+ if v, ok := u.User.Password(); v != "" || ok {
+ t.Fatalf("expected empty password, got %s (%v)", v, ok)
+ }
+
+ if v := u.User.String(); v != "" {
+ t.Fatalf("expected empty string, got %s", v)
+ }
+}
+
+func TestInvalidUserPassword(t *testing.T) {
+ _, err := Parse("http://user^:passwo^rd@foo.com/")
+ if got, wantsub := fmt.Sprint(err), "net/url: invalid userinfo"; !strings.Contains(got, wantsub) {
+ t.Errorf("error = %q; want substring %q", got, wantsub)
+ }
+}
+
+func TestRejectControlCharacters(t *testing.T) {
+ tests := []string{
+ "http://foo.com/?foo\nbar",
+ "http\r://foo.com/",
+ "http://foo\x7f.com/",
+ }
+ for _, s := range tests {
+ _, err := Parse(s)
+ const wantSub = "net/url: invalid control character in URL"
+ if got := fmt.Sprint(err); !strings.Contains(got, wantSub) {
+ t.Errorf("Parse(%q) error = %q; want substring %q", s, got, wantSub)
+ }
+ }
+
+ // But don't reject non-ASCII CTLs, at least for now:
+ if _, err := Parse("http://foo.com/ctl\x80"); err != nil {
+ t.Errorf("error parsing URL with non-ASCII control byte: %v", err)
+ }
+
+}
+
+var escapeBenchmarks = []struct {
+ unescaped string
+ query string
+ path string
+}{
+ {
+ unescaped: "one two",
+ query: "one+two",
+ path: "one%20two",
+ },
+ {
+ unescaped: "Фотки собак",
+ query: "%D0%A4%D0%BE%D1%82%D0%BA%D0%B8+%D1%81%D0%BE%D0%B1%D0%B0%D0%BA",
+ path: "%D0%A4%D0%BE%D1%82%D0%BA%D0%B8%20%D1%81%D0%BE%D0%B1%D0%B0%D0%BA",
+ },
+
+ {
+ unescaped: "shortrun(break)shortrun",
+ query: "shortrun%28break%29shortrun",
+ path: "shortrun%28break%29shortrun",
+ },
+
+ {
+ unescaped: "longerrunofcharacters(break)anotherlongerrunofcharacters",
+ query: "longerrunofcharacters%28break%29anotherlongerrunofcharacters",
+ path: "longerrunofcharacters%28break%29anotherlongerrunofcharacters",
+ },
+
+ {
+ unescaped: strings.Repeat("padded/with+various%characters?that=need$some@escaping+paddedsowebreak/256bytes", 4),
+ query: strings.Repeat("padded%2Fwith%2Bvarious%25characters%3Fthat%3Dneed%24some%40escaping%2Bpaddedsowebreak%2F256bytes", 4),
+ path: strings.Repeat("padded%2Fwith+various%25characters%3Fthat=need$some@escaping+paddedsowebreak%2F256bytes", 4),
+ },
+}
+
+func BenchmarkQueryEscape(b *testing.B) {
+ for _, tc := range escapeBenchmarks {
+ b.Run("", func(b *testing.B) {
+ b.ReportAllocs()
+ var g string
+ for i := 0; i < b.N; i++ {
+ g = QueryEscape(tc.unescaped)
+ }
+ b.StopTimer()
+ if g != tc.query {
+ b.Errorf("QueryEscape(%q) == %q, want %q", tc.unescaped, g, tc.query)
+ }
+
+ })
+ }
+}
+
+func BenchmarkPathEscape(b *testing.B) {
+ for _, tc := range escapeBenchmarks {
+ b.Run("", func(b *testing.B) {
+ b.ReportAllocs()
+ var g string
+ for i := 0; i < b.N; i++ {
+ g = PathEscape(tc.unescaped)
+ }
+ b.StopTimer()
+ if g != tc.path {
+ b.Errorf("PathEscape(%q) == %q, want %q", tc.unescaped, g, tc.path)
+ }
+
+ })
+ }
+}
+
+func BenchmarkQueryUnescape(b *testing.B) {
+ for _, tc := range escapeBenchmarks {
+ b.Run("", func(b *testing.B) {
+ b.ReportAllocs()
+ var g string
+ for i := 0; i < b.N; i++ {
+ g, _ = QueryUnescape(tc.query)
+ }
+ b.StopTimer()
+ if g != tc.unescaped {
+ b.Errorf("QueryUnescape(%q) == %q, want %q", tc.query, g, tc.unescaped)
+ }
+
+ })
+ }
+}
+
+func BenchmarkPathUnescape(b *testing.B) {
+ for _, tc := range escapeBenchmarks {
+ b.Run("", func(b *testing.B) {
+ b.ReportAllocs()
+ var g string
+ for i := 0; i < b.N; i++ {
+ g, _ = PathUnescape(tc.path)
+ }
+ b.StopTimer()
+ if g != tc.unescaped {
+ b.Errorf("PathUnescape(%q) == %q, want %q", tc.path, g, tc.unescaped)
+ }
+
+ })
+ }
+}
+
+func TestJoinPath(t *testing.T) {
+ tests := []struct {
+ base string
+ elem []string
+ out string
+ }{
+ {
+ base: "https://go.googlesource.com",
+ elem: []string{"go"},
+ out: "https://go.googlesource.com/go",
+ },
+ {
+ base: "https://go.googlesource.com/a/b/c",
+ elem: []string{"../../../go"},
+ out: "https://go.googlesource.com/go",
+ },
+ {
+ base: "https://go.googlesource.com/",
+ elem: []string{"../go"},
+ out: "https://go.googlesource.com/go",
+ },
+ {
+ base: "https://go.googlesource.com",
+ elem: []string{"../go"},
+ out: "https://go.googlesource.com/go",
+ },
+ {
+ base: "https://go.googlesource.com",
+ elem: []string{"../go", "../../go", "../../../go"},
+ out: "https://go.googlesource.com/go",
+ },
+ {
+ base: "https://go.googlesource.com/../go",
+ elem: nil,
+ out: "https://go.googlesource.com/go",
+ },
+ {
+ base: "https://go.googlesource.com/",
+ elem: []string{"./go"},
+ out: "https://go.googlesource.com/go",
+ },
+ {
+ base: "https://go.googlesource.com//",
+ elem: []string{"/go"},
+ out: "https://go.googlesource.com/go",
+ },
+ {
+ base: "https://go.googlesource.com//",
+ elem: []string{"/go", "a", "b", "c"},
+ out: "https://go.googlesource.com/go/a/b/c",
+ },
+ {
+ base: "http://[fe80::1%en0]:8080/",
+ elem: []string{"/go"},
+ },
+ {
+ base: "https://go.googlesource.com",
+ elem: []string{"go/"},
+ out: "https://go.googlesource.com/go/",
+ },
+ {
+ base: "https://go.googlesource.com",
+ elem: []string{"go//"},
+ out: "https://go.googlesource.com/go/",
+ },
+ {
+ base: "https://go.googlesource.com",
+ elem: nil,
+ out: "https://go.googlesource.com/",
+ },
+ {
+ base: "https://go.googlesource.com/",
+ elem: nil,
+ out: "https://go.googlesource.com/",
+ },
+ {
+ base: "https://go.googlesource.com/a%2fb",
+ elem: []string{"c"},
+ out: "https://go.googlesource.com/a%2fb/c",
+ },
+ {
+ base: "https://go.googlesource.com/a%2fb",
+ elem: []string{"c%2fd"},
+ out: "https://go.googlesource.com/a%2fb/c%2fd",
+ },
+ {
+ base: "https://go.googlesource.com/a/b",
+ elem: []string{"/go"},
+ out: "https://go.googlesource.com/a/b/go",
+ },
+ {
+ base: "/",
+ elem: nil,
+ out: "/",
+ },
+ {
+ base: "a",
+ elem: nil,
+ out: "a",
+ },
+ {
+ base: "a",
+ elem: []string{"b"},
+ out: "a/b",
+ },
+ {
+ base: "a",
+ elem: []string{"../b"},
+ out: "b",
+ },
+ {
+ base: "a",
+ elem: []string{"../../b"},
+ out: "b",
+ },
+ {
+ base: "",
+ elem: []string{"a"},
+ out: "a",
+ },
+ {
+ base: "",
+ elem: []string{"../a"},
+ out: "a",
+ },
+ }
+ for _, tt := range tests {
+ wantErr := "nil"
+ if tt.out == "" {
+ wantErr = "non-nil error"
+ }
+ if out, err := JoinPath(tt.base, tt.elem...); out != tt.out || (err == nil) != (tt.out != "") {
+ t.Errorf("JoinPath(%q, %q) = %q, %v, want %q, %v", tt.base, tt.elem, out, err, tt.out, wantErr)
+ }
+ var out string
+ u, err := Parse(tt.base)
+ if err == nil {
+ u = u.JoinPath(tt.elem...)
+ out = u.String()
+ }
+ if out != tt.out || (err == nil) != (tt.out != "") {
+ t.Errorf("Parse(%q).JoinPath(%q) = %q, %v, want %q, %v", tt.base, tt.elem, out, err, tt.out, wantErr)
+ }
+ }
+}
diff --git a/src/net/write_unix_test.go b/src/net/write_unix_test.go
new file mode 100644
index 0000000..23e8bef
--- /dev/null
+++ b/src/net/write_unix_test.go
@@ -0,0 +1,66 @@
+// Copyright 2017 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
+
+package net
+
+import (
+ "bytes"
+ "syscall"
+ "testing"
+ "time"
+)
+
+// Test that a client can't trigger an endless loop of write system
+// calls on the server by shutting down the write side on the client.
+// Possibility raised in the discussion of https://golang.org/cl/71973.
+func TestEndlessWrite(t *testing.T) {
+ t.Parallel()
+ c := make(chan bool)
+ server := func(cs *TCPConn) error {
+ cs.CloseWrite()
+ <-c
+ return nil
+ }
+ client := func(ss *TCPConn) error {
+ // Tell the server to return when we return.
+ defer close(c)
+
+ // Loop writing to the server. The server is not reading
+ // anything, so this will eventually block, and then time out.
+ b := bytes.Repeat([]byte{'a'}, 8192)
+ cagain := 0
+ for {
+ n, err := ss.conn.fd.pfd.WriteOnce(b)
+ if n > 0 {
+ cagain = 0
+ }
+ switch err {
+ case nil:
+ case syscall.EAGAIN:
+ if cagain == 0 {
+ // We've written enough data to
+ // start blocking. Set a deadline
+ // so that we will stop.
+ ss.SetWriteDeadline(time.Now().Add(5 * time.Millisecond))
+ }
+ cagain++
+ if cagain > 20 {
+ t.Error("looping on EAGAIN")
+ return nil
+ }
+ if err = ss.conn.fd.pfd.WaitWrite(); err != nil {
+ t.Logf("client WaitWrite: %v", err)
+ return nil
+ }
+ default:
+ // We expect to eventually get an error.
+ t.Logf("client WriteOnce: %v", err)
+ return nil
+ }
+ }
+ }
+ withTCPConnPair(t, client, server)
+}
diff --git a/src/net/writev_test.go b/src/net/writev_test.go
new file mode 100644
index 0000000..8722c0f
--- /dev/null
+++ b/src/net/writev_test.go
@@ -0,0 +1,224 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !js && !wasip1
+
+package net
+
+import (
+ "bytes"
+ "fmt"
+ "internal/poll"
+ "io"
+ "reflect"
+ "runtime"
+ "sync"
+ "testing"
+)
+
+func TestBuffers_read(t *testing.T) {
+ const story = "once upon a time in Gopherland ... "
+ buffers := Buffers{
+ []byte("once "),
+ []byte("upon "),
+ []byte("a "),
+ []byte("time "),
+ []byte("in "),
+ []byte("Gopherland ... "),
+ }
+ got, err := io.ReadAll(&buffers)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if string(got) != story {
+ t.Errorf("read %q; want %q", got, story)
+ }
+ if len(buffers) != 0 {
+ t.Errorf("len(buffers) = %d; want 0", len(buffers))
+ }
+}
+
+func TestBuffers_consume(t *testing.T) {
+ tests := []struct {
+ in Buffers
+ consume int64
+ want Buffers
+ }{
+ {
+ in: Buffers{[]byte("foo"), []byte("bar")},
+ consume: 0,
+ want: Buffers{[]byte("foo"), []byte("bar")},
+ },
+ {
+ in: Buffers{[]byte("foo"), []byte("bar")},
+ consume: 2,
+ want: Buffers{[]byte("o"), []byte("bar")},
+ },
+ {
+ in: Buffers{[]byte("foo"), []byte("bar")},
+ consume: 3,
+ want: Buffers{[]byte("bar")},
+ },
+ {
+ in: Buffers{[]byte("foo"), []byte("bar")},
+ consume: 4,
+ want: Buffers{[]byte("ar")},
+ },
+ {
+ in: Buffers{nil, nil, nil, []byte("bar")},
+ consume: 1,
+ want: Buffers{[]byte("ar")},
+ },
+ {
+ in: Buffers{nil, nil, nil, []byte("foo")},
+ consume: 0,
+ want: Buffers{[]byte("foo")},
+ },
+ {
+ in: Buffers{nil, nil, nil},
+ consume: 0,
+ want: Buffers{},
+ },
+ }
+ for i, tt := range tests {
+ in := tt.in
+ in.consume(tt.consume)
+ if !reflect.DeepEqual(in, tt.want) {
+ t.Errorf("%d. after consume(%d) = %+v, want %+v", i, tt.consume, in, tt.want)
+ }
+ }
+}
+
+func TestBuffers_WriteTo(t *testing.T) {
+ for _, name := range []string{"WriteTo", "Copy"} {
+ for _, size := range []int{0, 10, 1023, 1024, 1025} {
+ t.Run(fmt.Sprintf("%s/%d", name, size), func(t *testing.T) {
+ testBuffer_writeTo(t, size, name == "Copy")
+ })
+ }
+ }
+}
+
+func testBuffer_writeTo(t *testing.T, chunks int, useCopy bool) {
+ oldHook := poll.TestHookDidWritev
+ defer func() { poll.TestHookDidWritev = oldHook }()
+ var writeLog struct {
+ sync.Mutex
+ log []int
+ }
+ poll.TestHookDidWritev = func(size int) {
+ writeLog.Lock()
+ writeLog.log = append(writeLog.log, size)
+ writeLog.Unlock()
+ }
+ var want bytes.Buffer
+ for i := 0; i < chunks; i++ {
+ want.WriteByte(byte(i))
+ }
+
+ withTCPConnPair(t, func(c *TCPConn) error {
+ buffers := make(Buffers, chunks)
+ for i := range buffers {
+ buffers[i] = want.Bytes()[i : i+1]
+ }
+ var n int64
+ var err error
+ if useCopy {
+ n, err = io.Copy(c, &buffers)
+ } else {
+ n, err = buffers.WriteTo(c)
+ }
+ if err != nil {
+ return err
+ }
+ if len(buffers) != 0 {
+ return fmt.Errorf("len(buffers) = %d; want 0", len(buffers))
+ }
+ if n != int64(want.Len()) {
+ return fmt.Errorf("Buffers.WriteTo returned %d; want %d", n, want.Len())
+ }
+ return nil
+ }, func(c *TCPConn) error {
+ all, err := io.ReadAll(c)
+ if !bytes.Equal(all, want.Bytes()) || err != nil {
+ return fmt.Errorf("client read %q, %v; want %q, nil", all, err, want.Bytes())
+ }
+
+ writeLog.Lock() // no need to unlock
+ var gotSum int
+ for _, v := range writeLog.log {
+ gotSum += v
+ }
+
+ var wantSum int
+ switch runtime.GOOS {
+ case "aix", "android", "darwin", "ios", "dragonfly", "freebsd", "illumos", "linux", "netbsd", "openbsd", "solaris":
+ var wantMinCalls int
+ wantSum = want.Len()
+ v := chunks
+ for v > 0 {
+ wantMinCalls++
+ v -= 1024
+ }
+ if len(writeLog.log) < wantMinCalls {
+ t.Errorf("write calls = %v < wanted min %v", len(writeLog.log), wantMinCalls)
+ }
+ case "windows":
+ var wantCalls int
+ wantSum = want.Len()
+ if wantSum > 0 {
+ wantCalls = 1 // windows will always do 1 syscall, unless sending empty buffer
+ }
+ if len(writeLog.log) != wantCalls {
+ t.Errorf("write calls = %v; want %v", len(writeLog.log), wantCalls)
+ }
+ }
+ if gotSum != wantSum {
+ t.Errorf("writev call sum = %v; want %v", gotSum, wantSum)
+ }
+ return nil
+ })
+}
+
+func TestWritevError(t *testing.T) {
+ if runtime.GOOS == "windows" {
+ t.Skipf("skipping the test: windows does not have problem sending large chunks of data")
+ }
+
+ ln := newLocalListener(t, "tcp")
+ defer ln.Close()
+
+ ch := make(chan Conn, 1)
+ go func() {
+ defer close(ch)
+ c, err := ln.Accept()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ ch <- c
+ }()
+ c1, err := Dial("tcp", ln.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c1.Close()
+ c2 := <-ch
+ if c2 == nil {
+ t.Fatal("no server side connection")
+ }
+ c2.Close()
+
+ // 1 GB of data should be enough to notice the connection is gone.
+ // Just a few bytes is not enough.
+ // Arrange to reuse the same 1 MB buffer so that we don't allocate much.
+ buf := make([]byte, 1<<20)
+ buffers := make(Buffers, 1<<10)
+ for i := range buffers {
+ buffers[i] = buf
+ }
+ if _, err := buffers.WriteTo(c1); err == nil {
+ t.Fatal("Buffers.WriteTo(closed conn) succeeded, want error")
+ }
+}
diff --git a/src/net/writev_unix.go b/src/net/writev_unix.go
new file mode 100644
index 0000000..3b0325b
--- /dev/null
+++ b/src/net/writev_unix.go
@@ -0,0 +1,29 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build unix
+
+package net
+
+import (
+ "runtime"
+ "syscall"
+)
+
+func (c *conn) writeBuffers(v *Buffers) (int64, error) {
+ if !c.ok() {
+ return 0, syscall.EINVAL
+ }
+ n, err := c.fd.writeBuffers(v)
+ if err != nil {
+ return n, &OpError{Op: "writev", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
+ }
+ return n, nil
+}
+
+func (fd *netFD) writeBuffers(v *Buffers) (n int64, err error) {
+ n, err = fd.pfd.Writev((*[][]byte)(v))
+ runtime.KeepAlive(fd)
+ return n, wrapSyscallError("writev", err)
+}