From 28d2b47164e73fb4555fbf0e7ce44f39443ebd07 Mon Sep 17 00:00:00 2001 From: JackDoan Date: Mon, 13 Apr 2026 14:24:30 -0500 Subject: [PATCH] udp_linux: wrap socket operations with syscall.RawConn for clean teardown --- udp/udp_linux.go | 304 ++++++++++++++++++++++------------------------- 1 file changed, 142 insertions(+), 162 deletions(-) diff --git a/udp/udp_linux.go b/udp/udp_linux.go index e7759329..d8027708 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -4,6 +4,7 @@ package udp import ( + "context" "encoding/binary" "fmt" "net" @@ -18,58 +19,50 @@ import ( ) type StdConn struct { - sysFd int - isV4 bool - l *logrus.Logger - batch int + udpConn *net.UDPConn + rawConn syscall.RawConn + isV4 bool + l *logrus.Logger + batch int } -func maybeIPV4(ip net.IP) (net.IP, bool) { - ip4 := ip.To4() - if ip4 != nil { - return ip4, true +func setReusePort(network, address string, c syscall.RawConn) error { + var opErr error + err := c.Control(func(fd uintptr) { + opErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1) + //CloseOnExec already set by the runtime + }) + if err != nil { + return err } - return ip, false + return opErr } func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { - af := unix.AF_INET6 - if ip.Is4() { - af = unix.AF_INET + listen := netip.AddrPortFrom(ip, uint16(port)) + lc := net.ListenConfig{} + if multi { + lc.Control = setReusePort } - syscall.ForkLock.RLock() - fd, err := unix.Socket(af, unix.SOCK_DGRAM, unix.IPPROTO_UDP) - if err == nil { - unix.CloseOnExec(fd) - } - syscall.ForkLock.RUnlock() - + //this context is only used during the bind operation, you can't cancel it to kill the socket + pc, err := lc.ListenPacket(context.Background(), "udp", listen.String()) if err != nil { - unix.Close(fd) return nil, fmt.Errorf("unable to open socket: %s", err) } - - if multi { - if err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil { - return nil, fmt.Errorf("unable to set SO_REUSEPORT: %s", err) - } + udpConn := pc.(*net.UDPConn) + rawConn, err := udpConn.SyscallConn() + if err != nil { + _ = udpConn.Close() + return nil, err } - var sa unix.Sockaddr - if ip.Is4() { - sa4 := &unix.SockaddrInet4{Port: port} - sa4.Addr = ip.As4() - sa = sa4 - } else { - sa6 := &unix.SockaddrInet6{Port: port} - sa6.Addr = ip.As16() - sa = sa6 - } - if err = unix.Bind(fd, sa); err != nil { - return nil, fmt.Errorf("unable to bind to socket: %s", err) - } - - return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err + return &StdConn{ + udpConn: udpConn, + rawConn: rawConn, + isV4: ip.Is4(), + l: l, + batch: batch, + }, err } func (u *StdConn) SupportsMultipleReaders() bool { @@ -80,63 +73,126 @@ func (u *StdConn) Rebind() error { return nil } +func (u *StdConn) getSockOptInt(opt int) (int, error) { + if u.rawConn == nil { + return 0, fmt.Errorf("no UDP connection") + } + var out int + var opErr error + err := u.rawConn.Control(func(fd uintptr) { + out, opErr = unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, opt) + }) + if err != nil { + return 0, err + } + return out, opErr +} + +func (u *StdConn) setSockOptInt(opt int, n int) error { + if u.rawConn == nil { + return fmt.Errorf("no UDP connection") + } + var opErr error + err := u.rawConn.Control(func(fd uintptr) { + opErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, opt, n) + }) + if err != nil { + return err + } + return opErr +} + func (u *StdConn) SetRecvBuffer(n int) error { - return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, n) + return u.setSockOptInt(unix.SO_RCVBUFFORCE, n) } func (u *StdConn) SetSendBuffer(n int) error { - return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, n) + return u.setSockOptInt(unix.SO_SNDBUFFORCE, n) } func (u *StdConn) SetSoMark(mark int) error { - return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_MARK, mark) + return u.setSockOptInt(unix.SO_MARK, mark) } func (u *StdConn) GetRecvBuffer() (int, error) { - return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_RCVBUF) + return u.getSockOptInt(unix.SO_RCVBUF) } func (u *StdConn) GetSendBuffer() (int, error) { - return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_SNDBUF) + return u.getSockOptInt(unix.SO_SNDBUF) } func (u *StdConn) GetSoMark() (int, error) { - return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_MARK) + return u.getSockOptInt(unix.SO_MARK) } func (u *StdConn) LocalAddr() (netip.AddrPort, error) { - sa, err := unix.Getsockname(u.sysFd) - if err != nil { - return netip.AddrPort{}, err + addr := u.udpConn.LocalAddr() + return netip.ParseAddrPort(addr.String()) +} + +func recvmmsg(fd uintptr, msgs []rawMessage) (int, bool, error) { + var errno syscall.Errno + n, _, errno := unix.Syscall6( + unix.SYS_RECVMMSG, + fd, + uintptr(unsafe.Pointer(&msgs[0])), + uintptr(len(msgs)), + unix.MSG_WAITFORONE, + 0, + 0, + ) + if errno == syscall.EAGAIN || errno == syscall.EWOULDBLOCK { + // No data available, block for I/O and try again. + return int(n), false, nil } + if errno != 0 { + return int(n), true, &net.OpError{Op: "recvmmsg", Err: errno} + } + return int(n), true, nil +} - switch sa := sa.(type) { - case *unix.SockaddrInet4: - return netip.AddrPortFrom(netip.AddrFrom4(sa.Addr), uint16(sa.Port)), nil +func (u *StdConn) listenOutSingle(r EncReader) { + var err error + var n int + var from netip.AddrPort + buffer := make([]byte, MTU) - case *unix.SockaddrInet6: - return netip.AddrPortFrom(netip.AddrFrom16(sa.Addr), uint16(sa.Port)), nil - - default: - return netip.AddrPort{}, fmt.Errorf("unsupported sock type: %T", sa) + for { + n, from, err = u.udpConn.ReadFromUDPAddrPort(buffer) + if err != nil { + u.l.WithError(err).Debug("udp socket is closed, exiting read loop") + return + } + from = netip.AddrPortFrom(from.Addr().Unmap(), from.Port()) + r(from, buffer[:n]) } } -func (u *StdConn) ListenOut(r EncReader) { +func (u *StdConn) listenOutBatch(r EncReader) { var ip netip.Addr + var n int + var operr error msgs, buffers, names := u.PrepareRawMessages(u.batch) - read := u.ReadMulti - if u.batch == 1 { - read = u.ReadSingle + + //reader needs to capture variables from this function, since it's used as a lambda with rawConn.Read + //defining it outside the loop so it gets re-used + reader := func(fd uintptr) (done bool) { + n, done, operr = recvmmsg(fd, msgs) + return done } for { - n, err := read(msgs) + err := u.rawConn.Read(reader) if err != nil { u.l.WithError(err).Debug("udp socket is closed, exiting read loop") return } + if operr != nil { + u.l.WithError(err).Debug("operr: udp socket is closed, exiting read loop") + return + } for i := 0; i < n; i++ { // Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic @@ -150,106 +206,20 @@ func (u *StdConn) ListenOut(r EncReader) { } } -func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) { - for { - n, _, err := unix.Syscall6( - unix.SYS_RECVMSG, - uintptr(u.sysFd), - uintptr(unsafe.Pointer(&(msgs[0].Hdr))), - 0, - 0, - 0, - 0, - ) - - if err != 0 { - return 0, &net.OpError{Op: "recvmsg", Err: err} - } - - msgs[0].Len = uint32(n) - return 1, nil - } -} - -func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) { - for { - n, _, err := unix.Syscall6( - unix.SYS_RECVMMSG, - uintptr(u.sysFd), - uintptr(unsafe.Pointer(&msgs[0])), - uintptr(len(msgs)), - unix.MSG_WAITFORONE, - 0, - 0, - ) - - if err != 0 { - return 0, &net.OpError{Op: "recvmmsg", Err: err} - } - - return int(n), nil +func (u *StdConn) ListenOut(r EncReader) { + if u.batch == 1 { + //save some ram by not calling PrepareRawMessages for fields we won't use + //we could also make this path more common by calling recvmmsg with msgs[:1], + //but that's still the recvmmsg syscall, which would be a change + u.listenOutSingle(r) + } else { + u.listenOutBatch(r) } } func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error { - if u.isV4 { - return u.writeTo4(b, ip) - } - return u.writeTo6(b, ip) -} - -func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error { - var rsa unix.RawSockaddrInet6 - rsa.Family = unix.AF_INET6 - rsa.Addr = ip.Addr().As16() - binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ip.Port()) - - for { - _, _, err := unix.Syscall6( - unix.SYS_SENDTO, - uintptr(u.sysFd), - uintptr(unsafe.Pointer(&b[0])), - uintptr(len(b)), - uintptr(0), - uintptr(unsafe.Pointer(&rsa)), - uintptr(unix.SizeofSockaddrInet6), - ) - - if err != 0 { - return &net.OpError{Op: "sendto", Err: err} - } - - return nil - } -} - -func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error { - if !ip.Addr().Is4() { - return ErrInvalidIPv6RemoteForSocket - } - - var rsa unix.RawSockaddrInet4 - rsa.Family = unix.AF_INET - rsa.Addr = ip.Addr().As4() - binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ip.Port()) - - for { - _, _, err := unix.Syscall6( - unix.SYS_SENDTO, - uintptr(u.sysFd), - uintptr(unsafe.Pointer(&b[0])), - uintptr(len(b)), - uintptr(0), - uintptr(unsafe.Pointer(&rsa)), - uintptr(unix.SizeofSockaddrInet4), - ) - - if err != 0 { - return &net.OpError{Op: "sendto", Err: err} - } - - return nil - } + _, err := u.udpConn.WriteToUDPAddrPort(b, ip) + return err } func (u *StdConn) ReloadConfig(c *config.C) { @@ -302,15 +272,25 @@ func (u *StdConn) ReloadConfig(c *config.C) { func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error { var vallen uint32 = 4 * unix.SK_MEMINFO_VARS - _, _, err := unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(u.sysFd), uintptr(unix.SOL_SOCKET), uintptr(unix.SO_MEMINFO), uintptr(unsafe.Pointer(meminfo)), uintptr(unsafe.Pointer(&vallen)), 0) - if err != 0 { + + if u.rawConn == nil { + return fmt.Errorf("no UDP connection") + } + var opErr error + err := u.rawConn.Control(func(fd uintptr) { + _, _, opErr = unix.Syscall6(unix.SYS_GETSOCKOPT, fd, uintptr(unix.SOL_SOCKET), uintptr(unix.SO_MEMINFO), uintptr(unsafe.Pointer(meminfo)), uintptr(unsafe.Pointer(&vallen)), 0) + }) + if err != nil { return err } - return nil + return opErr } func (u *StdConn) Close() error { - return syscall.Close(u.sysFd) + if u.udpConn != nil { + return u.udpConn.Close() + } + return nil } func NewUDPStatsEmitter(udpConns []Conn) func() {