main.go 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. package main
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "log/slog"
  8. "net"
  9. "net/netip"
  10. "os"
  11. "os/signal"
  12. "syscall"
  13. "time"
  14. "unsafe"
  15. )
  16. func main() {
  17. logLevel := new(slog.LevelVar)
  18. logHandler :=
  19. slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
  20. Level: logLevel,
  21. })
  22. slog.SetDefault(slog.New(logHandler))
  23. logLevel.Set(slog.LevelInfo)
  24. ctx, cancel := context.WithCancel(context.Background())
  25. slog.Info("Starting signal watchdog...")
  26. go func() {
  27. sig := make(chan os.Signal, 1)
  28. signal.Notify(sig, os.Interrupt, os.Kill)
  29. <-sig
  30. slog.Info("Caught signal; terminating...")
  31. cancel()
  32. }()
  33. slog.Info("Starting connection listener on port 54321")
  34. go listen(ctx, 54321)
  35. for {
  36. if ctx.Err() != nil {
  37. break
  38. }
  39. sleepNice(ctx, 100*time.Millisecond)
  40. }
  41. }
  42. func listen(ctx context.Context, port uint16) {
  43. tcpAddr := net.TCPAddrFromAddrPort(
  44. netip.AddrPortFrom(
  45. netip.IPv4Unspecified(),
  46. port,
  47. ),
  48. )
  49. ln, err := net.ListenTCP("tcp4", tcpAddr)
  50. if err != nil {
  51. panic(err)
  52. }
  53. for {
  54. if ctx.Err() != nil {
  55. break
  56. }
  57. slog.Info("Awaiting connection...")
  58. conn, err := ln.AcceptTCP()
  59. if err != nil {
  60. slog.Error("Unable to accept connection",
  61. "error", err,
  62. )
  63. }
  64. slog.Info("Received connection",
  65. "local", conn.LocalAddr(),
  66. "remote", conn.RemoteAddr(),
  67. )
  68. go handleConnection(ctx, conn)
  69. }
  70. }
  71. func handleConnection(
  72. ctx context.Context,
  73. conn *net.TCPConn,
  74. ) {
  75. defer func() {
  76. if er := conn.Close(); er != nil {
  77. slog.Warn("Error while closing connection",
  78. "local", conn.LocalAddr(),
  79. "remote", conn.RemoteAddr(),
  80. "error", er,
  81. )
  82. }
  83. }()
  84. if err := conn.SetLinger(0); err != nil {
  85. slog.Error("Could not set TCP linger time", "error", err)
  86. return
  87. }
  88. if err := conn.SetNoDelay(true); err != nil {
  89. slog.Error("Failed to enable NoDelay", "error", err)
  90. return
  91. }
  92. err := conn.SetKeepAliveConfig(net.KeepAliveConfig{
  93. Enable: true,
  94. Idle: 15 * time.Second,
  95. Interval: 15 * time.Second,
  96. Count: 5,
  97. })
  98. if err != nil {
  99. slog.Error("Cannot configure TCP keeplives", "error", err)
  100. return
  101. }
  102. /*sysConn, err := conn.SyscallConn()
  103. if err != nil {
  104. slog.Error("Could not obtain system-level connection", "error", err)
  105. return
  106. }*/
  107. timeout := 10 * time.Second
  108. //outputPeriod := 10 * time.Second
  109. //lastOutput := time.Now().Add(-outputPeriod)
  110. //tcpInfo := syscall.TCPInfo{}
  111. buf := make([]byte, 4)
  112. slog.Debug("Starting loop...",
  113. "local", conn.LocalAddr(),
  114. "remote", conn.RemoteAddr(),
  115. )
  116. for {
  117. if ctx.Err() != nil {
  118. break
  119. }
  120. /*slog.Debug("Retrieving TCP info",
  121. "local", conn.LocalAddr(),
  122. "remote", conn.RemoteAddr(),
  123. )
  124. err = retrieveTCPInfo(sysConn, &tcpInfo)
  125. if err != nil {
  126. slog.Error("Could not retrieve TCP information",
  127. "local", conn.LocalAddr(),
  128. "remote", conn.RemoteAddr(),
  129. "error", err,
  130. )
  131. }
  132. if tcpInfo.State == unix.BPF_TCP_CLOSE {
  133. slog.Info("Lost connection",
  134. "local", conn.LocalAddr(),
  135. "remote", conn.RemoteAddr(),
  136. )
  137. return
  138. }*/
  139. // Send and receive some data.
  140. if err := recv(conn, buf, timeout); err != nil {
  141. slog.Debug("While reading from connection",
  142. "remote", conn.RemoteAddr(),
  143. "error", err,
  144. )
  145. return
  146. }
  147. if err := send(conn, buf, timeout); err != nil {
  148. slog.Debug("While writing to connection",
  149. "remote", conn.RemoteAddr(),
  150. "error", err,
  151. )
  152. return
  153. }
  154. // Periodically print TCPInfo.
  155. /*select {
  156. case <-timer.C:
  157. fmt.Printf("RTT: %0.3f ms; StdDev: %0.3f ms\n",
  158. float32(tcpInfo.Rtt)/1000.0,
  159. float32(tcpInfo.Rttvar)/1000.0,
  160. )
  161. timer.Stop()
  162. timer.Reset(outputPeriod)
  163. default:
  164. }*/
  165. sleepNice(ctx, 10*time.Millisecond)
  166. }
  167. }
  168. func recv(c net.Conn, buf []byte, timeout time.Duration) error {
  169. deadline := time.Now().Add(timeout)
  170. if err := c.SetReadDeadline(deadline); err != nil {
  171. return fmt.Errorf("set deadline: %w", err)
  172. }
  173. _, err := c.Read(buf)
  174. if errors.Is(err, io.EOF) {
  175. slog.Info("Lost connection",
  176. "local", c.LocalAddr(),
  177. "remote", c.RemoteAddr(),
  178. )
  179. return fmt.Errorf("read: %w", err)
  180. }
  181. if err != nil && !errors.Is(err, os.ErrDeadlineExceeded) {
  182. slog.Warn("While reading from connection",
  183. "local", c.LocalAddr(),
  184. "remote", c.RemoteAddr(),
  185. "error", err,
  186. )
  187. return fmt.Errorf("read: %w", err)
  188. }
  189. return err
  190. }
  191. func send(c net.Conn, buf []byte, timeout time.Duration) error {
  192. deadline := time.Now().Add(timeout)
  193. if err := c.SetWriteDeadline(deadline); err != nil {
  194. return fmt.Errorf("set deadline: %w", err)
  195. }
  196. _, err := c.Write(buf)
  197. if errors.Is(err, io.EOF) {
  198. slog.Info("Lost connection",
  199. "local", c.LocalAddr(),
  200. "remote", c.RemoteAddr(),
  201. )
  202. return fmt.Errorf("read: %w", err)
  203. }
  204. if err != nil && !errors.Is(err, os.ErrDeadlineExceeded) {
  205. slog.Debug("While reading from connection",
  206. "local", c.LocalAddr(),
  207. "remote", c.RemoteAddr(),
  208. "error", err,
  209. )
  210. return fmt.Errorf("read: %w", err)
  211. }
  212. return err
  213. }
  214. func sleepNice(ctx context.Context, t time.Duration) {
  215. timer := time.NewTimer(t)
  216. for {
  217. if ctx.Err() != nil {
  218. return
  219. }
  220. select {
  221. case <-timer.C:
  222. return
  223. case <-time.After(100 * time.Millisecond):
  224. // Throttle CPU
  225. }
  226. }
  227. }
  228. func retrieveTCPInfo(
  229. sysConn syscall.RawConn,
  230. info *syscall.TCPInfo,
  231. ) error {
  232. size := unsafe.Sizeof(*info)
  233. var errno syscall.Errno
  234. err := sysConn.Control(func(fd uintptr) {
  235. _, _, errno = syscall.Syscall6(
  236. syscall.SYS_GETSOCKOPT,
  237. fd,
  238. syscall.SOL_TCP,
  239. syscall.TCP_INFO,
  240. uintptr(unsafe.Pointer(info)),
  241. uintptr(unsafe.Pointer(&size)),
  242. 0,
  243. )
  244. })
  245. if err != nil {
  246. return err
  247. }
  248. if errno != 0 {
  249. return fmt.Errorf("syscall errno %d", errno)
  250. }
  251. return nil
  252. }