From b3194236aac4d4a577812fbd3f1ec66c2e5e1a60 Mon Sep 17 00:00:00 2001 From: Jack Doan Date: Tue, 14 Apr 2026 18:25:24 -0500 Subject: [PATCH] udp_linux: wrap socket operations with syscall.RawConn for clean teardown (#1654) remove runtime.LockOSThread() because it makes things worse now remove the "custom" Write() method from tun_linux.go, the stdlib path via os.File performs better We should change our guidance around number of routines, ~2 per thread (that you wish to use for Nebula) seems to be about right now --- interface.go | 5 - overlay/tun_linux.go | 23 ---- udp/udp_linux.go | 322 ++++++++++++++++++++++--------------------- 3 files changed, 162 insertions(+), 188 deletions(-) diff --git a/interface.go b/interface.go index 61b1f228..61f8c9b7 100644 --- a/interface.go +++ b/interface.go @@ -7,7 +7,6 @@ import ( "io" "net/netip" "os" - "runtime" "sync/atomic" "time" @@ -263,8 +262,6 @@ func (f *Interface) run() { } func (f *Interface) listenOut(i int) { - runtime.LockOSThread() - var li udp.Conn if i > 0 { li = f.writers[i] @@ -285,8 +282,6 @@ func (f *Interface) listenOut(i int) { } func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { - runtime.LockOSThread() - packet := make([]byte, mtu) out := make([]byte, mtu) fwPacket := &firewall.Packet{} diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 7e4aa418..9d779a4b 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -261,29 +261,6 @@ func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { return r } -func (t *tun) Write(b []byte) (int, error) { - var nn int - maximum := len(b) - - for { - n, err := unix.Write(t.fd, b[nn:maximum]) - if n > 0 { - nn += n - } - if nn == len(b) { - return nn, err - } - - if err != nil { - return nn, err - } - - if n == 0 { - return nn, io.ErrUnexpectedEOF - } - } -} - func (t *tun) deviceBytes() (o [16]byte) { for i, c := range t.Device { o[i] = byte(c) diff --git a/udp/udp_linux.go b/udp/udp_linux.go index e7759329..b1490a1c 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -4,6 +4,7 @@ package udp import ( + "context" "encoding/binary" "fmt" "net" @@ -18,58 +19,58 @@ import ( ) type StdConn struct { - sysFd int - isV4 bool - l *logrus.Logger - batch int + udpConn *net.UDPConn + rawConn syscall.RawConn + isV4 bool + l *logrus.Logger + batch int } -func maybeIPV4(ip net.IP) (net.IP, bool) { - ip4 := ip.To4() - if ip4 != nil { - return ip4, true +func setReusePort(network, address string, c syscall.RawConn) error { + var opErr error + err := c.Control(func(fd uintptr) { + opErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1) + //CloseOnExec already set by the runtime + }) + if err != nil { + return err } - return ip, false + return opErr } func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch int) (Conn, error) { - af := unix.AF_INET6 - if ip.Is4() { - af = unix.AF_INET + listen := netip.AddrPortFrom(ip, uint16(port)) + lc := net.ListenConfig{} + if multi { + lc.Control = setReusePort } - syscall.ForkLock.RLock() - fd, err := unix.Socket(af, unix.SOCK_DGRAM, unix.IPPROTO_UDP) - if err == nil { - unix.CloseOnExec(fd) - } - syscall.ForkLock.RUnlock() - + //this context is only used during the bind operation, you can't cancel it to kill the socket + pc, err := lc.ListenPacket(context.Background(), "udp", listen.String()) if err != nil { - unix.Close(fd) return nil, fmt.Errorf("unable to open socket: %s", err) } - - if multi { - if err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil { - return nil, fmt.Errorf("unable to set SO_REUSEPORT: %s", err) - } + udpConn := pc.(*net.UDPConn) + rawConn, err := udpConn.SyscallConn() + if err != nil { + _ = udpConn.Close() + return nil, err + } + //gotta find out if we got an AF_INET6 socket or not: + out := &StdConn{ + udpConn: udpConn, + rawConn: rawConn, + l: l, + batch: batch, } - var sa unix.Sockaddr - if ip.Is4() { - sa4 := &unix.SockaddrInet4{Port: port} - sa4.Addr = ip.As4() - sa = sa4 - } else { - sa6 := &unix.SockaddrInet6{Port: port} - sa6.Addr = ip.As16() - sa = sa6 - } - if err = unix.Bind(fd, sa); err != nil { - return nil, fmt.Errorf("unable to bind to socket: %s", err) + af, err := out.getSockOptInt(unix.SO_DOMAIN) + if err != nil { + _ = out.Close() + return nil, err } + out.isV4 = af == unix.AF_INET - return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err + return out, nil } func (u *StdConn) SupportsMultipleReaders() bool { @@ -80,63 +81,137 @@ func (u *StdConn) Rebind() error { return nil } +func (u *StdConn) getSockOptInt(opt int) (int, error) { + if u.rawConn == nil { + return 0, fmt.Errorf("no UDP connection") + } + var out int + var opErr error + err := u.rawConn.Control(func(fd uintptr) { + out, opErr = unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, opt) + }) + if err != nil { + return 0, err + } + return out, opErr +} + +func (u *StdConn) setSockOptInt(opt int, n int) error { + if u.rawConn == nil { + return fmt.Errorf("no UDP connection") + } + var opErr error + err := u.rawConn.Control(func(fd uintptr) { + opErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, opt, n) + }) + if err != nil { + return err + } + return opErr +} + func (u *StdConn) SetRecvBuffer(n int) error { - return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, n) + return u.setSockOptInt(unix.SO_RCVBUFFORCE, n) } func (u *StdConn) SetSendBuffer(n int) error { - return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, n) + return u.setSockOptInt(unix.SO_SNDBUFFORCE, n) } func (u *StdConn) SetSoMark(mark int) error { - return unix.SetsockoptInt(u.sysFd, unix.SOL_SOCKET, unix.SO_MARK, mark) + return u.setSockOptInt(unix.SO_MARK, mark) } func (u *StdConn) GetRecvBuffer() (int, error) { - return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_RCVBUF) + return u.getSockOptInt(unix.SO_RCVBUF) } func (u *StdConn) GetSendBuffer() (int, error) { - return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_SNDBUF) + return u.getSockOptInt(unix.SO_SNDBUF) } func (u *StdConn) GetSoMark() (int, error) { - return unix.GetsockoptInt(int(u.sysFd), unix.SOL_SOCKET, unix.SO_MARK) + return u.getSockOptInt(unix.SO_MARK) } func (u *StdConn) LocalAddr() (netip.AddrPort, error) { - sa, err := unix.Getsockname(u.sysFd) - if err != nil { - return netip.AddrPort{}, err - } + a := u.udpConn.LocalAddr() - switch sa := sa.(type) { - case *unix.SockaddrInet4: - return netip.AddrPortFrom(netip.AddrFrom4(sa.Addr), uint16(sa.Port)), nil - - case *unix.SockaddrInet6: - return netip.AddrPortFrom(netip.AddrFrom16(sa.Addr), uint16(sa.Port)), nil + switch v := a.(type) { + case *net.UDPAddr: + addr, ok := netip.AddrFromSlice(v.IP) + if !ok { + return netip.AddrPort{}, fmt.Errorf("LocalAddr returned invalid IP address: %s", v.IP) + } + return netip.AddrPortFrom(addr, uint16(v.Port)), nil default: - return netip.AddrPort{}, fmt.Errorf("unsupported sock type: %T", sa) + return netip.AddrPort{}, fmt.Errorf("LocalAddr returned: %#v", a) } } -func (u *StdConn) ListenOut(r EncReader) { +func recvmmsg(fd uintptr, msgs []rawMessage) (int, bool, error) { + var errno syscall.Errno + n, _, errno := unix.Syscall6( + unix.SYS_RECVMMSG, + fd, + uintptr(unsafe.Pointer(&msgs[0])), + uintptr(len(msgs)), + unix.MSG_WAITFORONE, + 0, + 0, + ) + if errno == syscall.EAGAIN || errno == syscall.EWOULDBLOCK { + // No data available, block for I/O and try again. + return int(n), false, nil + } + if errno != 0 { + return int(n), true, &net.OpError{Op: "recvmmsg", Err: errno} + } + return int(n), true, nil +} + +func (u *StdConn) listenOutSingle(r EncReader) { + var err error + var n int + var from netip.AddrPort + buffer := make([]byte, MTU) + + for { + n, from, err = u.udpConn.ReadFromUDPAddrPort(buffer) + if err != nil { + u.l.WithError(err).Debug("udp socket is closed, exiting read loop") + return + } + from = netip.AddrPortFrom(from.Addr().Unmap(), from.Port()) + r(from, buffer[:n]) + } +} + +func (u *StdConn) listenOutBatch(r EncReader) { var ip netip.Addr + var n int + var operr error msgs, buffers, names := u.PrepareRawMessages(u.batch) - read := u.ReadMulti - if u.batch == 1 { - read = u.ReadSingle + + //reader needs to capture variables from this function, since it's used as a lambda with rawConn.Read + //defining it outside the loop so it gets re-used + reader := func(fd uintptr) (done bool) { + n, done, operr = recvmmsg(fd, msgs) + return done } for { - n, err := read(msgs) + err := u.rawConn.Read(reader) if err != nil { u.l.WithError(err).Debug("udp socket is closed, exiting read loop") return } + if operr != nil { + u.l.WithError(operr).Debug("operr: udp socket is closed, exiting read loop") + return + } for i := 0; i < n; i++ { // Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic @@ -150,106 +225,20 @@ func (u *StdConn) ListenOut(r EncReader) { } } -func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) { - for { - n, _, err := unix.Syscall6( - unix.SYS_RECVMSG, - uintptr(u.sysFd), - uintptr(unsafe.Pointer(&(msgs[0].Hdr))), - 0, - 0, - 0, - 0, - ) - - if err != 0 { - return 0, &net.OpError{Op: "recvmsg", Err: err} - } - - msgs[0].Len = uint32(n) - return 1, nil - } -} - -func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) { - for { - n, _, err := unix.Syscall6( - unix.SYS_RECVMMSG, - uintptr(u.sysFd), - uintptr(unsafe.Pointer(&msgs[0])), - uintptr(len(msgs)), - unix.MSG_WAITFORONE, - 0, - 0, - ) - - if err != 0 { - return 0, &net.OpError{Op: "recvmmsg", Err: err} - } - - return int(n), nil +func (u *StdConn) ListenOut(r EncReader) { + if u.batch == 1 { + //save some ram by not calling PrepareRawMessages for fields we won't use + //we could also make this path more common by calling recvmmsg with msgs[:1], + //but that's still the recvmmsg syscall, which would be a change + u.listenOutSingle(r) + } else { + u.listenOutBatch(r) } } func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error { - if u.isV4 { - return u.writeTo4(b, ip) - } - return u.writeTo6(b, ip) -} - -func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error { - var rsa unix.RawSockaddrInet6 - rsa.Family = unix.AF_INET6 - rsa.Addr = ip.Addr().As16() - binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ip.Port()) - - for { - _, _, err := unix.Syscall6( - unix.SYS_SENDTO, - uintptr(u.sysFd), - uintptr(unsafe.Pointer(&b[0])), - uintptr(len(b)), - uintptr(0), - uintptr(unsafe.Pointer(&rsa)), - uintptr(unix.SizeofSockaddrInet6), - ) - - if err != 0 { - return &net.OpError{Op: "sendto", Err: err} - } - - return nil - } -} - -func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error { - if !ip.Addr().Is4() { - return ErrInvalidIPv6RemoteForSocket - } - - var rsa unix.RawSockaddrInet4 - rsa.Family = unix.AF_INET - rsa.Addr = ip.Addr().As4() - binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&rsa.Port))[:], ip.Port()) - - for { - _, _, err := unix.Syscall6( - unix.SYS_SENDTO, - uintptr(u.sysFd), - uintptr(unsafe.Pointer(&b[0])), - uintptr(len(b)), - uintptr(0), - uintptr(unsafe.Pointer(&rsa)), - uintptr(unix.SizeofSockaddrInet4), - ) - - if err != 0 { - return &net.OpError{Op: "sendto", Err: err} - } - - return nil - } + _, err := u.udpConn.WriteToUDPAddrPort(b, ip) + return err } func (u *StdConn) ReloadConfig(c *config.C) { @@ -302,15 +291,28 @@ func (u *StdConn) ReloadConfig(c *config.C) { func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error { var vallen uint32 = 4 * unix.SK_MEMINFO_VARS - _, _, err := unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(u.sysFd), uintptr(unix.SOL_SOCKET), uintptr(unix.SO_MEMINFO), uintptr(unsafe.Pointer(meminfo)), uintptr(unsafe.Pointer(&vallen)), 0) - if err != 0 { + + if u.rawConn == nil { + return fmt.Errorf("no UDP connection") + } + var opErr error + err := u.rawConn.Control(func(fd uintptr) { + _, _, syserr := unix.Syscall6(unix.SYS_GETSOCKOPT, fd, uintptr(unix.SOL_SOCKET), uintptr(unix.SO_MEMINFO), uintptr(unsafe.Pointer(meminfo)), uintptr(unsafe.Pointer(&vallen)), 0) + if syserr != 0 { + opErr = syserr + } + }) + if err != nil { return err } - return nil + return opErr } func (u *StdConn) Close() error { - return syscall.Close(u.sysFd) + if u.udpConn != nil { + return u.udpConn.Close() + } + return nil } func NewUDPStatsEmitter(udpConns []Conn) func() {