summaryrefslogtreecommitdiffstats
path: root/src/net/internal/socktest
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-16 19:23:18 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-16 19:23:18 +0000
commit43a123c1ae6613b3efeed291fa552ecd909d3acf (patch)
treefd92518b7024bc74031f78a1cf9e454b65e73665 /src/net/internal/socktest
parentInitial commit. (diff)
downloadgolang-1.20-43a123c1ae6613b3efeed291fa552ecd909d3acf.tar.xz
golang-1.20-43a123c1ae6613b3efeed291fa552ecd909d3acf.zip
Adding upstream version 1.20.14.upstream/1.20.14upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'src/net/internal/socktest')
-rw-r--r--src/net/internal/socktest/main_test.go56
-rw-r--r--src/net/internal/socktest/main_unix_test.go24
-rw-r--r--src/net/internal/socktest/main_windows_test.go22
-rw-r--r--src/net/internal/socktest/switch.go169
-rw-r--r--src/net/internal/socktest/switch_posix.go58
-rw-r--r--src/net/internal/socktest/switch_stub.go16
-rw-r--r--src/net/internal/socktest/switch_unix.go29
-rw-r--r--src/net/internal/socktest/switch_windows.go29
-rw-r--r--src/net/internal/socktest/sys_cloexec.go42
-rw-r--r--src/net/internal/socktest/sys_unix.go193
-rw-r--r--src/net/internal/socktest/sys_windows.go221
11 files changed, 859 insertions, 0 deletions
diff --git a/src/net/internal/socktest/main_test.go b/src/net/internal/socktest/main_test.go
new file mode 100644
index 0000000..c7c8d16
--- /dev/null
+++ b/src/net/internal/socktest/main_test.go
@@ -0,0 +1,56 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !js && !plan9
+
+package socktest_test
+
+import (
+ "net/internal/socktest"
+ "os"
+ "sync"
+ "syscall"
+ "testing"
+)
+
+var sw socktest.Switch
+
+func TestMain(m *testing.M) {
+ installTestHooks()
+
+ st := m.Run()
+
+ for s := range sw.Sockets() {
+ closeFunc(s)
+ }
+ uninstallTestHooks()
+ os.Exit(st)
+}
+
+func TestSwitch(t *testing.T) {
+ const N = 10
+ var wg sync.WaitGroup
+ wg.Add(N)
+ for i := 0; i < N; i++ {
+ go func() {
+ defer wg.Done()
+ for _, family := range []int{syscall.AF_INET, syscall.AF_INET6} {
+ socketFunc(family, syscall.SOCK_STREAM, syscall.IPPROTO_TCP)
+ }
+ }()
+ }
+ wg.Wait()
+}
+
+func TestSocket(t *testing.T) {
+ for _, f := range []socktest.Filter{
+ func(st *socktest.Status) (socktest.AfterFilter, error) { return nil, nil },
+ nil,
+ } {
+ sw.Set(socktest.FilterSocket, f)
+ for _, family := range []int{syscall.AF_INET, syscall.AF_INET6} {
+ socketFunc(family, syscall.SOCK_STREAM, syscall.IPPROTO_TCP)
+ }
+ }
+}
diff --git a/src/net/internal/socktest/main_unix_test.go b/src/net/internal/socktest/main_unix_test.go
new file mode 100644
index 0000000..7d21f6f
--- /dev/null
+++ b/src/net/internal/socktest/main_unix_test.go
@@ -0,0 +1,24 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !js && !plan9 && !windows
+
+package socktest_test
+
+import "syscall"
+
+var (
+ socketFunc func(int, int, int) (int, error)
+ closeFunc func(int) error
+)
+
+func installTestHooks() {
+ socketFunc = sw.Socket
+ closeFunc = sw.Close
+}
+
+func uninstallTestHooks() {
+ socketFunc = syscall.Socket
+ closeFunc = syscall.Close
+}
diff --git a/src/net/internal/socktest/main_windows_test.go b/src/net/internal/socktest/main_windows_test.go
new file mode 100644
index 0000000..df1cb97
--- /dev/null
+++ b/src/net/internal/socktest/main_windows_test.go
@@ -0,0 +1,22 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package socktest_test
+
+import "syscall"
+
+var (
+ socketFunc func(int, int, int) (syscall.Handle, error)
+ closeFunc func(syscall.Handle) error
+)
+
+func installTestHooks() {
+ socketFunc = sw.Socket
+ closeFunc = sw.Closesocket
+}
+
+func uninstallTestHooks() {
+ socketFunc = syscall.Socket
+ closeFunc = syscall.Closesocket
+}
diff --git a/src/net/internal/socktest/switch.go b/src/net/internal/socktest/switch.go
new file mode 100644
index 0000000..3c37b6f
--- /dev/null
+++ b/src/net/internal/socktest/switch.go
@@ -0,0 +1,169 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package socktest provides utilities for socket testing.
+package socktest
+
+import (
+ "fmt"
+ "sync"
+)
+
+// A Switch represents a callpath point switch for socket system
+// calls.
+type Switch struct {
+ once sync.Once
+
+ fmu sync.RWMutex
+ fltab map[FilterType]Filter
+
+ smu sync.RWMutex
+ sotab Sockets
+ stats stats
+}
+
+func (sw *Switch) init() {
+ sw.fltab = make(map[FilterType]Filter)
+ sw.sotab = make(Sockets)
+ sw.stats = make(stats)
+}
+
+// Stats returns a list of per-cookie socket statistics.
+func (sw *Switch) Stats() []Stat {
+ var st []Stat
+ sw.smu.RLock()
+ for _, s := range sw.stats {
+ ns := *s
+ st = append(st, ns)
+ }
+ sw.smu.RUnlock()
+ return st
+}
+
+// Sockets returns mappings of socket descriptor to socket status.
+func (sw *Switch) Sockets() Sockets {
+ sw.smu.RLock()
+ tab := make(Sockets, len(sw.sotab))
+ for i, s := range sw.sotab {
+ tab[i] = s
+ }
+ sw.smu.RUnlock()
+ return tab
+}
+
+// A Cookie represents a 3-tuple of a socket; address family, socket
+// type and protocol number.
+type Cookie uint64
+
+// Family returns an address family.
+func (c Cookie) Family() int { return int(c >> 48) }
+
+// Type returns a socket type.
+func (c Cookie) Type() int { return int(c << 16 >> 32) }
+
+// Protocol returns a protocol number.
+func (c Cookie) Protocol() int { return int(c & 0xff) }
+
+func cookie(family, sotype, proto int) Cookie {
+ return Cookie(family)<<48 | Cookie(sotype)&0xffffffff<<16 | Cookie(proto)&0xff
+}
+
+// A Status represents the status of a socket.
+type Status struct {
+ Cookie Cookie
+ Err error // error status of socket system call
+ SocketErr error // error status of socket by SO_ERROR
+}
+
+func (so Status) String() string {
+ return fmt.Sprintf("(%s, %s, %s): syscallerr=%v socketerr=%v", familyString(so.Cookie.Family()), typeString(so.Cookie.Type()), protocolString(so.Cookie.Protocol()), so.Err, so.SocketErr)
+}
+
+// A Stat represents a per-cookie socket statistics.
+type Stat struct {
+ Family int // address family
+ Type int // socket type
+ Protocol int // protocol number
+
+ Opened uint64 // number of sockets opened
+ Connected uint64 // number of sockets connected
+ Listened uint64 // number of sockets listened
+ Accepted uint64 // number of sockets accepted
+ Closed uint64 // number of sockets closed
+
+ OpenFailed uint64 // number of sockets open failed
+ ConnectFailed uint64 // number of sockets connect failed
+ ListenFailed uint64 // number of sockets listen failed
+ AcceptFailed uint64 // number of sockets accept failed
+ CloseFailed uint64 // number of sockets close failed
+}
+
+func (st Stat) String() string {
+ return fmt.Sprintf("(%s, %s, %s): opened=%d connected=%d listened=%d accepted=%d closed=%d openfailed=%d connectfailed=%d listenfailed=%d acceptfailed=%d closefailed=%d", familyString(st.Family), typeString(st.Type), protocolString(st.Protocol), st.Opened, st.Connected, st.Listened, st.Accepted, st.Closed, st.OpenFailed, st.ConnectFailed, st.ListenFailed, st.AcceptFailed, st.CloseFailed)
+}
+
+type stats map[Cookie]*Stat
+
+func (st stats) getLocked(c Cookie) *Stat {
+ s, ok := st[c]
+ if !ok {
+ s = &Stat{Family: c.Family(), Type: c.Type(), Protocol: c.Protocol()}
+ st[c] = s
+ }
+ return s
+}
+
+// A FilterType represents a filter type.
+type FilterType int
+
+const (
+ FilterSocket FilterType = iota // for Socket
+ FilterConnect // for Connect or ConnectEx
+ FilterListen // for Listen
+ FilterAccept // for Accept, Accept4 or AcceptEx
+ FilterGetsockoptInt // for GetsockoptInt
+ FilterClose // for Close or Closesocket
+)
+
+// A Filter represents a socket system call filter.
+//
+// It will only be executed before a system call for a socket that has
+// an entry in internal table.
+// If the filter returns a non-nil error, the execution of system call
+// will be canceled and the system call function returns the non-nil
+// error.
+// It can return a non-nil AfterFilter for filtering after the
+// execution of the system call.
+type Filter func(*Status) (AfterFilter, error)
+
+func (f Filter) apply(st *Status) (AfterFilter, error) {
+ if f == nil {
+ return nil, nil
+ }
+ return f(st)
+}
+
+// An AfterFilter represents a socket system call filter after an
+// execution of a system call.
+//
+// It will only be executed after a system call for a socket that has
+// an entry in internal table.
+// If the filter returns a non-nil error, the system call function
+// returns the non-nil error.
+type AfterFilter func(*Status) error
+
+func (f AfterFilter) apply(st *Status) error {
+ if f == nil {
+ return nil
+ }
+ return f(st)
+}
+
+// Set deploys the socket system call filter f for the filter type t.
+func (sw *Switch) Set(t FilterType, f Filter) {
+ sw.once.Do(sw.init)
+ sw.fmu.Lock()
+ sw.fltab[t] = f
+ sw.fmu.Unlock()
+}
diff --git a/src/net/internal/socktest/switch_posix.go b/src/net/internal/socktest/switch_posix.go
new file mode 100644
index 0000000..fcad4ce
--- /dev/null
+++ b/src/net/internal/socktest/switch_posix.go
@@ -0,0 +1,58 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !plan9
+
+package socktest
+
+import (
+ "fmt"
+ "syscall"
+)
+
+func familyString(family int) string {
+ switch family {
+ case syscall.AF_INET:
+ return "inet4"
+ case syscall.AF_INET6:
+ return "inet6"
+ case syscall.AF_UNIX:
+ return "local"
+ default:
+ return fmt.Sprintf("%d", family)
+ }
+}
+
+func typeString(sotype int) string {
+ var s string
+ switch sotype & 0xff {
+ case syscall.SOCK_STREAM:
+ s = "stream"
+ case syscall.SOCK_DGRAM:
+ s = "datagram"
+ case syscall.SOCK_RAW:
+ s = "raw"
+ case syscall.SOCK_SEQPACKET:
+ s = "seqpacket"
+ default:
+ s = fmt.Sprintf("%d", sotype&0xff)
+ }
+ if flags := uint(sotype) & ^uint(0xff); flags != 0 {
+ s += fmt.Sprintf("|%#x", flags)
+ }
+ return s
+}
+
+func protocolString(proto int) string {
+ switch proto {
+ case 0:
+ return "default"
+ case syscall.IPPROTO_TCP:
+ return "tcp"
+ case syscall.IPPROTO_UDP:
+ return "udp"
+ default:
+ return fmt.Sprintf("%d", proto)
+ }
+}
diff --git a/src/net/internal/socktest/switch_stub.go b/src/net/internal/socktest/switch_stub.go
new file mode 100644
index 0000000..8a2fc35
--- /dev/null
+++ b/src/net/internal/socktest/switch_stub.go
@@ -0,0 +1,16 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build plan9
+
+package socktest
+
+// Sockets maps a socket descriptor to the status of socket.
+type Sockets map[int]Status
+
+func familyString(family int) string { return "<nil>" }
+
+func typeString(sotype int) string { return "<nil>" }
+
+func protocolString(proto int) string { return "<nil>" }
diff --git a/src/net/internal/socktest/switch_unix.go b/src/net/internal/socktest/switch_unix.go
new file mode 100644
index 0000000..f2e95d6
--- /dev/null
+++ b/src/net/internal/socktest/switch_unix.go
@@ -0,0 +1,29 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build unix || (js && wasm)
+
+package socktest
+
+// Sockets maps a socket descriptor to the status of socket.
+type Sockets map[int]Status
+
+func (sw *Switch) sockso(s int) *Status {
+ sw.smu.RLock()
+ defer sw.smu.RUnlock()
+ so, ok := sw.sotab[s]
+ if !ok {
+ return nil
+ }
+ return &so
+}
+
+// addLocked returns a new Status without locking.
+// sw.smu must be held before call.
+func (sw *Switch) addLocked(s, family, sotype, proto int) *Status {
+ sw.once.Do(sw.init)
+ so := Status{Cookie: cookie(family, sotype, proto)}
+ sw.sotab[s] = so
+ return &so
+}
diff --git a/src/net/internal/socktest/switch_windows.go b/src/net/internal/socktest/switch_windows.go
new file mode 100644
index 0000000..4f1d597
--- /dev/null
+++ b/src/net/internal/socktest/switch_windows.go
@@ -0,0 +1,29 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package socktest
+
+import "syscall"
+
+// Sockets maps a socket descriptor to the status of socket.
+type Sockets map[syscall.Handle]Status
+
+func (sw *Switch) sockso(s syscall.Handle) *Status {
+ sw.smu.RLock()
+ defer sw.smu.RUnlock()
+ so, ok := sw.sotab[s]
+ if !ok {
+ return nil
+ }
+ return &so
+}
+
+// addLocked returns a new Status without locking.
+// sw.smu must be held before call.
+func (sw *Switch) addLocked(s syscall.Handle, family, sotype, proto int) *Status {
+ sw.once.Do(sw.init)
+ so := Status{Cookie: cookie(family, sotype, proto)}
+ sw.sotab[s] = so
+ return &so
+}
diff --git a/src/net/internal/socktest/sys_cloexec.go b/src/net/internal/socktest/sys_cloexec.go
new file mode 100644
index 0000000..d57f44d
--- /dev/null
+++ b/src/net/internal/socktest/sys_cloexec.go
@@ -0,0 +1,42 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build dragonfly || freebsd || linux || netbsd || openbsd || solaris
+
+package socktest
+
+import "syscall"
+
+// Accept4 wraps syscall.Accept4.
+func (sw *Switch) Accept4(s, flags int) (ns int, sa syscall.Sockaddr, err error) {
+ so := sw.sockso(s)
+ if so == nil {
+ return syscall.Accept4(s, flags)
+ }
+ sw.fmu.RLock()
+ f := sw.fltab[FilterAccept]
+ sw.fmu.RUnlock()
+
+ af, err := f.apply(so)
+ if err != nil {
+ return -1, nil, err
+ }
+ ns, sa, so.Err = syscall.Accept4(s, flags)
+ if err = af.apply(so); err != nil {
+ if so.Err == nil {
+ syscall.Close(ns)
+ }
+ return -1, nil, err
+ }
+
+ sw.smu.Lock()
+ defer sw.smu.Unlock()
+ if so.Err != nil {
+ sw.stats.getLocked(so.Cookie).AcceptFailed++
+ return -1, nil, so.Err
+ }
+ nso := sw.addLocked(ns, so.Cookie.Family(), so.Cookie.Type(), so.Cookie.Protocol())
+ sw.stats.getLocked(nso.Cookie).Accepted++
+ return ns, sa, nil
+}
diff --git a/src/net/internal/socktest/sys_unix.go b/src/net/internal/socktest/sys_unix.go
new file mode 100644
index 0000000..e1040d3
--- /dev/null
+++ b/src/net/internal/socktest/sys_unix.go
@@ -0,0 +1,193 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build unix || (js && wasm)
+
+package socktest
+
+import "syscall"
+
+// Socket wraps syscall.Socket.
+func (sw *Switch) Socket(family, sotype, proto int) (s int, err error) {
+ sw.once.Do(sw.init)
+
+ so := &Status{Cookie: cookie(family, sotype, proto)}
+ sw.fmu.RLock()
+ f := sw.fltab[FilterSocket]
+ sw.fmu.RUnlock()
+
+ af, err := f.apply(so)
+ if err != nil {
+ return -1, err
+ }
+ s, so.Err = syscall.Socket(family, sotype, proto)
+ if err = af.apply(so); err != nil {
+ if so.Err == nil {
+ syscall.Close(s)
+ }
+ return -1, err
+ }
+
+ sw.smu.Lock()
+ defer sw.smu.Unlock()
+ if so.Err != nil {
+ sw.stats.getLocked(so.Cookie).OpenFailed++
+ return -1, so.Err
+ }
+ nso := sw.addLocked(s, family, sotype, proto)
+ sw.stats.getLocked(nso.Cookie).Opened++
+ return s, nil
+}
+
+// Close wraps syscall.Close.
+func (sw *Switch) Close(s int) (err error) {
+ so := sw.sockso(s)
+ if so == nil {
+ return syscall.Close(s)
+ }
+ sw.fmu.RLock()
+ f := sw.fltab[FilterClose]
+ sw.fmu.RUnlock()
+
+ af, err := f.apply(so)
+ if err != nil {
+ return err
+ }
+ so.Err = syscall.Close(s)
+ if err = af.apply(so); err != nil {
+ return err
+ }
+
+ sw.smu.Lock()
+ defer sw.smu.Unlock()
+ if so.Err != nil {
+ sw.stats.getLocked(so.Cookie).CloseFailed++
+ return so.Err
+ }
+ delete(sw.sotab, s)
+ sw.stats.getLocked(so.Cookie).Closed++
+ return nil
+}
+
+// Connect wraps syscall.Connect.
+func (sw *Switch) Connect(s int, sa syscall.Sockaddr) (err error) {
+ so := sw.sockso(s)
+ if so == nil {
+ return syscall.Connect(s, sa)
+ }
+ sw.fmu.RLock()
+ f := sw.fltab[FilterConnect]
+ sw.fmu.RUnlock()
+
+ af, err := f.apply(so)
+ if err != nil {
+ return err
+ }
+ so.Err = syscall.Connect(s, sa)
+ if err = af.apply(so); err != nil {
+ return err
+ }
+
+ sw.smu.Lock()
+ defer sw.smu.Unlock()
+ if so.Err != nil {
+ sw.stats.getLocked(so.Cookie).ConnectFailed++
+ return so.Err
+ }
+ sw.stats.getLocked(so.Cookie).Connected++
+ return nil
+}
+
+// Listen wraps syscall.Listen.
+func (sw *Switch) Listen(s, backlog int) (err error) {
+ so := sw.sockso(s)
+ if so == nil {
+ return syscall.Listen(s, backlog)
+ }
+ sw.fmu.RLock()
+ f := sw.fltab[FilterListen]
+ sw.fmu.RUnlock()
+
+ af, err := f.apply(so)
+ if err != nil {
+ return err
+ }
+ so.Err = syscall.Listen(s, backlog)
+ if err = af.apply(so); err != nil {
+ return err
+ }
+
+ sw.smu.Lock()
+ defer sw.smu.Unlock()
+ if so.Err != nil {
+ sw.stats.getLocked(so.Cookie).ListenFailed++
+ return so.Err
+ }
+ sw.stats.getLocked(so.Cookie).Listened++
+ return nil
+}
+
+// Accept wraps syscall.Accept.
+func (sw *Switch) Accept(s int) (ns int, sa syscall.Sockaddr, err error) {
+ so := sw.sockso(s)
+ if so == nil {
+ return syscall.Accept(s)
+ }
+ sw.fmu.RLock()
+ f := sw.fltab[FilterAccept]
+ sw.fmu.RUnlock()
+
+ af, err := f.apply(so)
+ if err != nil {
+ return -1, nil, err
+ }
+ ns, sa, so.Err = syscall.Accept(s)
+ if err = af.apply(so); err != nil {
+ if so.Err == nil {
+ syscall.Close(ns)
+ }
+ return -1, nil, err
+ }
+
+ sw.smu.Lock()
+ defer sw.smu.Unlock()
+ if so.Err != nil {
+ sw.stats.getLocked(so.Cookie).AcceptFailed++
+ return -1, nil, so.Err
+ }
+ nso := sw.addLocked(ns, so.Cookie.Family(), so.Cookie.Type(), so.Cookie.Protocol())
+ sw.stats.getLocked(nso.Cookie).Accepted++
+ return ns, sa, nil
+}
+
+// GetsockoptInt wraps syscall.GetsockoptInt.
+func (sw *Switch) GetsockoptInt(s, level, opt int) (soerr int, err error) {
+ so := sw.sockso(s)
+ if so == nil {
+ return syscall.GetsockoptInt(s, level, opt)
+ }
+ sw.fmu.RLock()
+ f := sw.fltab[FilterGetsockoptInt]
+ sw.fmu.RUnlock()
+
+ af, err := f.apply(so)
+ if err != nil {
+ return -1, err
+ }
+ soerr, so.Err = syscall.GetsockoptInt(s, level, opt)
+ so.SocketErr = syscall.Errno(soerr)
+ if err = af.apply(so); err != nil {
+ return -1, err
+ }
+
+ if so.Err != nil {
+ return -1, so.Err
+ }
+ if opt == syscall.SO_ERROR && (so.SocketErr == syscall.Errno(0) || so.SocketErr == syscall.EISCONN) {
+ sw.smu.Lock()
+ sw.stats.getLocked(so.Cookie).Connected++
+ sw.smu.Unlock()
+ }
+ return soerr, nil
+}
diff --git a/src/net/internal/socktest/sys_windows.go b/src/net/internal/socktest/sys_windows.go
new file mode 100644
index 0000000..8c1c862
--- /dev/null
+++ b/src/net/internal/socktest/sys_windows.go
@@ -0,0 +1,221 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package socktest
+
+import (
+ "internal/syscall/windows"
+ "syscall"
+)
+
+// Socket wraps syscall.Socket.
+func (sw *Switch) Socket(family, sotype, proto int) (s syscall.Handle, err error) {
+ sw.once.Do(sw.init)
+
+ so := &Status{Cookie: cookie(family, sotype, proto)}
+ sw.fmu.RLock()
+ f, _ := sw.fltab[FilterSocket]
+ sw.fmu.RUnlock()
+
+ af, err := f.apply(so)
+ if err != nil {
+ return syscall.InvalidHandle, err
+ }
+ s, so.Err = syscall.Socket(family, sotype, proto)
+ if err = af.apply(so); err != nil {
+ if so.Err == nil {
+ syscall.Closesocket(s)
+ }
+ return syscall.InvalidHandle, err
+ }
+
+ sw.smu.Lock()
+ defer sw.smu.Unlock()
+ if so.Err != nil {
+ sw.stats.getLocked(so.Cookie).OpenFailed++
+ return syscall.InvalidHandle, so.Err
+ }
+ nso := sw.addLocked(s, family, sotype, proto)
+ sw.stats.getLocked(nso.Cookie).Opened++
+ return s, nil
+}
+
+// WSASocket wraps syscall.WSASocket.
+func (sw *Switch) WSASocket(family, sotype, proto int32, protinfo *syscall.WSAProtocolInfo, group uint32, flags uint32) (s syscall.Handle, err error) {
+ sw.once.Do(sw.init)
+
+ so := &Status{Cookie: cookie(int(family), int(sotype), int(proto))}
+ sw.fmu.RLock()
+ f, _ := sw.fltab[FilterSocket]
+ sw.fmu.RUnlock()
+
+ af, err := f.apply(so)
+ if err != nil {
+ return syscall.InvalidHandle, err
+ }
+ s, so.Err = windows.WSASocket(family, sotype, proto, protinfo, group, flags)
+ if err = af.apply(so); err != nil {
+ if so.Err == nil {
+ syscall.Closesocket(s)
+ }
+ return syscall.InvalidHandle, err
+ }
+
+ sw.smu.Lock()
+ defer sw.smu.Unlock()
+ if so.Err != nil {
+ sw.stats.getLocked(so.Cookie).OpenFailed++
+ return syscall.InvalidHandle, so.Err
+ }
+ nso := sw.addLocked(s, int(family), int(sotype), int(proto))
+ sw.stats.getLocked(nso.Cookie).Opened++
+ return s, nil
+}
+
+// Closesocket wraps syscall.Closesocket.
+func (sw *Switch) Closesocket(s syscall.Handle) (err error) {
+ so := sw.sockso(s)
+ if so == nil {
+ return syscall.Closesocket(s)
+ }
+ sw.fmu.RLock()
+ f, _ := sw.fltab[FilterClose]
+ sw.fmu.RUnlock()
+
+ af, err := f.apply(so)
+ if err != nil {
+ return err
+ }
+ so.Err = syscall.Closesocket(s)
+ if err = af.apply(so); err != nil {
+ return err
+ }
+
+ sw.smu.Lock()
+ defer sw.smu.Unlock()
+ if so.Err != nil {
+ sw.stats.getLocked(so.Cookie).CloseFailed++
+ return so.Err
+ }
+ delete(sw.sotab, s)
+ sw.stats.getLocked(so.Cookie).Closed++
+ return nil
+}
+
+// Connect wraps syscall.Connect.
+func (sw *Switch) Connect(s syscall.Handle, sa syscall.Sockaddr) (err error) {
+ so := sw.sockso(s)
+ if so == nil {
+ return syscall.Connect(s, sa)
+ }
+ sw.fmu.RLock()
+ f, _ := sw.fltab[FilterConnect]
+ sw.fmu.RUnlock()
+
+ af, err := f.apply(so)
+ if err != nil {
+ return err
+ }
+ so.Err = syscall.Connect(s, sa)
+ if err = af.apply(so); err != nil {
+ return err
+ }
+
+ sw.smu.Lock()
+ defer sw.smu.Unlock()
+ if so.Err != nil {
+ sw.stats.getLocked(so.Cookie).ConnectFailed++
+ return so.Err
+ }
+ sw.stats.getLocked(so.Cookie).Connected++
+ return nil
+}
+
+// ConnectEx wraps syscall.ConnectEx.
+func (sw *Switch) ConnectEx(s syscall.Handle, sa syscall.Sockaddr, b *byte, n uint32, nwr *uint32, o *syscall.Overlapped) (err error) {
+ so := sw.sockso(s)
+ if so == nil {
+ return syscall.ConnectEx(s, sa, b, n, nwr, o)
+ }
+ sw.fmu.RLock()
+ f, _ := sw.fltab[FilterConnect]
+ sw.fmu.RUnlock()
+
+ af, err := f.apply(so)
+ if err != nil {
+ return err
+ }
+ so.Err = syscall.ConnectEx(s, sa, b, n, nwr, o)
+ if err = af.apply(so); err != nil {
+ return err
+ }
+
+ sw.smu.Lock()
+ defer sw.smu.Unlock()
+ if so.Err != nil {
+ sw.stats.getLocked(so.Cookie).ConnectFailed++
+ return so.Err
+ }
+ sw.stats.getLocked(so.Cookie).Connected++
+ return nil
+}
+
+// Listen wraps syscall.Listen.
+func (sw *Switch) Listen(s syscall.Handle, backlog int) (err error) {
+ so := sw.sockso(s)
+ if so == nil {
+ return syscall.Listen(s, backlog)
+ }
+ sw.fmu.RLock()
+ f, _ := sw.fltab[FilterListen]
+ sw.fmu.RUnlock()
+
+ af, err := f.apply(so)
+ if err != nil {
+ return err
+ }
+ so.Err = syscall.Listen(s, backlog)
+ if err = af.apply(so); err != nil {
+ return err
+ }
+
+ sw.smu.Lock()
+ defer sw.smu.Unlock()
+ if so.Err != nil {
+ sw.stats.getLocked(so.Cookie).ListenFailed++
+ return so.Err
+ }
+ sw.stats.getLocked(so.Cookie).Listened++
+ return nil
+}
+
+// AcceptEx wraps syscall.AcceptEx.
+func (sw *Switch) AcceptEx(ls syscall.Handle, as syscall.Handle, b *byte, rxdatalen uint32, laddrlen uint32, raddrlen uint32, rcvd *uint32, overlapped *syscall.Overlapped) error {
+ so := sw.sockso(ls)
+ if so == nil {
+ return syscall.AcceptEx(ls, as, b, rxdatalen, laddrlen, raddrlen, rcvd, overlapped)
+ }
+ sw.fmu.RLock()
+ f, _ := sw.fltab[FilterAccept]
+ sw.fmu.RUnlock()
+
+ af, err := f.apply(so)
+ if err != nil {
+ return err
+ }
+ so.Err = syscall.AcceptEx(ls, as, b, rxdatalen, laddrlen, raddrlen, rcvd, overlapped)
+ if err = af.apply(so); err != nil {
+ return err
+ }
+
+ sw.smu.Lock()
+ defer sw.smu.Unlock()
+ if so.Err != nil {
+ sw.stats.getLocked(so.Cookie).AcceptFailed++
+ return so.Err
+ }
+ nso := sw.addLocked(as, so.Cookie.Family(), so.Cookie.Type(), so.Cookie.Protocol())
+ sw.stats.getLocked(nso.Cookie).Accepted++
+ return nil
+}