From 43a123c1ae6613b3efeed291fa552ecd909d3acf Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Tue, 16 Apr 2024 21:23:18 +0200 Subject: Adding upstream version 1.20.14. Signed-off-by: Daniel Baumann --- src/net/internal/socktest/main_test.go | 56 +++++++ src/net/internal/socktest/main_unix_test.go | 24 +++ src/net/internal/socktest/main_windows_test.go | 22 +++ src/net/internal/socktest/switch.go | 169 +++++++++++++++++++ src/net/internal/socktest/switch_posix.go | 58 +++++++ src/net/internal/socktest/switch_stub.go | 16 ++ src/net/internal/socktest/switch_unix.go | 29 ++++ src/net/internal/socktest/switch_windows.go | 29 ++++ src/net/internal/socktest/sys_cloexec.go | 42 +++++ src/net/internal/socktest/sys_unix.go | 193 +++++++++++++++++++++ src/net/internal/socktest/sys_windows.go | 221 +++++++++++++++++++++++++ 11 files changed, 859 insertions(+) create mode 100644 src/net/internal/socktest/main_test.go create mode 100644 src/net/internal/socktest/main_unix_test.go create mode 100644 src/net/internal/socktest/main_windows_test.go create mode 100644 src/net/internal/socktest/switch.go create mode 100644 src/net/internal/socktest/switch_posix.go create mode 100644 src/net/internal/socktest/switch_stub.go create mode 100644 src/net/internal/socktest/switch_unix.go create mode 100644 src/net/internal/socktest/switch_windows.go create mode 100644 src/net/internal/socktest/sys_cloexec.go create mode 100644 src/net/internal/socktest/sys_unix.go create mode 100644 src/net/internal/socktest/sys_windows.go (limited to 'src/net/internal/socktest') diff --git a/src/net/internal/socktest/main_test.go b/src/net/internal/socktest/main_test.go new file mode 100644 index 0000000..c7c8d16 --- /dev/null +++ b/src/net/internal/socktest/main_test.go @@ -0,0 +1,56 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !js && !plan9 + +package socktest_test + +import ( + "net/internal/socktest" + "os" + "sync" + "syscall" + "testing" +) + +var sw socktest.Switch + +func TestMain(m *testing.M) { + installTestHooks() + + st := m.Run() + + for s := range sw.Sockets() { + closeFunc(s) + } + uninstallTestHooks() + os.Exit(st) +} + +func TestSwitch(t *testing.T) { + const N = 10 + var wg sync.WaitGroup + wg.Add(N) + for i := 0; i < N; i++ { + go func() { + defer wg.Done() + for _, family := range []int{syscall.AF_INET, syscall.AF_INET6} { + socketFunc(family, syscall.SOCK_STREAM, syscall.IPPROTO_TCP) + } + }() + } + wg.Wait() +} + +func TestSocket(t *testing.T) { + for _, f := range []socktest.Filter{ + func(st *socktest.Status) (socktest.AfterFilter, error) { return nil, nil }, + nil, + } { + sw.Set(socktest.FilterSocket, f) + for _, family := range []int{syscall.AF_INET, syscall.AF_INET6} { + socketFunc(family, syscall.SOCK_STREAM, syscall.IPPROTO_TCP) + } + } +} diff --git a/src/net/internal/socktest/main_unix_test.go b/src/net/internal/socktest/main_unix_test.go new file mode 100644 index 0000000..7d21f6f --- /dev/null +++ b/src/net/internal/socktest/main_unix_test.go @@ -0,0 +1,24 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !js && !plan9 && !windows + +package socktest_test + +import "syscall" + +var ( + socketFunc func(int, int, int) (int, error) + closeFunc func(int) error +) + +func installTestHooks() { + socketFunc = sw.Socket + closeFunc = sw.Close +} + +func uninstallTestHooks() { + socketFunc = syscall.Socket + closeFunc = syscall.Close +} diff --git a/src/net/internal/socktest/main_windows_test.go b/src/net/internal/socktest/main_windows_test.go new file mode 100644 index 0000000..df1cb97 --- /dev/null +++ b/src/net/internal/socktest/main_windows_test.go @@ -0,0 +1,22 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package socktest_test + +import "syscall" + +var ( + socketFunc func(int, int, int) (syscall.Handle, error) + closeFunc func(syscall.Handle) error +) + +func installTestHooks() { + socketFunc = sw.Socket + closeFunc = sw.Closesocket +} + +func uninstallTestHooks() { + socketFunc = syscall.Socket + closeFunc = syscall.Closesocket +} diff --git a/src/net/internal/socktest/switch.go b/src/net/internal/socktest/switch.go new file mode 100644 index 0000000..3c37b6f --- /dev/null +++ b/src/net/internal/socktest/switch.go @@ -0,0 +1,169 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package socktest provides utilities for socket testing. +package socktest + +import ( + "fmt" + "sync" +) + +// A Switch represents a callpath point switch for socket system +// calls. +type Switch struct { + once sync.Once + + fmu sync.RWMutex + fltab map[FilterType]Filter + + smu sync.RWMutex + sotab Sockets + stats stats +} + +func (sw *Switch) init() { + sw.fltab = make(map[FilterType]Filter) + sw.sotab = make(Sockets) + sw.stats = make(stats) +} + +// Stats returns a list of per-cookie socket statistics. +func (sw *Switch) Stats() []Stat { + var st []Stat + sw.smu.RLock() + for _, s := range sw.stats { + ns := *s + st = append(st, ns) + } + sw.smu.RUnlock() + return st +} + +// Sockets returns mappings of socket descriptor to socket status. +func (sw *Switch) Sockets() Sockets { + sw.smu.RLock() + tab := make(Sockets, len(sw.sotab)) + for i, s := range sw.sotab { + tab[i] = s + } + sw.smu.RUnlock() + return tab +} + +// A Cookie represents a 3-tuple of a socket; address family, socket +// type and protocol number. +type Cookie uint64 + +// Family returns an address family. +func (c Cookie) Family() int { return int(c >> 48) } + +// Type returns a socket type. +func (c Cookie) Type() int { return int(c << 16 >> 32) } + +// Protocol returns a protocol number. +func (c Cookie) Protocol() int { return int(c & 0xff) } + +func cookie(family, sotype, proto int) Cookie { + return Cookie(family)<<48 | Cookie(sotype)&0xffffffff<<16 | Cookie(proto)&0xff +} + +// A Status represents the status of a socket. +type Status struct { + Cookie Cookie + Err error // error status of socket system call + SocketErr error // error status of socket by SO_ERROR +} + +func (so Status) String() string { + return fmt.Sprintf("(%s, %s, %s): syscallerr=%v socketerr=%v", familyString(so.Cookie.Family()), typeString(so.Cookie.Type()), protocolString(so.Cookie.Protocol()), so.Err, so.SocketErr) +} + +// A Stat represents a per-cookie socket statistics. +type Stat struct { + Family int // address family + Type int // socket type + Protocol int // protocol number + + Opened uint64 // number of sockets opened + Connected uint64 // number of sockets connected + Listened uint64 // number of sockets listened + Accepted uint64 // number of sockets accepted + Closed uint64 // number of sockets closed + + OpenFailed uint64 // number of sockets open failed + ConnectFailed uint64 // number of sockets connect failed + ListenFailed uint64 // number of sockets listen failed + AcceptFailed uint64 // number of sockets accept failed + CloseFailed uint64 // number of sockets close failed +} + +func (st Stat) String() string { + return fmt.Sprintf("(%s, %s, %s): opened=%d connected=%d listened=%d accepted=%d closed=%d openfailed=%d connectfailed=%d listenfailed=%d acceptfailed=%d closefailed=%d", familyString(st.Family), typeString(st.Type), protocolString(st.Protocol), st.Opened, st.Connected, st.Listened, st.Accepted, st.Closed, st.OpenFailed, st.ConnectFailed, st.ListenFailed, st.AcceptFailed, st.CloseFailed) +} + +type stats map[Cookie]*Stat + +func (st stats) getLocked(c Cookie) *Stat { + s, ok := st[c] + if !ok { + s = &Stat{Family: c.Family(), Type: c.Type(), Protocol: c.Protocol()} + st[c] = s + } + return s +} + +// A FilterType represents a filter type. +type FilterType int + +const ( + FilterSocket FilterType = iota // for Socket + FilterConnect // for Connect or ConnectEx + FilterListen // for Listen + FilterAccept // for Accept, Accept4 or AcceptEx + FilterGetsockoptInt // for GetsockoptInt + FilterClose // for Close or Closesocket +) + +// A Filter represents a socket system call filter. +// +// It will only be executed before a system call for a socket that has +// an entry in internal table. +// If the filter returns a non-nil error, the execution of system call +// will be canceled and the system call function returns the non-nil +// error. +// It can return a non-nil AfterFilter for filtering after the +// execution of the system call. +type Filter func(*Status) (AfterFilter, error) + +func (f Filter) apply(st *Status) (AfterFilter, error) { + if f == nil { + return nil, nil + } + return f(st) +} + +// An AfterFilter represents a socket system call filter after an +// execution of a system call. +// +// It will only be executed after a system call for a socket that has +// an entry in internal table. +// If the filter returns a non-nil error, the system call function +// returns the non-nil error. +type AfterFilter func(*Status) error + +func (f AfterFilter) apply(st *Status) error { + if f == nil { + return nil + } + return f(st) +} + +// Set deploys the socket system call filter f for the filter type t. +func (sw *Switch) Set(t FilterType, f Filter) { + sw.once.Do(sw.init) + sw.fmu.Lock() + sw.fltab[t] = f + sw.fmu.Unlock() +} diff --git a/src/net/internal/socktest/switch_posix.go b/src/net/internal/socktest/switch_posix.go new file mode 100644 index 0000000..fcad4ce --- /dev/null +++ b/src/net/internal/socktest/switch_posix.go @@ -0,0 +1,58 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !plan9 + +package socktest + +import ( + "fmt" + "syscall" +) + +func familyString(family int) string { + switch family { + case syscall.AF_INET: + return "inet4" + case syscall.AF_INET6: + return "inet6" + case syscall.AF_UNIX: + return "local" + default: + return fmt.Sprintf("%d", family) + } +} + +func typeString(sotype int) string { + var s string + switch sotype & 0xff { + case syscall.SOCK_STREAM: + s = "stream" + case syscall.SOCK_DGRAM: + s = "datagram" + case syscall.SOCK_RAW: + s = "raw" + case syscall.SOCK_SEQPACKET: + s = "seqpacket" + default: + s = fmt.Sprintf("%d", sotype&0xff) + } + if flags := uint(sotype) & ^uint(0xff); flags != 0 { + s += fmt.Sprintf("|%#x", flags) + } + return s +} + +func protocolString(proto int) string { + switch proto { + case 0: + return "default" + case syscall.IPPROTO_TCP: + return "tcp" + case syscall.IPPROTO_UDP: + return "udp" + default: + return fmt.Sprintf("%d", proto) + } +} diff --git a/src/net/internal/socktest/switch_stub.go b/src/net/internal/socktest/switch_stub.go new file mode 100644 index 0000000..8a2fc35 --- /dev/null +++ b/src/net/internal/socktest/switch_stub.go @@ -0,0 +1,16 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build plan9 + +package socktest + +// Sockets maps a socket descriptor to the status of socket. +type Sockets map[int]Status + +func familyString(family int) string { return "" } + +func typeString(sotype int) string { return "" } + +func protocolString(proto int) string { return "" } diff --git a/src/net/internal/socktest/switch_unix.go b/src/net/internal/socktest/switch_unix.go new file mode 100644 index 0000000..f2e95d6 --- /dev/null +++ b/src/net/internal/socktest/switch_unix.go @@ -0,0 +1,29 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build unix || (js && wasm) + +package socktest + +// Sockets maps a socket descriptor to the status of socket. +type Sockets map[int]Status + +func (sw *Switch) sockso(s int) *Status { + sw.smu.RLock() + defer sw.smu.RUnlock() + so, ok := sw.sotab[s] + if !ok { + return nil + } + return &so +} + +// addLocked returns a new Status without locking. +// sw.smu must be held before call. +func (sw *Switch) addLocked(s, family, sotype, proto int) *Status { + sw.once.Do(sw.init) + so := Status{Cookie: cookie(family, sotype, proto)} + sw.sotab[s] = so + return &so +} diff --git a/src/net/internal/socktest/switch_windows.go b/src/net/internal/socktest/switch_windows.go new file mode 100644 index 0000000..4f1d597 --- /dev/null +++ b/src/net/internal/socktest/switch_windows.go @@ -0,0 +1,29 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package socktest + +import "syscall" + +// Sockets maps a socket descriptor to the status of socket. +type Sockets map[syscall.Handle]Status + +func (sw *Switch) sockso(s syscall.Handle) *Status { + sw.smu.RLock() + defer sw.smu.RUnlock() + so, ok := sw.sotab[s] + if !ok { + return nil + } + return &so +} + +// addLocked returns a new Status without locking. +// sw.smu must be held before call. +func (sw *Switch) addLocked(s syscall.Handle, family, sotype, proto int) *Status { + sw.once.Do(sw.init) + so := Status{Cookie: cookie(family, sotype, proto)} + sw.sotab[s] = so + return &so +} diff --git a/src/net/internal/socktest/sys_cloexec.go b/src/net/internal/socktest/sys_cloexec.go new file mode 100644 index 0000000..d57f44d --- /dev/null +++ b/src/net/internal/socktest/sys_cloexec.go @@ -0,0 +1,42 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build dragonfly || freebsd || linux || netbsd || openbsd || solaris + +package socktest + +import "syscall" + +// Accept4 wraps syscall.Accept4. +func (sw *Switch) Accept4(s, flags int) (ns int, sa syscall.Sockaddr, err error) { + so := sw.sockso(s) + if so == nil { + return syscall.Accept4(s, flags) + } + sw.fmu.RLock() + f := sw.fltab[FilterAccept] + sw.fmu.RUnlock() + + af, err := f.apply(so) + if err != nil { + return -1, nil, err + } + ns, sa, so.Err = syscall.Accept4(s, flags) + if err = af.apply(so); err != nil { + if so.Err == nil { + syscall.Close(ns) + } + return -1, nil, err + } + + sw.smu.Lock() + defer sw.smu.Unlock() + if so.Err != nil { + sw.stats.getLocked(so.Cookie).AcceptFailed++ + return -1, nil, so.Err + } + nso := sw.addLocked(ns, so.Cookie.Family(), so.Cookie.Type(), so.Cookie.Protocol()) + sw.stats.getLocked(nso.Cookie).Accepted++ + return ns, sa, nil +} diff --git a/src/net/internal/socktest/sys_unix.go b/src/net/internal/socktest/sys_unix.go new file mode 100644 index 0000000..e1040d3 --- /dev/null +++ b/src/net/internal/socktest/sys_unix.go @@ -0,0 +1,193 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build unix || (js && wasm) + +package socktest + +import "syscall" + +// Socket wraps syscall.Socket. +func (sw *Switch) Socket(family, sotype, proto int) (s int, err error) { + sw.once.Do(sw.init) + + so := &Status{Cookie: cookie(family, sotype, proto)} + sw.fmu.RLock() + f := sw.fltab[FilterSocket] + sw.fmu.RUnlock() + + af, err := f.apply(so) + if err != nil { + return -1, err + } + s, so.Err = syscall.Socket(family, sotype, proto) + if err = af.apply(so); err != nil { + if so.Err == nil { + syscall.Close(s) + } + return -1, err + } + + sw.smu.Lock() + defer sw.smu.Unlock() + if so.Err != nil { + sw.stats.getLocked(so.Cookie).OpenFailed++ + return -1, so.Err + } + nso := sw.addLocked(s, family, sotype, proto) + sw.stats.getLocked(nso.Cookie).Opened++ + return s, nil +} + +// Close wraps syscall.Close. +func (sw *Switch) Close(s int) (err error) { + so := sw.sockso(s) + if so == nil { + return syscall.Close(s) + } + sw.fmu.RLock() + f := sw.fltab[FilterClose] + sw.fmu.RUnlock() + + af, err := f.apply(so) + if err != nil { + return err + } + so.Err = syscall.Close(s) + if err = af.apply(so); err != nil { + return err + } + + sw.smu.Lock() + defer sw.smu.Unlock() + if so.Err != nil { + sw.stats.getLocked(so.Cookie).CloseFailed++ + return so.Err + } + delete(sw.sotab, s) + sw.stats.getLocked(so.Cookie).Closed++ + return nil +} + +// Connect wraps syscall.Connect. +func (sw *Switch) Connect(s int, sa syscall.Sockaddr) (err error) { + so := sw.sockso(s) + if so == nil { + return syscall.Connect(s, sa) + } + sw.fmu.RLock() + f := sw.fltab[FilterConnect] + sw.fmu.RUnlock() + + af, err := f.apply(so) + if err != nil { + return err + } + so.Err = syscall.Connect(s, sa) + if err = af.apply(so); err != nil { + return err + } + + sw.smu.Lock() + defer sw.smu.Unlock() + if so.Err != nil { + sw.stats.getLocked(so.Cookie).ConnectFailed++ + return so.Err + } + sw.stats.getLocked(so.Cookie).Connected++ + return nil +} + +// Listen wraps syscall.Listen. +func (sw *Switch) Listen(s, backlog int) (err error) { + so := sw.sockso(s) + if so == nil { + return syscall.Listen(s, backlog) + } + sw.fmu.RLock() + f := sw.fltab[FilterListen] + sw.fmu.RUnlock() + + af, err := f.apply(so) + if err != nil { + return err + } + so.Err = syscall.Listen(s, backlog) + if err = af.apply(so); err != nil { + return err + } + + sw.smu.Lock() + defer sw.smu.Unlock() + if so.Err != nil { + sw.stats.getLocked(so.Cookie).ListenFailed++ + return so.Err + } + sw.stats.getLocked(so.Cookie).Listened++ + return nil +} + +// Accept wraps syscall.Accept. +func (sw *Switch) Accept(s int) (ns int, sa syscall.Sockaddr, err error) { + so := sw.sockso(s) + if so == nil { + return syscall.Accept(s) + } + sw.fmu.RLock() + f := sw.fltab[FilterAccept] + sw.fmu.RUnlock() + + af, err := f.apply(so) + if err != nil { + return -1, nil, err + } + ns, sa, so.Err = syscall.Accept(s) + if err = af.apply(so); err != nil { + if so.Err == nil { + syscall.Close(ns) + } + return -1, nil, err + } + + sw.smu.Lock() + defer sw.smu.Unlock() + if so.Err != nil { + sw.stats.getLocked(so.Cookie).AcceptFailed++ + return -1, nil, so.Err + } + nso := sw.addLocked(ns, so.Cookie.Family(), so.Cookie.Type(), so.Cookie.Protocol()) + sw.stats.getLocked(nso.Cookie).Accepted++ + return ns, sa, nil +} + +// GetsockoptInt wraps syscall.GetsockoptInt. +func (sw *Switch) GetsockoptInt(s, level, opt int) (soerr int, err error) { + so := sw.sockso(s) + if so == nil { + return syscall.GetsockoptInt(s, level, opt) + } + sw.fmu.RLock() + f := sw.fltab[FilterGetsockoptInt] + sw.fmu.RUnlock() + + af, err := f.apply(so) + if err != nil { + return -1, err + } + soerr, so.Err = syscall.GetsockoptInt(s, level, opt) + so.SocketErr = syscall.Errno(soerr) + if err = af.apply(so); err != nil { + return -1, err + } + + if so.Err != nil { + return -1, so.Err + } + if opt == syscall.SO_ERROR && (so.SocketErr == syscall.Errno(0) || so.SocketErr == syscall.EISCONN) { + sw.smu.Lock() + sw.stats.getLocked(so.Cookie).Connected++ + sw.smu.Unlock() + } + return soerr, nil +} diff --git a/src/net/internal/socktest/sys_windows.go b/src/net/internal/socktest/sys_windows.go new file mode 100644 index 0000000..8c1c862 --- /dev/null +++ b/src/net/internal/socktest/sys_windows.go @@ -0,0 +1,221 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package socktest + +import ( + "internal/syscall/windows" + "syscall" +) + +// Socket wraps syscall.Socket. +func (sw *Switch) Socket(family, sotype, proto int) (s syscall.Handle, err error) { + sw.once.Do(sw.init) + + so := &Status{Cookie: cookie(family, sotype, proto)} + sw.fmu.RLock() + f, _ := sw.fltab[FilterSocket] + sw.fmu.RUnlock() + + af, err := f.apply(so) + if err != nil { + return syscall.InvalidHandle, err + } + s, so.Err = syscall.Socket(family, sotype, proto) + if err = af.apply(so); err != nil { + if so.Err == nil { + syscall.Closesocket(s) + } + return syscall.InvalidHandle, err + } + + sw.smu.Lock() + defer sw.smu.Unlock() + if so.Err != nil { + sw.stats.getLocked(so.Cookie).OpenFailed++ + return syscall.InvalidHandle, so.Err + } + nso := sw.addLocked(s, family, sotype, proto) + sw.stats.getLocked(nso.Cookie).Opened++ + return s, nil +} + +// WSASocket wraps syscall.WSASocket. +func (sw *Switch) WSASocket(family, sotype, proto int32, protinfo *syscall.WSAProtocolInfo, group uint32, flags uint32) (s syscall.Handle, err error) { + sw.once.Do(sw.init) + + so := &Status{Cookie: cookie(int(family), int(sotype), int(proto))} + sw.fmu.RLock() + f, _ := sw.fltab[FilterSocket] + sw.fmu.RUnlock() + + af, err := f.apply(so) + if err != nil { + return syscall.InvalidHandle, err + } + s, so.Err = windows.WSASocket(family, sotype, proto, protinfo, group, flags) + if err = af.apply(so); err != nil { + if so.Err == nil { + syscall.Closesocket(s) + } + return syscall.InvalidHandle, err + } + + sw.smu.Lock() + defer sw.smu.Unlock() + if so.Err != nil { + sw.stats.getLocked(so.Cookie).OpenFailed++ + return syscall.InvalidHandle, so.Err + } + nso := sw.addLocked(s, int(family), int(sotype), int(proto)) + sw.stats.getLocked(nso.Cookie).Opened++ + return s, nil +} + +// Closesocket wraps syscall.Closesocket. +func (sw *Switch) Closesocket(s syscall.Handle) (err error) { + so := sw.sockso(s) + if so == nil { + return syscall.Closesocket(s) + } + sw.fmu.RLock() + f, _ := sw.fltab[FilterClose] + sw.fmu.RUnlock() + + af, err := f.apply(so) + if err != nil { + return err + } + so.Err = syscall.Closesocket(s) + if err = af.apply(so); err != nil { + return err + } + + sw.smu.Lock() + defer sw.smu.Unlock() + if so.Err != nil { + sw.stats.getLocked(so.Cookie).CloseFailed++ + return so.Err + } + delete(sw.sotab, s) + sw.stats.getLocked(so.Cookie).Closed++ + return nil +} + +// Connect wraps syscall.Connect. +func (sw *Switch) Connect(s syscall.Handle, sa syscall.Sockaddr) (err error) { + so := sw.sockso(s) + if so == nil { + return syscall.Connect(s, sa) + } + sw.fmu.RLock() + f, _ := sw.fltab[FilterConnect] + sw.fmu.RUnlock() + + af, err := f.apply(so) + if err != nil { + return err + } + so.Err = syscall.Connect(s, sa) + if err = af.apply(so); err != nil { + return err + } + + sw.smu.Lock() + defer sw.smu.Unlock() + if so.Err != nil { + sw.stats.getLocked(so.Cookie).ConnectFailed++ + return so.Err + } + sw.stats.getLocked(so.Cookie).Connected++ + return nil +} + +// ConnectEx wraps syscall.ConnectEx. +func (sw *Switch) ConnectEx(s syscall.Handle, sa syscall.Sockaddr, b *byte, n uint32, nwr *uint32, o *syscall.Overlapped) (err error) { + so := sw.sockso(s) + if so == nil { + return syscall.ConnectEx(s, sa, b, n, nwr, o) + } + sw.fmu.RLock() + f, _ := sw.fltab[FilterConnect] + sw.fmu.RUnlock() + + af, err := f.apply(so) + if err != nil { + return err + } + so.Err = syscall.ConnectEx(s, sa, b, n, nwr, o) + if err = af.apply(so); err != nil { + return err + } + + sw.smu.Lock() + defer sw.smu.Unlock() + if so.Err != nil { + sw.stats.getLocked(so.Cookie).ConnectFailed++ + return so.Err + } + sw.stats.getLocked(so.Cookie).Connected++ + return nil +} + +// Listen wraps syscall.Listen. +func (sw *Switch) Listen(s syscall.Handle, backlog int) (err error) { + so := sw.sockso(s) + if so == nil { + return syscall.Listen(s, backlog) + } + sw.fmu.RLock() + f, _ := sw.fltab[FilterListen] + sw.fmu.RUnlock() + + af, err := f.apply(so) + if err != nil { + return err + } + so.Err = syscall.Listen(s, backlog) + if err = af.apply(so); err != nil { + return err + } + + sw.smu.Lock() + defer sw.smu.Unlock() + if so.Err != nil { + sw.stats.getLocked(so.Cookie).ListenFailed++ + return so.Err + } + sw.stats.getLocked(so.Cookie).Listened++ + return nil +} + +// AcceptEx wraps syscall.AcceptEx. +func (sw *Switch) AcceptEx(ls syscall.Handle, as syscall.Handle, b *byte, rxdatalen uint32, laddrlen uint32, raddrlen uint32, rcvd *uint32, overlapped *syscall.Overlapped) error { + so := sw.sockso(ls) + if so == nil { + return syscall.AcceptEx(ls, as, b, rxdatalen, laddrlen, raddrlen, rcvd, overlapped) + } + sw.fmu.RLock() + f, _ := sw.fltab[FilterAccept] + sw.fmu.RUnlock() + + af, err := f.apply(so) + if err != nil { + return err + } + so.Err = syscall.AcceptEx(ls, as, b, rxdatalen, laddrlen, raddrlen, rcvd, overlapped) + if err = af.apply(so); err != nil { + return err + } + + sw.smu.Lock() + defer sw.smu.Unlock() + if so.Err != nil { + sw.stats.getLocked(so.Cookie).AcceptFailed++ + return so.Err + } + nso := sw.addLocked(as, so.Cookie.Family(), so.Cookie.Type(), so.Cookie.Protocol()) + sw.stats.getLocked(nso.Cookie).Accepted++ + return nil +} -- cgit v1.2.3