package main import ( "context" "errors" "fmt" "io" "log/slog" "net" "net/netip" "os" "os/signal" "syscall" "time" "unsafe" ) func main() { logLevel := new(slog.LevelVar) logHandler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ Level: logLevel, }) slog.SetDefault(slog.New(logHandler)) logLevel.Set(slog.LevelInfo) ctx, cancel := context.WithCancel(context.Background()) slog.Info("Starting signal watchdog...") go func() { sig := make(chan os.Signal, 1) signal.Notify(sig, os.Interrupt, os.Kill) <-sig slog.Info("Caught signal; terminating...") cancel() }() slog.Info("Starting connection listener on port 54321") go listen(ctx, 54321) for { if ctx.Err() != nil { break } sleepNice(ctx, 100*time.Millisecond) } } func listen(ctx context.Context, port uint16) { tcpAddr := net.TCPAddrFromAddrPort( netip.AddrPortFrom( netip.IPv4Unspecified(), port, ), ) ln, err := net.ListenTCP("tcp4", tcpAddr) if err != nil { panic(err) } for { if ctx.Err() != nil { break } slog.Info("Awaiting connection...") conn, err := ln.AcceptTCP() if err != nil { slog.Error("Unable to accept connection", "error", err, ) } slog.Info("Received connection", "local", conn.LocalAddr(), "remote", conn.RemoteAddr(), ) go handleConnection(ctx, conn) } } func handleConnection( ctx context.Context, conn *net.TCPConn, ) { defer func() { if er := conn.Close(); er != nil { slog.Warn("Error while closing connection", "local", conn.LocalAddr(), "remote", conn.RemoteAddr(), "error", er, ) } }() if err := conn.SetLinger(0); err != nil { slog.Error("Could not set TCP linger time", "error", err) return } if err := conn.SetNoDelay(true); err != nil { slog.Error("Failed to enable NoDelay", "error", err) return } err := conn.SetKeepAliveConfig(net.KeepAliveConfig{ Enable: true, Idle: 15 * time.Second, Interval: 15 * time.Second, Count: 5, }) if err != nil { slog.Error("Cannot configure TCP keeplives", "error", err) return } /*sysConn, err := conn.SyscallConn() if err != nil { slog.Error("Could not obtain system-level connection", "error", err) return }*/ timeout := 10 * time.Second //outputPeriod := 10 * time.Second //lastOutput := time.Now().Add(-outputPeriod) //tcpInfo := syscall.TCPInfo{} buf := make([]byte, 4) slog.Debug("Starting loop...", "local", conn.LocalAddr(), "remote", conn.RemoteAddr(), ) for { if ctx.Err() != nil { break } /*slog.Debug("Retrieving TCP info", "local", conn.LocalAddr(), "remote", conn.RemoteAddr(), ) err = retrieveTCPInfo(sysConn, &tcpInfo) if err != nil { slog.Error("Could not retrieve TCP information", "local", conn.LocalAddr(), "remote", conn.RemoteAddr(), "error", err, ) } if tcpInfo.State == unix.BPF_TCP_CLOSE { slog.Info("Lost connection", "local", conn.LocalAddr(), "remote", conn.RemoteAddr(), ) return }*/ // Send and receive some data. if err := recv(conn, buf, timeout); err != nil { slog.Debug("While reading from connection", "remote", conn.RemoteAddr(), "error", err, ) return } if err := send(conn, buf, timeout); err != nil { slog.Debug("While writing to connection", "remote", conn.RemoteAddr(), "error", err, ) return } // Periodically print TCPInfo. /*select { case <-timer.C: fmt.Printf("RTT: %0.3f ms; StdDev: %0.3f ms\n", float32(tcpInfo.Rtt)/1000.0, float32(tcpInfo.Rttvar)/1000.0, ) timer.Stop() timer.Reset(outputPeriod) default: }*/ sleepNice(ctx, 10*time.Millisecond) } } func recv(c net.Conn, buf []byte, timeout time.Duration) error { deadline := time.Now().Add(timeout) if err := c.SetReadDeadline(deadline); err != nil { return fmt.Errorf("set deadline: %w", err) } _, err := c.Read(buf) if errors.Is(err, io.EOF) { slog.Info("Lost connection", "local", c.LocalAddr(), "remote", c.RemoteAddr(), ) return fmt.Errorf("read: %w", err) } if err != nil && !errors.Is(err, os.ErrDeadlineExceeded) { slog.Warn("While reading from connection", "local", c.LocalAddr(), "remote", c.RemoteAddr(), "error", err, ) return fmt.Errorf("read: %w", err) } return err } func send(c net.Conn, buf []byte, timeout time.Duration) error { deadline := time.Now().Add(timeout) if err := c.SetWriteDeadline(deadline); err != nil { return fmt.Errorf("set deadline: %w", err) } _, err := c.Write(buf) if errors.Is(err, io.EOF) { slog.Info("Lost connection", "local", c.LocalAddr(), "remote", c.RemoteAddr(), ) return fmt.Errorf("read: %w", err) } if err != nil && !errors.Is(err, os.ErrDeadlineExceeded) { slog.Debug("While reading from connection", "local", c.LocalAddr(), "remote", c.RemoteAddr(), "error", err, ) return fmt.Errorf("read: %w", err) } return err } func sleepNice(ctx context.Context, t time.Duration) { timer := time.NewTimer(t) for { if ctx.Err() != nil { return } select { case <-timer.C: return case <-time.After(100 * time.Millisecond): // Throttle CPU } } } func retrieveTCPInfo( sysConn syscall.RawConn, info *syscall.TCPInfo, ) error { size := unsafe.Sizeof(*info) var errno syscall.Errno err := sysConn.Control(func(fd uintptr) { _, _, errno = syscall.Syscall6( syscall.SYS_GETSOCKOPT, fd, syscall.SOL_TCP, syscall.TCP_INFO, uintptr(unsafe.Pointer(info)), uintptr(unsafe.Pointer(&size)), 0, ) }) if err != nil { return err } if errno != 0 { return fmt.Errorf("syscall errno %d", errno) } return nil }