123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276 |
- package main
- import (
- "context"
- "errors"
- "fmt"
- "io"
- "log/slog"
- "net"
- "net/netip"
- "os"
- "os/signal"
- "syscall"
- "time"
- "unsafe"
- )
- func main() {
- logLevel := new(slog.LevelVar)
- logHandler :=
- slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
- Level: logLevel,
- })
- slog.SetDefault(slog.New(logHandler))
- logLevel.Set(slog.LevelInfo)
- ctx, cancel := context.WithCancel(context.Background())
- slog.Info("Starting signal watchdog...")
- go func() {
- sig := make(chan os.Signal, 1)
- signal.Notify(sig, os.Interrupt, os.Kill)
- <-sig
- slog.Info("Caught signal; terminating...")
- cancel()
- }()
- slog.Info("Starting connection listener on port 54321")
- go listen(ctx, 54321)
- for {
- if ctx.Err() != nil {
- break
- }
- sleepNice(ctx, 100*time.Millisecond)
- }
- }
- func listen(ctx context.Context, port uint16) {
- tcpAddr := net.TCPAddrFromAddrPort(
- netip.AddrPortFrom(
- netip.IPv4Unspecified(),
- port,
- ),
- )
- ln, err := net.ListenTCP("tcp4", tcpAddr)
- if err != nil {
- panic(err)
- }
- for {
- if ctx.Err() != nil {
- break
- }
- slog.Info("Awaiting connection...")
- conn, err := ln.AcceptTCP()
- if err != nil {
- slog.Error("Unable to accept connection",
- "error", err,
- )
- }
- slog.Info("Received connection",
- "local", conn.LocalAddr(),
- "remote", conn.RemoteAddr(),
- )
- go handleConnection(ctx, conn)
- }
- }
- func handleConnection(
- ctx context.Context,
- conn *net.TCPConn,
- ) {
- defer func() {
- if er := conn.Close(); er != nil {
- slog.Warn("Error while closing connection",
- "local", conn.LocalAddr(),
- "remote", conn.RemoteAddr(),
- "error", er,
- )
- }
- }()
- if err := conn.SetLinger(0); err != nil {
- slog.Error("Could not set TCP linger time", "error", err)
- return
- }
- if err := conn.SetNoDelay(true); err != nil {
- slog.Error("Failed to enable NoDelay", "error", err)
- return
- }
- err := conn.SetKeepAliveConfig(net.KeepAliveConfig{
- Enable: true,
- Idle: 15 * time.Second,
- Interval: 15 * time.Second,
- Count: 5,
- })
- if err != nil {
- slog.Error("Cannot configure TCP keeplives", "error", err)
- return
- }
- /*sysConn, err := conn.SyscallConn()
- if err != nil {
- slog.Error("Could not obtain system-level connection", "error", err)
- return
- }*/
- timeout := 10 * time.Second
- //outputPeriod := 10 * time.Second
- //lastOutput := time.Now().Add(-outputPeriod)
- //tcpInfo := syscall.TCPInfo{}
- buf := make([]byte, 4)
- slog.Debug("Starting loop...",
- "local", conn.LocalAddr(),
- "remote", conn.RemoteAddr(),
- )
- for {
- if ctx.Err() != nil {
- break
- }
- /*slog.Debug("Retrieving TCP info",
- "local", conn.LocalAddr(),
- "remote", conn.RemoteAddr(),
- )
- err = retrieveTCPInfo(sysConn, &tcpInfo)
- if err != nil {
- slog.Error("Could not retrieve TCP information",
- "local", conn.LocalAddr(),
- "remote", conn.RemoteAddr(),
- "error", err,
- )
- }
- if tcpInfo.State == unix.BPF_TCP_CLOSE {
- slog.Info("Lost connection",
- "local", conn.LocalAddr(),
- "remote", conn.RemoteAddr(),
- )
- return
- }*/
- // Send and receive some data.
- if err := recv(conn, buf, timeout); err != nil {
- slog.Debug("While reading from connection",
- "remote", conn.RemoteAddr(),
- "error", err,
- )
- return
- }
- if err := send(conn, buf, timeout); err != nil {
- slog.Debug("While writing to connection",
- "remote", conn.RemoteAddr(),
- "error", err,
- )
- return
- }
- // Periodically print TCPInfo.
- /*select {
- case <-timer.C:
- fmt.Printf("RTT: %0.3f ms; StdDev: %0.3f ms\n",
- float32(tcpInfo.Rtt)/1000.0,
- float32(tcpInfo.Rttvar)/1000.0,
- )
- timer.Stop()
- timer.Reset(outputPeriod)
- default:
- }*/
- sleepNice(ctx, 10*time.Millisecond)
- }
- }
- func recv(c net.Conn, buf []byte, timeout time.Duration) error {
- deadline := time.Now().Add(timeout)
- if err := c.SetReadDeadline(deadline); err != nil {
- return fmt.Errorf("set deadline: %w", err)
- }
- _, err := c.Read(buf)
- if errors.Is(err, io.EOF) {
- slog.Info("Lost connection",
- "local", c.LocalAddr(),
- "remote", c.RemoteAddr(),
- )
- return fmt.Errorf("read: %w", err)
- }
- if err != nil && !errors.Is(err, os.ErrDeadlineExceeded) {
- slog.Warn("While reading from connection",
- "local", c.LocalAddr(),
- "remote", c.RemoteAddr(),
- "error", err,
- )
- return fmt.Errorf("read: %w", err)
- }
- return err
- }
- func send(c net.Conn, buf []byte, timeout time.Duration) error {
- deadline := time.Now().Add(timeout)
- if err := c.SetWriteDeadline(deadline); err != nil {
- return fmt.Errorf("set deadline: %w", err)
- }
- _, err := c.Write(buf)
- if errors.Is(err, io.EOF) {
- slog.Info("Lost connection",
- "local", c.LocalAddr(),
- "remote", c.RemoteAddr(),
- )
- return fmt.Errorf("read: %w", err)
- }
- if err != nil && !errors.Is(err, os.ErrDeadlineExceeded) {
- slog.Debug("While reading from connection",
- "local", c.LocalAddr(),
- "remote", c.RemoteAddr(),
- "error", err,
- )
- return fmt.Errorf("read: %w", err)
- }
- return err
- }
- func sleepNice(ctx context.Context, t time.Duration) {
- timer := time.NewTimer(t)
- for {
- if ctx.Err() != nil {
- return
- }
- select {
- case <-timer.C:
- return
- case <-time.After(100 * time.Millisecond):
- // Throttle CPU
- }
- }
- }
- func retrieveTCPInfo(
- sysConn syscall.RawConn,
- info *syscall.TCPInfo,
- ) error {
- size := unsafe.Sizeof(*info)
- var errno syscall.Errno
- err := sysConn.Control(func(fd uintptr) {
- _, _, errno = syscall.Syscall6(
- syscall.SYS_GETSOCKOPT,
- fd,
- syscall.SOL_TCP,
- syscall.TCP_INFO,
- uintptr(unsafe.Pointer(info)),
- uintptr(unsafe.Pointer(&size)),
- 0,
- )
- })
- if err != nil {
- return err
- }
- if errno != 0 {
- return fmt.Errorf("syscall errno %d", errno)
- }
- return nil
- }
|