summaryrefslogtreecommitdiffstats
path: root/modules/graceful/server.go
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-10-11 10:27:00 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-10-11 10:27:00 +0000
commit65aa53fc52ff15efe54df4147564828d535837f8 (patch)
tree31c51dad04fdcca80e6d3043c8bd49d2f1a51f83 /modules/graceful/server.go
parentInitial commit. (diff)
downloadforgejo-upstream.tar.xz
forgejo-upstream.zip
Adding upstream version 8.0.3.HEADupstream/8.0.3upstreamdebian
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'modules/graceful/server.go')
-rw-r--r--modules/graceful/server.go284
1 files changed, 284 insertions, 0 deletions
diff --git a/modules/graceful/server.go b/modules/graceful/server.go
new file mode 100644
index 00000000..2525a83e
--- /dev/null
+++ b/modules/graceful/server.go
@@ -0,0 +1,284 @@
+// Copyright 2019 The Gitea Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+// This code is highly inspired by endless go
+
+package graceful
+
+import (
+ "crypto/tls"
+ "net"
+ "os"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "syscall"
+ "time"
+
+ "code.gitea.io/gitea/modules/log"
+ "code.gitea.io/gitea/modules/proxyprotocol"
+ "code.gitea.io/gitea/modules/setting"
+)
+
+// GetListener returns a net listener
+// This determines the implementation of net.Listener which the server will use,
+// so that downstreams could provide their own Listener, such as with a hidden service or a p2p network
+var GetListener = DefaultGetListener
+
+// ServeFunction represents a listen.Accept loop
+type ServeFunction = func(net.Listener) error
+
+// Server represents our graceful server
+type Server struct {
+ network string
+ address string
+ listener net.Listener
+ wg sync.WaitGroup
+ state state
+ lock *sync.RWMutex
+ BeforeBegin func(network, address string)
+ OnShutdown func()
+ PerWriteTimeout time.Duration
+ PerWritePerKbTimeout time.Duration
+}
+
+// NewServer creates a server on network at provided address
+func NewServer(network, address, name string) *Server {
+ if GetManager().IsChild() {
+ log.Info("Restarting new %s server: %s:%s on PID: %d", name, network, address, os.Getpid())
+ } else {
+ log.Info("Starting new %s server: %s:%s on PID: %d", name, network, address, os.Getpid())
+ }
+ srv := &Server{
+ wg: sync.WaitGroup{},
+ state: stateInit,
+ lock: &sync.RWMutex{},
+ network: network,
+ address: address,
+ PerWriteTimeout: setting.PerWriteTimeout,
+ PerWritePerKbTimeout: setting.PerWritePerKbTimeout,
+ }
+
+ srv.BeforeBegin = func(network, addr string) {
+ log.Debug("Starting server on %s:%s (PID: %d)", network, addr, syscall.Getpid())
+ }
+
+ return srv
+}
+
+// ListenAndServe listens on the provided network address and then calls Serve
+// to handle requests on incoming connections.
+func (srv *Server) ListenAndServe(serve ServeFunction, useProxyProtocol bool) error {
+ go srv.awaitShutdown()
+
+ listener, err := GetListener(srv.network, srv.address)
+ if err != nil {
+ log.Error("Unable to GetListener: %v", err)
+ return err
+ }
+
+ // we need to wrap the listener to take account of our lifecycle
+ listener = newWrappedListener(listener, srv)
+
+ // Now we need to take account of ProxyProtocol settings...
+ if useProxyProtocol {
+ listener = &proxyprotocol.Listener{
+ Listener: listener,
+ ProxyHeaderTimeout: setting.ProxyProtocolHeaderTimeout,
+ AcceptUnknown: setting.ProxyProtocolAcceptUnknown,
+ }
+ }
+ srv.listener = listener
+
+ srv.BeforeBegin(srv.network, srv.address)
+
+ return srv.Serve(serve)
+}
+
+// ListenAndServeTLSConfig listens on the provided network address and then calls
+// Serve to handle requests on incoming TLS connections.
+func (srv *Server) ListenAndServeTLSConfig(tlsConfig *tls.Config, serve ServeFunction, useProxyProtocol, proxyProtocolTLSBridging bool) error {
+ go srv.awaitShutdown()
+
+ if tlsConfig.MinVersion == 0 {
+ tlsConfig.MinVersion = tls.VersionTLS12
+ }
+
+ listener, err := GetListener(srv.network, srv.address)
+ if err != nil {
+ log.Error("Unable to get Listener: %v", err)
+ return err
+ }
+
+ // we need to wrap the listener to take account of our lifecycle
+ listener = newWrappedListener(listener, srv)
+
+ // Now we need to take account of ProxyProtocol settings... If we're not bridging then we expect that the proxy will forward the connection to us
+ if useProxyProtocol && !proxyProtocolTLSBridging {
+ listener = &proxyprotocol.Listener{
+ Listener: listener,
+ ProxyHeaderTimeout: setting.ProxyProtocolHeaderTimeout,
+ AcceptUnknown: setting.ProxyProtocolAcceptUnknown,
+ }
+ }
+
+ // Now handle the tls protocol
+ listener = tls.NewListener(listener, tlsConfig)
+
+ // Now if we're bridging then we need the proxy to tell us who we're bridging for...
+ if useProxyProtocol && proxyProtocolTLSBridging {
+ listener = &proxyprotocol.Listener{
+ Listener: listener,
+ ProxyHeaderTimeout: setting.ProxyProtocolHeaderTimeout,
+ AcceptUnknown: setting.ProxyProtocolAcceptUnknown,
+ }
+ }
+
+ srv.listener = listener
+ srv.BeforeBegin(srv.network, srv.address)
+
+ return srv.Serve(serve)
+}
+
+// Serve accepts incoming HTTP connections on the wrapped listener l, creating a new
+// service goroutine for each. The service goroutines read requests and then call
+// handler to reply to them. Handler is typically nil, in which case the
+// DefaultServeMux is used.
+//
+// In addition to the standard Serve behaviour each connection is added to a
+// sync.Waitgroup so that all outstanding connections can be served before shutting
+// down the server.
+func (srv *Server) Serve(serve ServeFunction) error {
+ defer log.Debug("Serve() returning... (PID: %d)", syscall.Getpid())
+ srv.setState(stateRunning)
+ GetManager().RegisterServer()
+ err := serve(srv.listener)
+ log.Debug("Waiting for connections to finish... (PID: %d)", syscall.Getpid())
+ srv.wg.Wait()
+ srv.setState(stateTerminate)
+ GetManager().ServerDone()
+ // use of closed means that the listeners are closed - i.e. we should be shutting down - return nil
+ if err == nil || strings.Contains(err.Error(), "use of closed") || strings.Contains(err.Error(), "http: Server closed") {
+ return nil
+ }
+ return err
+}
+
+func (srv *Server) getState() state {
+ srv.lock.RLock()
+ defer srv.lock.RUnlock()
+
+ return srv.state
+}
+
+func (srv *Server) setState(st state) {
+ srv.lock.Lock()
+ defer srv.lock.Unlock()
+
+ srv.state = st
+}
+
+type filer interface {
+ File() (*os.File, error)
+}
+
+type wrappedListener struct {
+ net.Listener
+ stopped bool
+ server *Server
+}
+
+func newWrappedListener(l net.Listener, srv *Server) *wrappedListener {
+ return &wrappedListener{
+ Listener: l,
+ server: srv,
+ }
+}
+
+func (wl *wrappedListener) Accept() (net.Conn, error) {
+ var c net.Conn
+ // Set keepalive on TCPListeners connections.
+ if tcl, ok := wl.Listener.(*net.TCPListener); ok {
+ tc, err := tcl.AcceptTCP()
+ if err != nil {
+ return nil, err
+ }
+ _ = tc.SetKeepAlive(true) // see http.tcpKeepAliveListener
+ _ = tc.SetKeepAlivePeriod(3 * time.Minute) // see http.tcpKeepAliveListener
+ c = tc
+ } else {
+ var err error
+ c, err = wl.Listener.Accept()
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ closed := int32(0)
+
+ c = &wrappedConn{
+ Conn: c,
+ server: wl.server,
+ closed: &closed,
+ perWriteTimeout: wl.server.PerWriteTimeout,
+ perWritePerKbTimeout: wl.server.PerWritePerKbTimeout,
+ }
+
+ wl.server.wg.Add(1)
+ return c, nil
+}
+
+func (wl *wrappedListener) Close() error {
+ if wl.stopped {
+ return syscall.EINVAL
+ }
+
+ wl.stopped = true
+ return wl.Listener.Close()
+}
+
+func (wl *wrappedListener) File() (*os.File, error) {
+ // returns a dup(2) - FD_CLOEXEC flag *not* set so the listening socket can be passed to child processes
+ return wl.Listener.(filer).File()
+}
+
+type wrappedConn struct {
+ net.Conn
+ server *Server
+ closed *int32
+ deadline time.Time
+ perWriteTimeout time.Duration
+ perWritePerKbTimeout time.Duration
+}
+
+func (w *wrappedConn) Write(p []byte) (n int, err error) {
+ if w.perWriteTimeout > 0 {
+ minTimeout := time.Duration(len(p)/1024) * w.perWritePerKbTimeout
+ minDeadline := time.Now().Add(minTimeout).Add(w.perWriteTimeout)
+
+ w.deadline = w.deadline.Add(minTimeout)
+ if minDeadline.After(w.deadline) {
+ w.deadline = minDeadline
+ }
+ _ = w.Conn.SetWriteDeadline(w.deadline)
+ }
+ return w.Conn.Write(p)
+}
+
+func (w *wrappedConn) Close() error {
+ if atomic.CompareAndSwapInt32(w.closed, 0, 1) {
+ defer func() {
+ if err := recover(); err != nil {
+ select {
+ case <-GetManager().IsHammer():
+ // Likely deadlocked request released at hammertime
+ log.Warn("Panic during connection close! %v. Likely there has been a deadlocked request which has been released by forced shutdown.", err)
+ default:
+ log.Error("Panic during connection close! %v", err)
+ }
+ }
+ }()
+ w.server.wg.Done()
+ }
+ return w.Conn.Close()
+}