// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

//go:build !js

package net

import (
	"context"
	"errors"
	"fmt"
	"os"
	"path/filepath"
	"sync"
	"testing"
	"time"
)

// testUnixAddr uses os.MkdirTemp to get a name that is unique.
func testUnixAddr(t testing.TB) string {
	// Pass an empty pattern to get a directory name that is as short as possible.
	// If we end up with a name longer than the sun_path field in the sockaddr_un
	// struct, we won't be able to make the syscall to open the socket.
	d, err := os.MkdirTemp("", "")
	if err != nil {
		t.Fatal(err)
	}
	t.Cleanup(func() {
		if err := os.RemoveAll(d); err != nil {
			t.Error(err)
		}
	})
	return filepath.Join(d, "sock")
}

func newLocalListener(t testing.TB, network string, lcOpt ...*ListenConfig) Listener {
	var lc *ListenConfig
	switch len(lcOpt) {
	case 0:
		lc = new(ListenConfig)
	case 1:
		lc = lcOpt[0]
	default:
		t.Helper()
		t.Fatal("too many ListenConfigs passed to newLocalListener: want 0 or 1")
	}

	listen := func(net, addr string) Listener {
		ln, err := lc.Listen(context.Background(), net, addr)
		if err != nil {
			t.Helper()
			t.Fatal(err)
		}
		return ln
	}

	switch network {
	case "tcp":
		if supportsIPv4() {
			if !supportsIPv6() {
				return listen("tcp4", "127.0.0.1:0")
			}
			if ln, err := Listen("tcp4", "127.0.0.1:0"); err == nil {
				return ln
			}
		}
		if supportsIPv6() {
			return listen("tcp6", "[::1]:0")
		}
	case "tcp4":
		if supportsIPv4() {
			return listen("tcp4", "127.0.0.1:0")
		}
	case "tcp6":
		if supportsIPv6() {
			return listen("tcp6", "[::1]:0")
		}
	case "unix", "unixpacket":
		return listen(network, testUnixAddr(t))
	}

	t.Helper()
	t.Fatalf("%s is not supported", network)
	return nil
}

func newDualStackListener() (lns []*TCPListener, err error) {
	var args = []struct {
		network string
		TCPAddr
	}{
		{"tcp4", TCPAddr{IP: IPv4(127, 0, 0, 1)}},
		{"tcp6", TCPAddr{IP: IPv6loopback}},
	}
	for i := 0; i < 64; i++ {
		var port int
		var lns []*TCPListener
		for _, arg := range args {
			arg.TCPAddr.Port = port
			ln, err := ListenTCP(arg.network, &arg.TCPAddr)
			if err != nil {
				continue
			}
			port = ln.Addr().(*TCPAddr).Port
			lns = append(lns, ln)
		}
		if len(lns) != len(args) {
			for _, ln := range lns {
				ln.Close()
			}
			continue
		}
		return lns, nil
	}
	return nil, errors.New("no dualstack port available")
}

type localServer struct {
	lnmu sync.RWMutex
	Listener
	done chan bool // signal that indicates server stopped
	cl   []Conn    // accepted connection list
}

func (ls *localServer) buildup(handler func(*localServer, Listener)) error {
	go func() {
		handler(ls, ls.Listener)
		close(ls.done)
	}()
	return nil
}

func (ls *localServer) teardown() error {
	ls.lnmu.Lock()
	defer ls.lnmu.Unlock()
	if ls.Listener != nil {
		network := ls.Listener.Addr().Network()
		address := ls.Listener.Addr().String()
		ls.Listener.Close()
		for _, c := range ls.cl {
			if err := c.Close(); err != nil {
				return err
			}
		}
		<-ls.done
		ls.Listener = nil
		switch network {
		case "unix", "unixpacket":
			os.Remove(address)
		}
	}
	return nil
}

func newLocalServer(t testing.TB, network string) *localServer {
	t.Helper()
	ln := newLocalListener(t, network)
	return &localServer{Listener: ln, done: make(chan bool)}
}

type streamListener struct {
	network, address string
	Listener
	done chan bool // signal that indicates server stopped
}

func (sl *streamListener) newLocalServer() *localServer {
	return &localServer{Listener: sl.Listener, done: make(chan bool)}
}

type dualStackServer struct {
	lnmu sync.RWMutex
	lns  []streamListener
	port string

	cmu sync.RWMutex
	cs  []Conn // established connections at the passive open side
}

func (dss *dualStackServer) buildup(handler func(*dualStackServer, Listener)) error {
	for i := range dss.lns {
		go func(i int) {
			handler(dss, dss.lns[i].Listener)
			close(dss.lns[i].done)
		}(i)
	}
	return nil
}

func (dss *dualStackServer) teardownNetwork(network string) error {
	dss.lnmu.Lock()
	for i := range dss.lns {
		if network == dss.lns[i].network && dss.lns[i].Listener != nil {
			dss.lns[i].Listener.Close()
			<-dss.lns[i].done
			dss.lns[i].Listener = nil
		}
	}
	dss.lnmu.Unlock()
	return nil
}

func (dss *dualStackServer) teardown() error {
	dss.lnmu.Lock()
	for i := range dss.lns {
		if dss.lns[i].Listener != nil {
			dss.lns[i].Listener.Close()
			<-dss.lns[i].done
		}
	}
	dss.lns = dss.lns[:0]
	dss.lnmu.Unlock()
	dss.cmu.Lock()
	for _, c := range dss.cs {
		c.Close()
	}
	dss.cs = dss.cs[:0]
	dss.cmu.Unlock()
	return nil
}

func newDualStackServer() (*dualStackServer, error) {
	lns, err := newDualStackListener()
	if err != nil {
		return nil, err
	}
	_, port, err := SplitHostPort(lns[0].Addr().String())
	if err != nil {
		lns[0].Close()
		lns[1].Close()
		return nil, err
	}
	return &dualStackServer{
		lns: []streamListener{
			{network: "tcp4", address: lns[0].Addr().String(), Listener: lns[0], done: make(chan bool)},
			{network: "tcp6", address: lns[1].Addr().String(), Listener: lns[1], done: make(chan bool)},
		},
		port: port,
	}, nil
}

func (ls *localServer) transponder(ln Listener, ch chan<- error) {
	defer close(ch)

	switch ln := ln.(type) {
	case *TCPListener:
		ln.SetDeadline(time.Now().Add(someTimeout))
	case *UnixListener:
		ln.SetDeadline(time.Now().Add(someTimeout))
	}
	c, err := ln.Accept()
	if err != nil {
		if perr := parseAcceptError(err); perr != nil {
			ch <- perr
		}
		ch <- err
		return
	}
	ls.cl = append(ls.cl, c)

	network := ln.Addr().Network()
	if c.LocalAddr().Network() != network || c.RemoteAddr().Network() != network {
		ch <- fmt.Errorf("got %v->%v; expected %v->%v", c.LocalAddr().Network(), c.RemoteAddr().Network(), network, network)
		return
	}
	c.SetDeadline(time.Now().Add(someTimeout))
	c.SetReadDeadline(time.Now().Add(someTimeout))
	c.SetWriteDeadline(time.Now().Add(someTimeout))

	b := make([]byte, 256)
	n, err := c.Read(b)
	if err != nil {
		if perr := parseReadError(err); perr != nil {
			ch <- perr
		}
		ch <- err
		return
	}
	if _, err := c.Write(b[:n]); err != nil {
		if perr := parseWriteError(err); perr != nil {
			ch <- perr
		}
		ch <- err
		return
	}
}

func transceiver(c Conn, wb []byte, ch chan<- error) {
	defer close(ch)

	c.SetDeadline(time.Now().Add(someTimeout))
	c.SetReadDeadline(time.Now().Add(someTimeout))
	c.SetWriteDeadline(time.Now().Add(someTimeout))

	n, err := c.Write(wb)
	if err != nil {
		if perr := parseWriteError(err); perr != nil {
			ch <- perr
		}
		ch <- err
		return
	}
	if n != len(wb) {
		ch <- fmt.Errorf("wrote %d; want %d", n, len(wb))
	}
	rb := make([]byte, len(wb))
	n, err = c.Read(rb)
	if err != nil {
		if perr := parseReadError(err); perr != nil {
			ch <- perr
		}
		ch <- err
		return
	}
	if n != len(wb) {
		ch <- fmt.Errorf("read %d; want %d", n, len(wb))
	}
}

func newLocalPacketListener(t testing.TB, network string, lcOpt ...*ListenConfig) PacketConn {
	var lc *ListenConfig
	switch len(lcOpt) {
	case 0:
		lc = new(ListenConfig)
	case 1:
		lc = lcOpt[0]
	default:
		t.Helper()
		t.Fatal("too many ListenConfigs passed to newLocalListener: want 0 or 1")
	}

	listenPacket := func(net, addr string) PacketConn {
		c, err := lc.ListenPacket(context.Background(), net, addr)
		if err != nil {
			t.Helper()
			t.Fatal(err)
		}
		return c
	}

	switch network {
	case "udp":
		if supportsIPv4() {
			return listenPacket("udp4", "127.0.0.1:0")
		}
		if supportsIPv6() {
			return listenPacket("udp6", "[::1]:0")
		}
	case "udp4":
		if supportsIPv4() {
			return listenPacket("udp4", "127.0.0.1:0")
		}
	case "udp6":
		if supportsIPv6() {
			return listenPacket("udp6", "[::1]:0")
		}
	case "unixgram":
		return listenPacket(network, testUnixAddr(t))
	}

	t.Helper()
	t.Fatalf("%s is not supported", network)
	return nil
}

func newDualStackPacketListener() (cs []*UDPConn, err error) {
	var args = []struct {
		network string
		UDPAddr
	}{
		{"udp4", UDPAddr{IP: IPv4(127, 0, 0, 1)}},
		{"udp6", UDPAddr{IP: IPv6loopback}},
	}
	for i := 0; i < 64; i++ {
		var port int
		var cs []*UDPConn
		for _, arg := range args {
			arg.UDPAddr.Port = port
			c, err := ListenUDP(arg.network, &arg.UDPAddr)
			if err != nil {
				continue
			}
			port = c.LocalAddr().(*UDPAddr).Port
			cs = append(cs, c)
		}
		if len(cs) != len(args) {
			for _, c := range cs {
				c.Close()
			}
			continue
		}
		return cs, nil
	}
	return nil, errors.New("no dualstack port available")
}

type localPacketServer struct {
	pcmu sync.RWMutex
	PacketConn
	done chan bool // signal that indicates server stopped
}

func (ls *localPacketServer) buildup(handler func(*localPacketServer, PacketConn)) error {
	go func() {
		handler(ls, ls.PacketConn)
		close(ls.done)
	}()
	return nil
}

func (ls *localPacketServer) teardown() error {
	ls.pcmu.Lock()
	if ls.PacketConn != nil {
		network := ls.PacketConn.LocalAddr().Network()
		address := ls.PacketConn.LocalAddr().String()
		ls.PacketConn.Close()
		<-ls.done
		ls.PacketConn = nil
		switch network {
		case "unixgram":
			os.Remove(address)
		}
	}
	ls.pcmu.Unlock()
	return nil
}

func newLocalPacketServer(t testing.TB, network string) *localPacketServer {
	t.Helper()
	c := newLocalPacketListener(t, network)
	return &localPacketServer{PacketConn: c, done: make(chan bool)}
}

type packetListener struct {
	PacketConn
}

func (pl *packetListener) newLocalServer() *localPacketServer {
	return &localPacketServer{PacketConn: pl.PacketConn, done: make(chan bool)}
}

func packetTransponder(c PacketConn, ch chan<- error) {
	defer close(ch)

	c.SetDeadline(time.Now().Add(someTimeout))
	c.SetReadDeadline(time.Now().Add(someTimeout))
	c.SetWriteDeadline(time.Now().Add(someTimeout))

	b := make([]byte, 256)
	n, peer, err := c.ReadFrom(b)
	if err != nil {
		if perr := parseReadError(err); perr != nil {
			ch <- perr
		}
		ch <- err
		return
	}
	if peer == nil { // for connected-mode sockets
		switch c.LocalAddr().Network() {
		case "udp":
			peer, err = ResolveUDPAddr("udp", string(b[:n]))
		case "unixgram":
			peer, err = ResolveUnixAddr("unixgram", string(b[:n]))
		}
		if err != nil {
			ch <- err
			return
		}
	}
	if _, err := c.WriteTo(b[:n], peer); err != nil {
		if perr := parseWriteError(err); perr != nil {
			ch <- perr
		}
		ch <- err
		return
	}
}

func packetTransceiver(c PacketConn, wb []byte, dst Addr, ch chan<- error) {
	defer close(ch)

	c.SetDeadline(time.Now().Add(someTimeout))
	c.SetReadDeadline(time.Now().Add(someTimeout))
	c.SetWriteDeadline(time.Now().Add(someTimeout))

	n, err := c.WriteTo(wb, dst)
	if err != nil {
		if perr := parseWriteError(err); perr != nil {
			ch <- perr
		}
		ch <- err
		return
	}
	if n != len(wb) {
		ch <- fmt.Errorf("wrote %d; want %d", n, len(wb))
	}
	rb := make([]byte, len(wb))
	n, _, err = c.ReadFrom(rb)
	if err != nil {
		if perr := parseReadError(err); perr != nil {
			ch <- perr
		}
		ch <- err
		return
	}
	if n != len(wb) {
		ch <- fmt.Errorf("read %d; want %d", n, len(wb))
	}
}