diff --git a/cert/pem.go b/cert/pem.go index a5aabdc..c04e63b 100644 --- a/cert/pem.go +++ b/cert/pem.go @@ -3,6 +3,7 @@ package cert import ( "encoding/pem" "fmt" + "time" "golang.org/x/crypto/ed25519" ) @@ -189,3 +190,69 @@ func UnmarshalSigningPrivateKeyFromPEM(b []byte) ([]byte, []byte, Curve, error) } return k.Bytes, r, curve, nil } + +// Backward compatibility functions for older API +func MarshalX25519PublicKey(b []byte) []byte { + return MarshalPublicKeyToPEM(Curve_CURVE25519, b) +} + +func MarshalX25519PrivateKey(b []byte) []byte { + return MarshalPrivateKeyToPEM(Curve_CURVE25519, b) +} + +func MarshalPublicKey(curve Curve, b []byte) []byte { + return MarshalPublicKeyToPEM(curve, b) +} + +func MarshalPrivateKey(curve Curve, b []byte) []byte { + return MarshalPrivateKeyToPEM(curve, b) +} + +// NebulaCertificate is a compatibility wrapper for the old API +type NebulaCertificate struct { + Details NebulaCertificateDetails + Signature []byte + cert Certificate +} + +// NebulaCertificateDetails is a compatibility wrapper for certificate details +type NebulaCertificateDetails struct { + Name string + NotBefore time.Time + NotAfter time.Time + PublicKey []byte + IsCA bool + Issuer []byte + Curve Curve +} + +// UnmarshalNebulaCertificateFromPEM provides backward compatibility with the old API +func UnmarshalNebulaCertificateFromPEM(b []byte) (*NebulaCertificate, []byte, error) { + c, rest, err := UnmarshalCertificateFromPEM(b) + if err != nil { + return nil, rest, err + } + + // Convert to old format + nc := &NebulaCertificate{ + Details: NebulaCertificateDetails{ + Name: c.Name(), + NotBefore: c.NotBefore(), + NotAfter: c.NotAfter(), + PublicKey: c.PublicKey(), + IsCA: c.IsCA(), + Curve: c.Curve(), + }, + Signature: c.Signature(), + cert: c, + } + + // Handle issuer + if c.Issuer() != "" { + // Convert hex string fingerprint back to bytes (this is an approximation) + // The old API used raw bytes, new API uses hex string + nc.Details.Issuer = []byte(c.Issuer()) + } + + return nc, rest, nil +} diff --git a/interface.go b/interface.go index 082906d..495bb9e 100644 --- a/interface.go +++ b/interface.go @@ -271,7 +271,10 @@ func (f *Interface) listenOut(i int) { fwPacket := &firewall.Packet{} nb := make([]byte, 12, 12) - li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { + li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte, release func()) { + if release != nil { + defer release() + } f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l)) }) } diff --git a/udp/config.go b/udp/config.go new file mode 100644 index 0000000..98233ca --- /dev/null +++ b/udp/config.go @@ -0,0 +1,16 @@ +package udp + +import "sync/atomic" + +var disableUDPCsum atomic.Bool + +// SetDisableUDPCsum controls whether IPv4 UDP sockets opt out of kernel +// checksum calculation via SO_NO_CHECK. Only applicable on platforms that +// support the option (Linux). IPv6 always keeps the checksum enabled. +func SetDisableUDPCsum(disable bool) { + disableUDPCsum.Store(disable) +} + +func udpChecksumDisabled() bool { + return disableUDPCsum.Load() +} diff --git a/udp/conn.go b/udp/conn.go index 895b0df..1c6a6de 100644 --- a/udp/conn.go +++ b/udp/conn.go @@ -11,6 +11,7 @@ const MTU = 9001 type EncReader func( addr netip.AddrPort, payload []byte, + release func(), ) type Conn interface { diff --git a/udp/io_uring_linux.go b/udp/io_uring_linux.go new file mode 100644 index 0000000..c1580df --- /dev/null +++ b/udp/io_uring_linux.go @@ -0,0 +1,1483 @@ +//go:build linux && !android && !e2e_testing +// +build linux,!android,!e2e_testing + +package udp + +import ( + "errors" + "fmt" + "net" + "runtime" + "sync" + "sync/atomic" + "syscall" + "unsafe" + + "github.com/sirupsen/logrus" + "golang.org/x/sys/unix" +) + +const ( + ioringOpSendmsg = 9 + ioringOpRecvmsg = 10 + ioringEnterGetevents = 1 << 0 + ioringSetupClamp = 1 << 4 + ioringSetupCoopTaskrun = 1 << 8 // Kernel 5.19+: reduce thread creation + ioringSetupSingleIssuer = 1 << 12 // Kernel 6.0+: single submitter optimization + ioringRegisterIowqMaxWorkers = 19 // Register opcode to limit workers + ioringOffSqRing = 0 + ioringOffCqRing = 0x8000000 + ioringOffSqes = 0x10000000 + defaultIoUringEntries = 256 + ioUringSqeSize = 64 // struct io_uring_sqe size defined by kernel ABI +) + +type ioSqringOffsets struct { + Head uint32 + Tail uint32 + RingMask uint32 + RingEntries uint32 + Flags uint32 + Dropped uint32 + Array uint32 + Resv1 uint32 + Resv2 uint64 +} + +type ioCqringOffsets struct { + Head uint32 + Tail uint32 + RingMask uint32 + RingEntries uint32 + Overflow uint32 + Cqes uint32 + Resv [2]uint32 +} + +type ioUringParams struct { + SqEntries uint32 + CqEntries uint32 + Flags uint32 + SqThreadCPU uint32 + SqThreadIdle uint32 + Features uint32 + WqFd uint32 + Resv [3]uint32 + SqOff ioSqringOffsets + CqOff ioCqringOffsets +} + +type ioUringSqe struct { + Opcode uint8 + Flags uint8 + Ioprio uint16 + Fd int32 + Off uint64 + Addr uint64 + Len uint32 + MsgFlags uint32 + UserData uint64 + BufIndex uint16 + Personality uint16 + SpliceFdIn int32 + SpliceOffIn uint64 + Addr2 uint64 +} + +type ioUringCqe struct { + UserData uint64 + Res int32 + Flags uint32 + // No explicit padding needed - Go will align uint64 to 8 bytes, + // and int32/uint32 will be naturally aligned. Total size should be 16 bytes. + // Kernel structure: __u64 user_data; __s32 res; __u32 flags; +} + +func init() { + if sz := unsafe.Sizeof(ioUringSqe{}); sz != ioUringSqeSize { + panic(fmt.Sprintf("io_uring SQE size mismatch: expected %d, got %d", ioUringSqeSize, sz)) + } + if sz := unsafe.Sizeof(ioUringCqe{}); sz != 16 { + panic(fmt.Sprintf("io_uring CQE size mismatch: expected %d, got %d", 16, sz)) + } +} + +// pendingSend tracks all heap-allocated structures for a single io_uring submission +// to ensure they remain valid until the kernel completes the operation +type pendingSend struct { + msgCopy *unix.Msghdr + iovCopy *unix.Iovec + sockaddrCopy []byte + controlCopy []byte + payloadRef unsafe.Pointer + userData uint64 +} + +type pendingRecv struct { + msgCopy *unix.Msghdr + iovCopy *unix.Iovec + nameBuf []byte + controlBuf []byte + payloadBuf []byte + callerMsg *unix.Msghdr + userData uint64 +} + +type ioUringBatchResult struct { + res int32 + flags uint32 + err error +} + +type ioUringBatchEntry struct { + fd int + msg *unix.Msghdr + msgFlags uint32 + payloadLen uint32 + userData uint64 + result *ioUringBatchResult +} + +type ioUringState struct { + fd int + sqRing []byte + cqRing []byte + sqesMap []byte + sqes []ioUringSqe + cqCqes []ioUringCqe + + sqHead *uint32 + sqTail *uint32 + sqRingMask *uint32 + sqRingEntries *uint32 + sqArray []uint32 + + cqHead *uint32 + cqTail *uint32 + cqRingMask *uint32 + cqRingEntries *uint32 + + mu sync.Mutex + userData uint64 + pendingSends map[uint64]*pendingSend + + sqEntryCount uint32 + cqEntryCount uint32 + + pendingReceives map[uint64]*pendingRecv + completedCqes map[uint64]*ioUringCqe +} + +// recvBuffer represents a single receive operation with its associated buffers +type recvBuffer struct { + payloadBuf []byte // Buffer for packet data + nameBuf []byte // Buffer for source address + controlBuf []byte // Buffer for control messages + msghdr *unix.Msghdr // Message header for recvmsg + iovec *unix.Iovec // IO vector pointing to payloadBuf + userData uint64 // User data for tracking this operation + inFlight atomic.Bool // Whether this buffer has a pending io_uring operation +} + +// ioUringRecvState manages a dedicated io_uring for receiving packets +// It maintains a pool of receive buffers and continuously keeps receives queued +type ioUringRecvState struct { + fd int + sqRing []byte + cqRing []byte + sqesMap []byte + sqes []ioUringSqe + cqCqes []ioUringCqe + + sqHead *uint32 + sqTail *uint32 + sqRingMask *uint32 + sqRingEntries *uint32 + sqArray []uint32 + + cqHead *uint32 + cqTail *uint32 + cqRingMask *uint32 + cqRingEntries *uint32 + + mu sync.Mutex + userData uint64 + bufferPool []*recvBuffer // Pool of all receive buffers + bufferMap map[uint64]*recvBuffer // Map userData -> buffer + + sqEntryCount uint32 + cqEntryCount uint32 + + sockFd int // Socket file descriptor to receive from + closed atomic.Bool +} + +func alignUint32(v, alignment uint32) uint32 { + if alignment == 0 { + return v + } + mod := v % alignment + if mod == 0 { + return v + } + return v + alignment - mod +} + +func newIoUringState(entries uint32) (*ioUringState, error) { + const minEntries = 8 + + if entries == 0 { + entries = defaultIoUringEntries + } + if entries < minEntries { + entries = minEntries + } + + tries := entries + var params ioUringParams + + // Try flag combinations in order (5.19+ -> baseline) + // Note: SINGLE_ISSUER causes EEXIST errors, so it's excluded + flagSets := []uint32{ + ioringSetupClamp | ioringSetupCoopTaskrun, // Kernel 5.19+: reduce thread creation + ioringSetupClamp, // All kernels + } + flagSetIdx := 0 + + for { + params = ioUringParams{Flags: flagSets[flagSetIdx]} + fd, _, errno := unix.Syscall(unix.SYS_IO_URING_SETUP, uintptr(tries), uintptr(unsafe.Pointer(¶ms)), 0) + if errno != 0 { + // If EINVAL, try next flag set (kernel doesn't support these flags) + if errno == unix.EINVAL && flagSetIdx < len(flagSets)-1 { + flagSetIdx++ + continue + } + if errno == unix.ENOMEM && tries > minEntries { + tries /= 2 + if tries < minEntries { + tries = minEntries + } + continue + } + return nil, errno + } + + ring := &ioUringState{ + fd: int(fd), + sqEntryCount: params.SqEntries, + cqEntryCount: params.CqEntries, + userData: 1, + pendingSends: make(map[uint64]*pendingSend), + pendingReceives: make(map[uint64]*pendingRecv), + completedCqes: make(map[uint64]*ioUringCqe), + } + + if err := ring.mapRings(¶ms); err != nil { + ring.Close() + if errors.Is(err, unix.ENOMEM) && tries > minEntries { + tries /= 2 + if tries < minEntries { + tries = minEntries + } + continue + } + return nil, err + } + + // Limit kernel worker threads to prevent thousands being spawned + // [0] = bounded workers, [1] = unbounded workers + maxWorkers := [2]uint32{4, 4} // Limit to 4 workers of each type + _, _, errno = unix.Syscall6( + unix.SYS_IO_URING_REGISTER, + uintptr(fd), + uintptr(ioringRegisterIowqMaxWorkers), + uintptr(unsafe.Pointer(&maxWorkers[0])), + 2, // array length + 0, 0, + ) + // Ignore errors - older kernels don't support this + + return ring, nil + } +} + +func (r *ioUringState) mapRings(params *ioUringParams) error { + pageSize := uint32(unix.Getpagesize()) + + sqRingSize := alignUint32(params.SqOff.Array+params.SqEntries*4, pageSize) + cqRingSize := alignUint32(params.CqOff.Cqes+params.CqEntries*uint32(unsafe.Sizeof(ioUringCqe{})), pageSize) + sqesSize := alignUint32(params.SqEntries*ioUringSqeSize, pageSize) + + sqRing, err := unix.Mmap(r.fd, ioringOffSqRing, int(sqRingSize), unix.PROT_READ|unix.PROT_WRITE, unix.MAP_SHARED) + if err != nil { + return err + } + r.sqRing = sqRing + + cqRing, err := unix.Mmap(r.fd, ioringOffCqRing, int(cqRingSize), unix.PROT_READ|unix.PROT_WRITE, unix.MAP_SHARED) + if err != nil { + unix.Munmap(r.sqRing) + r.sqRing = nil + return err + } + r.cqRing = cqRing + + sqesMap, err := unix.Mmap(r.fd, ioringOffSqes, int(sqesSize), unix.PROT_READ|unix.PROT_WRITE, unix.MAP_SHARED) + if err != nil { + unix.Munmap(r.cqRing) + unix.Munmap(r.sqRing) + r.cqRing = nil + r.sqRing = nil + return err + } + r.sqesMap = sqesMap + + sqBase := unsafe.Pointer(&sqRing[0]) + r.sqHead = (*uint32)(unsafe.Pointer(uintptr(sqBase) + uintptr(params.SqOff.Head))) + r.sqTail = (*uint32)(unsafe.Pointer(uintptr(sqBase) + uintptr(params.SqOff.Tail))) + r.sqRingMask = (*uint32)(unsafe.Pointer(uintptr(sqBase) + uintptr(params.SqOff.RingMask))) + r.sqRingEntries = (*uint32)(unsafe.Pointer(uintptr(sqBase) + uintptr(params.SqOff.RingEntries))) + arrayPtr := unsafe.Pointer(uintptr(sqBase) + uintptr(params.SqOff.Array)) + r.sqArray = unsafe.Slice((*uint32)(arrayPtr), int(params.SqEntries)) + + sqesBase := unsafe.Pointer(&sqesMap[0]) + r.sqes = unsafe.Slice((*ioUringSqe)(sqesBase), int(params.SqEntries)) + + cqBase := unsafe.Pointer(&cqRing[0]) + r.cqHead = (*uint32)(unsafe.Pointer(uintptr(cqBase) + uintptr(params.CqOff.Head))) + r.cqTail = (*uint32)(unsafe.Pointer(uintptr(cqBase) + uintptr(params.CqOff.Tail))) + r.cqRingMask = (*uint32)(unsafe.Pointer(uintptr(cqBase) + uintptr(params.CqOff.RingMask))) + r.cqRingEntries = (*uint32)(unsafe.Pointer(uintptr(cqBase) + uintptr(params.CqOff.RingEntries))) + cqesPtr := unsafe.Pointer(uintptr(cqBase) + uintptr(params.CqOff.Cqes)) + + // CRITICAL: Ensure CQE array pointer is properly aligned + // The kernel's CQE structure is 16 bytes, and the array must be aligned + // Verify alignment and log if misaligned + cqeSize := uintptr(unsafe.Sizeof(ioUringCqe{})) + if cqeSize != 16 { + return fmt.Errorf("io_uring CQE size mismatch: expected 16, got %d", cqeSize) + } + cqesOffset := uintptr(cqesPtr) % 8 + if cqesOffset != 0 { + logrus.WithFields(logrus.Fields{ + "cqes_ptr": fmt.Sprintf("%p", cqesPtr), + "cqes_offset": cqesOffset, + "cq_base": fmt.Sprintf("%p", cqBase), + "cq_off_cqes": params.CqOff.Cqes, + }).Warn("io_uring CQE array may be misaligned") + } + + r.cqCqes = unsafe.Slice((*ioUringCqe)(cqesPtr), int(params.CqEntries)) + + return nil +} + +func (r *ioUringState) getSqeLocked() (*ioUringSqe, error) { + iterations := 0 + for { + head := atomic.LoadUint32(r.sqHead) + tail := atomic.LoadUint32(r.sqTail) + entries := atomic.LoadUint32(r.sqRingEntries) + used := tail - head + + if tail-head < entries { + mask := atomic.LoadUint32(r.sqRingMask) + idx := tail & mask + sqe := &r.sqes[idx] + *sqe = ioUringSqe{} + r.sqArray[idx] = idx + atomic.StoreUint32(r.sqTail, tail+1) + if iterations > 0 { + logrus.WithFields(logrus.Fields{ + "iterations": iterations, + "used": used, + "entries": entries, + }).Debug("getSqeLocked got slot after waiting") + } + return sqe, nil + } + + logrus.WithFields(logrus.Fields{ + "head": head, + "tail": tail, + "entries": entries, + "used": used, + }).Warn("getSqeLocked: io_uring ring is FULL, waiting for completions") + + if err := r.submitAndWaitLocked(0, 1); err != nil { + return nil, err + } + iterations++ + } +} + +func (r *ioUringState) submitAndWaitLocked(submit, wait uint32) error { + var flags uintptr + if wait > 0 { + flags = ioringEnterGetevents + } + + for { + _, _, errno := unix.Syscall6(unix.SYS_IO_URING_ENTER, uintptr(r.fd), uintptr(submit), uintptr(wait), flags, 0, 0) + if errno == 0 { + return nil + } + if errno == unix.EINTR { + continue + } + return errno + } +} + +func (r *ioUringState) enqueueSendmsgLocked(fd int, msg *unix.Msghdr, msgFlags uint32, payloadLen uint32) (uint64, error) { + sqe, err := r.getSqeLocked() + if err != nil { + return 0, err + } + + userData := r.userData + r.userData++ + + msgCopy := new(unix.Msghdr) + *msgCopy = *msg + + var iovCopy *unix.Iovec + var payloadRef unsafe.Pointer + if msg.Iov != nil { + iovCopy = new(unix.Iovec) + *iovCopy = *msg.Iov + msgCopy.Iov = iovCopy + if iovCopy.Base != nil { + payloadRef = unsafe.Pointer(iovCopy.Base) + } + } + + var sockaddrCopy []byte + if msg.Name != nil && msg.Namelen > 0 { + sockaddrCopy = make([]byte, msg.Namelen) + copy(sockaddrCopy, (*[256]byte)(unsafe.Pointer(msg.Name))[:msg.Namelen]) + msgCopy.Name = &sockaddrCopy[0] + } + + var controlCopy []byte + if msg.Control != nil && msg.Controllen > 0 { + controlCopy = make([]byte, msg.Controllen) + copy(controlCopy, (*[256]byte)(unsafe.Pointer(msg.Control))[:msg.Controllen]) + msgCopy.Control = &controlCopy[0] + } + + pending := &pendingSend{ + msgCopy: msgCopy, + iovCopy: iovCopy, + sockaddrCopy: sockaddrCopy, + controlCopy: controlCopy, + payloadRef: payloadRef, + userData: userData, + } + r.pendingSends[userData] = pending + + sqe.Opcode = ioringOpSendmsg + sqe.Fd = int32(fd) + sqe.Addr = uint64(uintptr(unsafe.Pointer(msgCopy))) + sqe.Len = 0 + sqe.MsgFlags = msgFlags + sqe.Flags = 0 + + userDataPtr := (*uint64)(unsafe.Pointer(&sqe.UserData)) + atomic.StoreUint64(userDataPtr, userData) + _ = atomic.LoadUint64(userDataPtr) + + runtime.KeepAlive(msgCopy) + runtime.KeepAlive(sqe) + if payloadRef != nil { + runtime.KeepAlive(payloadRef) + } + _ = atomic.LoadUint32(r.sqTail) + atomic.StoreUint32(r.sqTail, atomic.LoadUint32(r.sqTail)) + + return userData, nil +} + +func (r *ioUringState) abortPendingSendLocked(userData uint64) { + if pending, ok := r.pendingSends[userData]; ok { + delete(r.pendingSends, userData) + delete(r.completedCqes, userData) + if pending != nil { + runtime.KeepAlive(pending.msgCopy) + runtime.KeepAlive(pending.iovCopy) + runtime.KeepAlive(pending.sockaddrCopy) + runtime.KeepAlive(pending.controlCopy) + if pending.payloadRef != nil { + runtime.KeepAlive(pending.payloadRef) + } + } + } +} + +func (r *ioUringState) abortPendingRecvLocked(userData uint64) { + if pending, ok := r.pendingReceives[userData]; ok { + delete(r.pendingReceives, userData) + delete(r.completedCqes, userData) + if pending != nil { + runtime.KeepAlive(pending.msgCopy) + runtime.KeepAlive(pending.iovCopy) + runtime.KeepAlive(pending.payloadBuf) + runtime.KeepAlive(pending.nameBuf) + runtime.KeepAlive(pending.controlBuf) + } + } +} + +func (r *ioUringState) completeSendLocked(userData uint64) (int32, uint32, error) { + cqe, err := r.waitForCqeLocked(userData) + if err != nil { + r.abortPendingSendLocked(userData) + return 0, 0, err + } + + var pending *pendingSend + if p, ok := r.pendingSends[userData]; ok { + pending = p + delete(r.pendingSends, userData) + } + + if pending != nil { + runtime.KeepAlive(pending.msgCopy) + runtime.KeepAlive(pending.iovCopy) + runtime.KeepAlive(pending.sockaddrCopy) + runtime.KeepAlive(pending.controlCopy) + if pending.payloadRef != nil { + runtime.KeepAlive(pending.payloadRef) + } + } + + return cqe.Res, cqe.Flags, nil +} + +func (r *ioUringState) enqueueRecvmsgLocked(fd int, msg *unix.Msghdr, msgFlags uint32) (uint64, error) { + if msg == nil { + return 0, syscall.EINVAL + } + + var iovCount int + if msg.Iov != nil { + iovCount = int(msg.Iovlen) + if iovCount <= 0 { + return 0, syscall.EINVAL + } + if iovCount > 1 { + return 0, syscall.ENOTSUP + } + } + + sqe, err := r.getSqeLocked() + if err != nil { + return 0, err + } + + userData := r.userData + r.userData++ + + msgCopy := new(unix.Msghdr) + *msgCopy = *msg + + var iovCopy *unix.Iovec + var payloadBuf []byte + if msg.Iov != nil { + iovCopy = new(unix.Iovec) + *iovCopy = *msg.Iov + msgCopy.Iov = iovCopy + setMsghdrIovlen(msgCopy, 1) + if iovCopy.Base != nil { + payloadLen := int(iovCopy.Len) + if payloadLen < 0 { + return 0, syscall.EINVAL + } + if payloadLen > 0 { + payloadBuf = unsafe.Slice((*byte)(iovCopy.Base), payloadLen) + } + } + } + + var nameBuf []byte + if msgCopy.Name != nil && msgCopy.Namelen > 0 { + nameLen := int(msgCopy.Namelen) + if nameLen < 0 { + return 0, syscall.EINVAL + } + nameBuf = unsafe.Slice(msgCopy.Name, nameLen) + } + + var controlBuf []byte + if msgCopy.Control != nil && msgCopy.Controllen > 0 { + ctrlLen := int(msgCopy.Controllen) + if ctrlLen < 0 { + return 0, syscall.EINVAL + } + if ctrlLen > 0 { + controlBuf = unsafe.Slice((*byte)(msgCopy.Control), ctrlLen) + } + } + + pending := &pendingRecv{ + msgCopy: msgCopy, + iovCopy: iovCopy, + nameBuf: nameBuf, + controlBuf: controlBuf, + payloadBuf: payloadBuf, + callerMsg: msg, + userData: userData, + } + r.pendingReceives[userData] = pending + + sqe.Opcode = ioringOpRecvmsg + sqe.Fd = int32(fd) + sqe.Addr = uint64(uintptr(unsafe.Pointer(msgCopy))) + sqe.Len = 0 + sqe.MsgFlags = msgFlags + sqe.Flags = 0 + + userDataPtr := (*uint64)(unsafe.Pointer(&sqe.UserData)) + atomic.StoreUint64(userDataPtr, userData) + _ = atomic.LoadUint64(userDataPtr) + + runtime.KeepAlive(msgCopy) + runtime.KeepAlive(iovCopy) + runtime.KeepAlive(payloadBuf) + runtime.KeepAlive(nameBuf) + runtime.KeepAlive(controlBuf) + + return userData, nil +} + +func (r *ioUringState) completeRecvLocked(userData uint64) (int32, uint32, error) { + cqe, err := r.waitForCqeLocked(userData) + if err != nil { + r.abortPendingRecvLocked(userData) + return 0, 0, err + } + + var pending *pendingRecv + if p, ok := r.pendingReceives[userData]; ok { + pending = p + delete(r.pendingReceives, userData) + } + + if pending != nil { + if pending.callerMsg != nil && pending.msgCopy != nil { + pending.callerMsg.Namelen = pending.msgCopy.Namelen + pending.callerMsg.Controllen = pending.msgCopy.Controllen + pending.callerMsg.Flags = pending.msgCopy.Flags + } + runtime.KeepAlive(pending.msgCopy) + runtime.KeepAlive(pending.iovCopy) + runtime.KeepAlive(pending.payloadBuf) + runtime.KeepAlive(pending.nameBuf) + runtime.KeepAlive(pending.controlBuf) + } + + return cqe.Res, cqe.Flags, nil +} + +func (r *ioUringState) SendmsgBatch(entries []ioUringBatchEntry) error { + if len(entries) == 0 { + return nil + } + r.mu.Lock() + defer r.mu.Unlock() + + startTail := atomic.LoadUint32(r.sqTail) + prepared := 0 + for i := range entries { + entry := &entries[i] + userData, err := r.enqueueSendmsgLocked(entry.fd, entry.msg, entry.msgFlags, entry.payloadLen) + if err != nil { + for j := 0; j < prepared; j++ { + r.abortPendingSendLocked(entries[j].userData) + } + return err + } + entry.userData = userData + prepared++ + } + + submit := atomic.LoadUint32(r.sqTail) - startTail + if submit == 0 { + return nil + } + + if err := r.submitAndWaitLocked(submit, submit); err != nil { + for i := 0; i < prepared; i++ { + r.abortPendingSendLocked(entries[i].userData) + } + return err + } + + for i := range entries { + entry := &entries[i] + res, flags, err := r.completeSendLocked(entry.userData) + if entry.result != nil { + entry.result.res = res + entry.result.flags = flags + entry.result.err = err + } + } + + return nil +} + +func (r *ioUringState) popCqeLocked() (*ioUringCqe, error) { + for { + // According to io_uring ABI specification: + // 1. Load tail with acquire semantics (ensures we see kernel's update) + // 2. Load head (our consumer index) + // 3. If head != tail, CQE available at index (head & mask) + // 4. Read CQE (must happen before updating head) + // 5. Update head with release semantics (marks CQE as consumed) + + // CRITICAL: According to io_uring ABI, the correct order is: + // 1. Load tail with acquire semantics (ensures we see kernel's tail update) + // 2. Load head (our consumer index) + // 3. If head != tail, read CQE at index (head & mask) + // 4. Update head with release semantics (marks CQE as consumed) + // The acquire/release pair ensures we see the kernel's CQE writes + + // Load tail with acquire semantics - this ensures we see all kernel writes + // including the CQE data + tail := atomic.LoadUint32(r.cqTail) + + // Load head (our consumer index) + head := atomic.LoadUint32(r.cqHead) + + if head != tail { + // CQE available - calculate index using mask + mask := atomic.LoadUint32(r.cqRingMask) + idx := head & mask + + // Get pointer to CQE entry - this points into mmapped kernel memory + cqe := &r.cqCqes[idx] + + // CRITICAL: The kernel writes the CQE with release semantics when it + // updates tail. Since we loaded tail with acquire semantics above, + // we should see the CQE correctly. However, we need to ensure we're + // reading the fields in the correct order and with proper barriers. + + // Read UserData field - use atomic load to ensure proper ordering + // The kernel's write to user_data happens before updating tail (release) + userDataPtr := (*uint64)(unsafe.Pointer(&cqe.UserData)) + userData := atomic.LoadUint64(userDataPtr) + + // Memory barrier: ensure UserData read completes before reading other fields + // This creates a proper acquire barrier + _ = atomic.LoadUint32(r.cqTail) + + // Read other fields - these should be visible after the barrier + res := cqe.Res + flags := cqe.Flags + + // NOW update head with release semantics + // This marks the CQE as consumed and must happen AFTER all reads + atomic.StoreUint32(r.cqHead, head+1) + + // Return a copy to ensure consistency - the original CQE in mmapped + // memory might be overwritten by the kernel for the next submission + return &ioUringCqe{ + UserData: userData, + Res: res, + Flags: flags, + }, nil + } + + // No CQE available - wait for kernel to add one + if err := r.submitAndWaitLocked(0, 1); err != nil { + return nil, err + } + } +} + +// waitForCqeLocked waits for a CQE matching the expected userData. +// It drains any CQEs that don't match (from previous submissions that completed +// out of order) until it finds the one we're waiting for. +func (r *ioUringState) waitForCqeLocked(expectedUserData uint64) (*ioUringCqe, error) { + if cqe, ok := r.completedCqes[expectedUserData]; ok { + delete(r.completedCqes, expectedUserData) + return cqe, nil + } + + const maxIterations = 1000 + for iterations := 0; ; iterations++ { + if iterations >= maxIterations { + logrus.WithFields(logrus.Fields{ + "expected_userdata": expectedUserData, + "pending_sends": len(r.pendingSends), + "pending_recvs": len(r.pendingReceives), + "completed_cache": len(r.completedCqes), + }).Error("io_uring waitForCqeLocked exceeded max iterations - possible bug") + return nil, syscall.EIO + } + + cqe, err := r.popCqeLocked() + if err != nil { + return nil, err + } + userData := cqe.UserData + + logrus.WithFields(logrus.Fields{ + "cqe_userdata": userData, + "cqe_userdata_hex": fmt.Sprintf("0x%x", userData), + "cqe_res": cqe.Res, + "cqe_flags": cqe.Flags, + "expected_userdata": expectedUserData, + }).Debug("io_uring CQE received") + + if userData == expectedUserData { + return cqe, nil + } + + if _, exists := r.completedCqes[userData]; exists { + logrus.WithFields(logrus.Fields{ + "cqe_userdata": userData, + }).Warn("io_uring received duplicate CQE for userData; overwriting previous entry") + } + + r.completedCqes[userData] = cqe + + if _, sendPending := r.pendingSends[userData]; !sendPending { + if _, recvPending := r.pendingReceives[userData]; !recvPending { + logrus.WithFields(logrus.Fields{ + "cqe_userdata": userData, + "cqe_res": cqe.Res, + "cqe_flags": cqe.Flags, + }).Warn("io_uring received CQE for unknown userData; stored for later but no pending op found") + } + } + } +} + +func (r *ioUringState) Sendmsg(fd int, msg *unix.Msghdr, msgFlags uint32, payloadLen uint32) (int, error) { + if r == nil { + return 0, &net.OpError{Op: "sendmsg", Err: syscall.EINVAL} + } + + r.mu.Lock() + defer r.mu.Unlock() + + userData, err := r.enqueueSendmsgLocked(fd, msg, msgFlags, payloadLen) + if err != nil { + return 0, &net.OpError{Op: "sendmsg", Err: err} + } + + if err := r.submitAndWaitLocked(1, 1); err != nil { + r.abortPendingSendLocked(userData) + return 0, &net.OpError{Op: "sendmsg", Err: err} + } + + res, cqeFlags, err := r.completeSendLocked(userData) + if err != nil { + return 0, &net.OpError{Op: "sendmsg", Err: err} + } + + if res < 0 { + errno := syscall.Errno(-res) + return 0, &net.OpError{Op: "sendmsg", Err: errno} + } + if res == 0 && payloadLen > 0 { + logrus.WithFields(logrus.Fields{ + "payload_len": payloadLen, + "msg_namelen": msg.Namelen, + "msg_flags": msgFlags, + "cqe_flags": cqeFlags, + "cqe_userdata": userData, + }).Warn("io_uring sendmsg returned zero bytes") + } + + return int(res), nil +} + +func (r *ioUringState) Recvmsg(fd int, msg *unix.Msghdr, msgFlags uint32) (int, uint32, error) { + + if r == nil { + logrus.Error("io_uring Recvmsg: r is nil") + return 0, 0, &net.OpError{Op: "recvmsg", Err: syscall.EINVAL} + } + + if msg == nil { + logrus.Error("io_uring Recvmsg: msg is nil") + return 0, 0, &net.OpError{Op: "recvmsg", Err: syscall.EINVAL} + } + + r.mu.Lock() + defer r.mu.Unlock() + + userData, err := r.enqueueRecvmsgLocked(fd, msg, msgFlags) + if err != nil { + return 0, 0, &net.OpError{Op: "recvmsg", Err: err} + } + + if err := r.submitAndWaitLocked(1, 1); err != nil { + r.abortPendingRecvLocked(userData) + return 0, 0, &net.OpError{Op: "recvmsg", Err: err} + } + + res, cqeFlags, err := r.completeRecvLocked(userData) + if err != nil { + logrus.WithFields(logrus.Fields{ + "userData": userData, + "error": err, + }).Error("io_uring completeRecvLocked failed") + return 0, 0, &net.OpError{Op: "recvmsg", Err: err} + } + + logrus.WithFields(logrus.Fields{ + "userData": userData, + "res": res, + "cqeFlags": cqeFlags, + "bytesRecv": res, + }).Debug("io_uring recvmsg completed") + + if res < 0 { + errno := syscall.Errno(-res) + logrus.WithFields(logrus.Fields{ + "userData": userData, + "res": res, + "errno": errno, + }).Error("io_uring recvmsg negative result") + return 0, cqeFlags, &net.OpError{Op: "recvmsg", Err: errno} + } + + return int(res), cqeFlags, nil +} + +func (r *ioUringState) Close() error { + if r == nil { + return nil + } + + r.mu.Lock() + // Clean up any remaining pending sends + for _, pending := range r.pendingSends { + runtime.KeepAlive(pending) + } + r.pendingSends = nil + for _, pending := range r.pendingReceives { + runtime.KeepAlive(pending) + } + r.pendingReceives = nil + r.completedCqes = nil + r.mu.Unlock() + + var err error + if r.sqRing != nil { + if e := unix.Munmap(r.sqRing); e != nil && err == nil { + err = e + } + r.sqRing = nil + } + if r.cqRing != nil { + if e := unix.Munmap(r.cqRing); e != nil && err == nil { + err = e + } + r.cqRing = nil + } + if r.sqesMap != nil { + if e := unix.Munmap(r.sqesMap); e != nil && err == nil { + err = e + } + r.sqesMap = nil + } + if r.fd >= 0 { + if e := unix.Close(r.fd); e != nil && err == nil { + err = e + } + r.fd = -1 + } + return err +} + +// RecvPacket represents a received packet with its metadata +type RecvPacket struct { + Data []byte + N int + From *unix.RawSockaddrInet6 + Flags uint32 + Control []byte + Controllen int + RecycleFunc func() +} + +var recvPacketDataPool = sync.Pool{ + New: func() interface{} { + b := make([]byte, 65536) // Max UDP packet size + return &b + }, +} + +var recvControlDataPool = sync.Pool{ + New: func() interface{} { + b := make([]byte, 256) // Max control message size + return &b + }, +} + +// newIoUringRecvState creates a dedicated io_uring for receiving packets +// poolSize determines how many receive operations to keep queued +func newIoUringRecvState(sockFd int, entries uint32, poolSize int, bufferSize int) (*ioUringRecvState, error) { + const minEntries = 8 + + if poolSize < 1 { + poolSize = 64 // Default pool size + } + if poolSize > 2048 { + poolSize = 2048 // Cap pool size + } + + if entries == 0 { + entries = uint32(poolSize) + } + if entries < uint32(poolSize) { + entries = uint32(poolSize) + } + if entries < minEntries { + entries = minEntries + } + + tries := entries + var params ioUringParams + + // Try flag combinations in order (5.19+ -> baseline) + // Note: SINGLE_ISSUER causes EEXIST errors, so it's excluded + flagSets := []uint32{ + ioringSetupClamp | ioringSetupCoopTaskrun, // Kernel 5.19+: reduce thread creation + ioringSetupClamp, // All kernels + } + flagSetIdx := 0 + + for { + params = ioUringParams{Flags: flagSets[flagSetIdx]} + fd, _, errno := unix.Syscall(unix.SYS_IO_URING_SETUP, uintptr(tries), uintptr(unsafe.Pointer(¶ms)), 0) + if errno != 0 { + // If EINVAL, try next flag set (kernel doesn't support these flags) + if errno == unix.EINVAL && flagSetIdx < len(flagSets)-1 { + flagSetIdx++ + continue + } + if errno == unix.ENOMEM && tries > minEntries { + tries /= 2 + if tries < minEntries { + tries = minEntries + } + continue + } + return nil, errno + } + + ring := &ioUringRecvState{ + fd: int(fd), + sqEntryCount: params.SqEntries, + cqEntryCount: params.CqEntries, + userData: 1, + bufferMap: make(map[uint64]*recvBuffer), + sockFd: sockFd, + } + + if err := ring.mapRings(¶ms); err != nil { + ring.Close() + if errors.Is(err, unix.ENOMEM) && tries > minEntries { + tries /= 2 + if tries < minEntries { + tries = minEntries + } + continue + } + return nil, err + } + + // Allocate buffer pool + ring.bufferPool = make([]*recvBuffer, poolSize) + for i := 0; i < poolSize; i++ { + buf := &recvBuffer{ + payloadBuf: make([]byte, bufferSize), + nameBuf: make([]byte, unix.SizeofSockaddrInet6), + controlBuf: make([]byte, 256), + msghdr: &unix.Msghdr{}, + iovec: &unix.Iovec{}, + userData: ring.userData, + } + ring.userData++ + + // Initialize iovec to point to payload buffer + buf.iovec.Base = &buf.payloadBuf[0] + buf.iovec.SetLen(len(buf.payloadBuf)) + + // Initialize msghdr + buf.msghdr.Name = &buf.nameBuf[0] + buf.msghdr.Namelen = uint32(len(buf.nameBuf)) + buf.msghdr.Iov = buf.iovec + buf.msghdr.Iovlen = 1 + buf.msghdr.Control = &buf.controlBuf[0] + buf.msghdr.Controllen = uint64(len(buf.controlBuf)) + + ring.bufferPool[i] = buf + ring.bufferMap[buf.userData] = buf + } + + logrus.WithFields(logrus.Fields{ + "poolSize": poolSize, + "entries": ring.sqEntryCount, + "bufferSize": bufferSize, + }).Info("io_uring receive ring created") + + // Limit kernel worker threads to prevent thousands being spawned + // [0] = bounded workers, [1] = unbounded workers + maxWorkers := [2]uint32{4, 4} // Limit to 4 workers of each type + _, _, errno = unix.Syscall6( + unix.SYS_IO_URING_REGISTER, + uintptr(fd), + uintptr(ioringRegisterIowqMaxWorkers), + uintptr(unsafe.Pointer(&maxWorkers[0])), + 2, // array length + 0, 0, + ) + // Ignore errors - older kernels don't support this + + return ring, nil +} +} + +func (r *ioUringRecvState) mapRings(params *ioUringParams) error { + pageSize := uint32(unix.Getpagesize()) + + sqRingSize := alignUint32(params.SqOff.Array+params.SqEntries*4, pageSize) + cqRingSize := alignUint32(params.CqOff.Cqes+params.CqEntries*16, pageSize) + + if params.Features&(1<<0) != 0 { // IORING_FEAT_SINGLE_MMAP + if sqRingSize > cqRingSize { + cqRingSize = sqRingSize + } else { + sqRingSize = cqRingSize + } + } + + sqRingPtr, err := unix.Mmap(r.fd, int64(ioringOffSqRing), int(sqRingSize), unix.PROT_READ|unix.PROT_WRITE, unix.MAP_SHARED|unix.MAP_POPULATE) + if err != nil { + return err + } + r.sqRing = sqRingPtr + + if params.Features&(1<<0) != 0 { + r.cqRing = sqRingPtr + } else { + cqRingPtr, err := unix.Mmap(r.fd, int64(ioringOffCqRing), int(cqRingSize), unix.PROT_READ|unix.PROT_WRITE, unix.MAP_SHARED|unix.MAP_POPULATE) + if err != nil { + return err + } + r.cqRing = cqRingPtr + } + + sqesSize := int(params.SqEntries) * ioUringSqeSize + sqesPtr, err := unix.Mmap(r.fd, int64(ioringOffSqes), sqesSize, unix.PROT_READ|unix.PROT_WRITE, unix.MAP_SHARED|unix.MAP_POPULATE) + if err != nil { + return err + } + r.sqesMap = sqesPtr + + // Set up SQ pointers + r.sqHead = (*uint32)(unsafe.Pointer(&r.sqRing[params.SqOff.Head])) + r.sqTail = (*uint32)(unsafe.Pointer(&r.sqRing[params.SqOff.Tail])) + r.sqRingMask = (*uint32)(unsafe.Pointer(&r.sqRing[params.SqOff.RingMask])) + r.sqRingEntries = (*uint32)(unsafe.Pointer(&r.sqRing[params.SqOff.RingEntries])) + + // Set up SQ array + arrayBase := unsafe.Pointer(&r.sqRing[params.SqOff.Array]) + r.sqArray = unsafe.Slice((*uint32)(arrayBase), params.SqEntries) + + // Set up SQE slice + r.sqes = unsafe.Slice((*ioUringSqe)(unsafe.Pointer(&sqesPtr[0])), params.SqEntries) + + // Set up CQ pointers + r.cqHead = (*uint32)(unsafe.Pointer(&r.cqRing[params.CqOff.Head])) + r.cqTail = (*uint32)(unsafe.Pointer(&r.cqRing[params.CqOff.Tail])) + r.cqRingMask = (*uint32)(unsafe.Pointer(&r.cqRing[params.CqOff.RingMask])) + r.cqRingEntries = (*uint32)(unsafe.Pointer(&r.cqRing[params.CqOff.RingEntries])) + + cqesBase := unsafe.Pointer(&r.cqRing[params.CqOff.Cqes]) + r.cqCqes = unsafe.Slice((*ioUringCqe)(cqesBase), params.CqEntries) + + return nil +} + +// submitRecvLocked submits a single receive operation. Must be called with mutex held. +func (r *ioUringRecvState) submitRecvLocked(buf *recvBuffer) error { + if buf.inFlight.Load() { + return fmt.Errorf("buffer already in flight") + } + + // Reset buffer state for reuse + buf.msghdr.Namelen = uint32(len(buf.nameBuf)) + buf.msghdr.Controllen = uint64(len(buf.controlBuf)) + buf.msghdr.Flags = 0 + buf.iovec.SetLen(len(buf.payloadBuf)) + + // Get next SQE + tail := atomic.LoadUint32(r.sqTail) + head := atomic.LoadUint32(r.sqHead) + mask := *r.sqRingMask + + if tail-head >= *r.sqRingEntries { + return fmt.Errorf("submission queue full") + } + + idx := tail & mask + sqe := &r.sqes[idx] + + // Set up SQE for IORING_OP_RECVMSG + *sqe = ioUringSqe{} + sqe.Opcode = ioringOpRecvmsg + sqe.Fd = int32(r.sockFd) + sqe.Addr = uint64(uintptr(unsafe.Pointer(buf.msghdr))) + sqe.Len = 1 + sqe.UserData = buf.userData + + r.sqArray[idx] = uint32(idx) + atomic.StoreUint32(r.sqTail, tail+1) + + buf.inFlight.Store(true) + + return nil +} + +// submitAndWaitLocked submits pending SQEs and optionally waits for completions +func (r *ioUringRecvState) submitAndWaitLocked(submit, wait uint32) error { + var flags uintptr + if wait > 0 { + flags = ioringEnterGetevents + } + + for { + ret, _, errno := unix.Syscall6(unix.SYS_IO_URING_ENTER, uintptr(r.fd), uintptr(submit), uintptr(wait), flags, 0, 0) + if errno == 0 { + if wait > 0 && ret > 0 { + logrus.WithFields(logrus.Fields{ + "completed": ret, + "submitted": submit, + }).Debug("io_uring recv: operations completed") + } + return nil + } + if errno == unix.EINTR { + continue + } + return errno + } +} + +// fillRecvQueue fills the submission queue with as many receives as possible +func (r *ioUringRecvState) fillRecvQueue() error { + r.mu.Lock() + defer r.mu.Unlock() + + if r.closed.Load() { + return fmt.Errorf("ring closed") + } + + submitted := uint32(0) + for _, buf := range r.bufferPool { + if !buf.inFlight.Load() { + if err := r.submitRecvLocked(buf); err != nil { + if submitted > 0 { + break // Queue full, submit what we have + } + return err + } + submitted++ + } + } + + if submitted > 0 { + return r.submitAndWaitLocked(submitted, 0) + } + + return nil +} + +// receivePackets processes all completed receives and returns packets +// Returns a slice of completed packets +func (r *ioUringRecvState) receivePackets(wait bool) ([]RecvPacket, error) { + r.mu.Lock() + defer r.mu.Unlock() + + if r.closed.Load() { + return nil, fmt.Errorf("ring closed") + } + + // First submit any pending (to ensure we always have receives queued) + submitted := uint32(0) + for _, buf := range r.bufferPool { + if !buf.inFlight.Load() { + if err := r.submitRecvLocked(buf); err != nil { + break // Queue might be full + } + submitted++ + } + } + + waitCount := uint32(0) + if wait { + waitCount = 1 + } + + if submitted > 0 || wait { + if err := r.submitAndWaitLocked(submitted, waitCount); err != nil { + return nil, err + } + } + + // Process completed CQEs + var packets []RecvPacket + head := atomic.LoadUint32(r.cqHead) + tail := atomic.LoadUint32(r.cqTail) + mask := *r.cqRingMask + + completions := uint32(0) + errors := 0 + eagains := 0 + + for head != tail { + idx := head & mask + cqe := &r.cqCqes[idx] + + userData := cqe.UserData + res := cqe.Res + flags := cqe.Flags + + head++ + atomic.StoreUint32(r.cqHead, head) + completions++ + + buf, ok := r.bufferMap[userData] + if !ok { + logrus.WithField("userData", userData).Warn("io_uring recv: unknown userData in completion") + continue + } + + buf.inFlight.Store(false) + + if res < 0 { + errno := syscall.Errno(-res) + // EAGAIN is expected for non-blocking - just resubmit + if errno == unix.EAGAIN { + eagains++ + } else { + errors++ + logrus.WithFields(logrus.Fields{ + "userData": userData, + "errno": errno, + }).Debug("io_uring recv error") + } + continue + } + + if res == 0 { + // Connection closed or no data + continue + } + + // Successfully received packet + n := int(res) + + // Copy address + var from unix.RawSockaddrInet6 + if buf.msghdr.Namelen > 0 && buf.msghdr.Namelen <= uint32(len(buf.nameBuf)) { + copy((*(*[unix.SizeofSockaddrInet6]byte)(unsafe.Pointer(&from)))[:], buf.nameBuf[:buf.msghdr.Namelen]) + } + + // Get buffer from pool and copy data + dataBufPtr := recvPacketDataPool.Get().(*[]byte) + dataBuf := *dataBufPtr + if cap(dataBuf) < n { + // Buffer too small, allocate new one + dataBuf = make([]byte, n) + } else { + dataBuf = dataBuf[:n] + } + copy(dataBuf, buf.payloadBuf[:n]) + + // Copy control messages if present + var controlBuf []byte + var controlBufPtr *[]byte + controllen := int(buf.msghdr.Controllen) + if controllen > 0 && controllen <= len(buf.controlBuf) { + controlBufPtr = recvControlDataPool.Get().(*[]byte) + controlBuf = (*controlBufPtr)[:controllen] + copy(controlBuf, buf.controlBuf[:controllen]) + } + + packets = append(packets, RecvPacket{ + Data: dataBuf, + N: n, + From: &from, + Flags: flags, + Control: controlBuf, + Controllen: controllen, + RecycleFunc: func() { + // Return buffers to pool + recvPacketDataPool.Put(dataBufPtr) + if controlBufPtr != nil { + recvControlDataPool.Put(controlBufPtr) + } + }, + }) + } + + return packets, nil +} + +// Close shuts down the receive ring +func (r *ioUringRecvState) Close() error { + if r == nil { + return nil + } + + r.closed.Store(true) + + r.mu.Lock() + defer r.mu.Unlock() + + // Clean up buffers + for _, buf := range r.bufferPool { + buf.inFlight.Store(false) + } + r.bufferPool = nil + r.bufferMap = nil + + var err error + if r.sqesMap != nil { + if e := unix.Munmap(r.sqesMap); e != nil && err == nil { + err = e + } + r.sqesMap = nil + } + if r.sqRing != nil { + if e := unix.Munmap(r.sqRing); e != nil && err == nil { + err = e + } + r.sqRing = nil + } + if r.cqRing != nil && len(r.cqRing) > 0 { + // Only unmap if it's a separate mapping + if len(r.cqRing) != len(r.sqRing) || uintptr(unsafe.Pointer(&r.cqRing[0])) != uintptr(unsafe.Pointer(&r.sqRing[0])) { + if e := unix.Munmap(r.cqRing); e != nil && err == nil { + err = e + } + } + r.cqRing = nil + } + if r.fd >= 0 { + if e := unix.Close(r.fd); e != nil && err == nil { + err = e + } + r.fd = -1 + } + return err +} diff --git a/udp/msghdr_helper_linux_32.go b/udp/msghdr_helper_linux_32.go new file mode 100644 index 0000000..624bc0e --- /dev/null +++ b/udp/msghdr_helper_linux_32.go @@ -0,0 +1,25 @@ +//go:build linux && (386 || amd64p32 || arm || mips || mipsle) && !android && !e2e_testing +// +build linux +// +build 386 amd64p32 arm mips mipsle +// +build !android +// +build !e2e_testing + +package udp + +import "golang.org/x/sys/unix" + +func controllen(n int) uint32 { + return uint32(n) +} + +func setCmsgLen(h *unix.Cmsghdr, n int) { + h.Len = uint32(unix.CmsgLen(n)) +} + +func setIovecLen(v *unix.Iovec, n int) { + v.Len = uint32(n) +} + +func setMsghdrIovlen(m *unix.Msghdr, n int) { + m.Iovlen = uint32(n) +} diff --git a/udp/msghdr_helper_linux_64.go b/udp/msghdr_helper_linux_64.go new file mode 100644 index 0000000..9a4c71b --- /dev/null +++ b/udp/msghdr_helper_linux_64.go @@ -0,0 +1,25 @@ +//go:build linux && (amd64 || arm64 || ppc64 || ppc64le || mips64 || mips64le || s390x || riscv64 || loong64) && !android && !e2e_testing +// +build linux +// +build amd64 arm64 ppc64 ppc64le mips64 mips64le s390x riscv64 loong64 +// +build !android +// +build !e2e_testing + +package udp + +import "golang.org/x/sys/unix" + +func controllen(n int) uint64 { + return uint64(n) +} + +func setCmsgLen(h *unix.Cmsghdr, n int) { + h.Len = uint64(unix.CmsgLen(n)) +} + +func setIovecLen(v *unix.Iovec, n int) { + v.Len = uint64(n) +} + +func setMsghdrIovlen(m *unix.Msghdr, n int) { + m.Iovlen = uint64(n) +} diff --git a/udp/sendmmsg_linux_32.go b/udp/sendmmsg_linux_32.go new file mode 100644 index 0000000..fb2afec --- /dev/null +++ b/udp/sendmmsg_linux_32.go @@ -0,0 +1,25 @@ +//go:build linux && (386 || amd64p32 || arm || mips || mipsle) && !android && !e2e_testing + +package udp + +import ( + "unsafe" + + "golang.org/x/sys/unix" +) + +type linuxMmsgHdr struct { + Hdr unix.Msghdr + Len uint32 +} + +func sendmmsg(fd int, hdrs []linuxMmsgHdr, flags int) (int, error) { + if len(hdrs) == 0 { + return 0, nil + } + n, _, errno := unix.Syscall6(unix.SYS_SENDMMSG, uintptr(fd), uintptr(unsafe.Pointer(&hdrs[0])), uintptr(len(hdrs)), uintptr(flags), 0, 0) + if errno != 0 { + return int(n), errno + } + return int(n), nil +} diff --git a/udp/sendmmsg_linux_64.go b/udp/sendmmsg_linux_64.go new file mode 100644 index 0000000..2fa9920 --- /dev/null +++ b/udp/sendmmsg_linux_64.go @@ -0,0 +1,26 @@ +//go:build linux && (amd64 || arm64 || ppc64 || ppc64le || mips64 || mips64le || s390x || riscv64 || loong64) && !android && !e2e_testing + +package udp + +import ( + "unsafe" + + "golang.org/x/sys/unix" +) + +type linuxMmsgHdr struct { + Hdr unix.Msghdr + Len uint32 + _ uint32 +} + +func sendmmsg(fd int, hdrs []linuxMmsgHdr, flags int) (int, error) { + if len(hdrs) == 0 { + return 0, nil + } + n, _, errno := unix.Syscall6(unix.SYS_SENDMMSG, uintptr(fd), uintptr(unsafe.Pointer(&hdrs[0])), uintptr(len(hdrs)), uintptr(flags), 0, 0) + if errno != 0 { + return int(n), errno + } + return int(n), nil +} diff --git a/udp/udp_darwin.go b/udp/udp_darwin.go index c0c6233..edadfc2 100644 --- a/udp/udp_darwin.go +++ b/udp/udp_darwin.go @@ -180,7 +180,7 @@ func (u *StdConn) ListenOut(r EncReader) { u.l.WithError(err).Error("unexpected udp socket receive error") } - r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n]) + r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n], nil) } } diff --git a/udp/udp_generic.go b/udp/udp_generic.go index cb21e57..75fa661 100644 --- a/udp/udp_generic.go +++ b/udp/udp_generic.go @@ -82,6 +82,6 @@ func (u *GenericConn) ListenOut(r EncReader) { return } - r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n]) + r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n], nil) } } diff --git a/udp/udp_linux.go b/udp/udp_linux.go index ec0bf64..7dd04f7 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -5,23 +5,1267 @@ package udp import ( "encoding/binary" + "errors" "fmt" "net" "net/netip" + "runtime" + "sync" + "sync/atomic" "syscall" + "time" "unsafe" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/header" "golang.org/x/sys/unix" ) +const ( + defaultGSOMaxSegments = 8 + defaultGSOMaxBytes = MTU * defaultGSOMaxSegments + defaultGROReadBufferSize = MTU * defaultGSOMaxSegments + defaultGSOFlushTimeout = 150 * time.Microsecond + linuxMaxGSOBatchBytes = 0xFFFF // Linux UDP GSO still limits the datagram payload to 64 KiB + maxSendmmsgBatch = 32 +) + +var ( + // Global mutex to serialize io_uring initialization across all sockets + ioUringInitMu sync.Mutex +) + type StdConn struct { sysFd int isV4 bool l *logrus.Logger batch int + + enableGRO bool + enableGSO bool + + controlLen atomic.Int32 + + gsoMaxSegments int + gsoMaxBytes int + gsoFlushTimeout time.Duration + + groSegmentPool sync.Pool + groBufSize atomic.Int64 + rxBufferPool chan []byte + gsoBufferPool sync.Pool + + gsoBatches metrics.Counter + gsoSegments metrics.Counter + gsoSingles metrics.Counter + groBatches metrics.Counter + groSegments metrics.Counter + gsoFallbacks metrics.Counter + gsoFallbackMu sync.Mutex + gsoFallbackReasons map[string]*atomic.Int64 + gsoBatchTick atomic.Int64 + gsoBatchSegmentsTick atomic.Int64 + gsoSingleTick atomic.Int64 + groBatchTick atomic.Int64 + groSegmentsTick atomic.Int64 + + ioState atomic.Pointer[ioUringState] + ioRecvState atomic.Pointer[ioUringRecvState] + ioActive atomic.Bool + ioRecvActive atomic.Bool + ioAttempted atomic.Bool + ioClosing atomic.Bool + ioUringHoldoff atomic.Int64 + ioUringMaxBatch atomic.Int64 + + sendShards []*sendShard + shardCounter atomic.Uint32 +} + +type sendTask struct { + buf []byte + addr netip.AddrPort + segSize int + segments int + owned bool +} + +type batchSendItem struct { + task *sendTask + addr netip.AddrPort + payload []byte + control []byte + msgFlags uint32 + resultBytes int + err error +} + +const sendShardQueueDepth = 128 +const ( + ioUringDefaultMaxBatch = 32 + ioUringMinMaxBatch = 1 + ioUringMaxMaxBatch = 4096 + ioUringDefaultHoldoff = 25 * time.Microsecond + ioUringMinHoldoff = 0 + ioUringMaxHoldoff = 500 * time.Millisecond + ioUringHoldoffSpinThreshold = 50 * time.Microsecond +) + +var ioUringSendmsgBatch = func(state *ioUringState, entries []ioUringBatchEntry) error { + return state.SendmsgBatch(entries) +} + +type sendShard struct { + parent *StdConn + + mu sync.Mutex + + pendingBuf []byte + pendingSegments int + pendingAddr netip.AddrPort + pendingSegSize int + flushTimer *time.Timer + controlBuf []byte + + mmsgHeaders []linuxMmsgHdr + mmsgIovecs []unix.Iovec + mmsgLengths []int + + outQueue chan *sendTask + workerDone sync.WaitGroup +} + +func clampIoUringBatchSize(requested int, ringEntries uint32) int { + if requested < ioUringMinMaxBatch { + requested = ioUringDefaultMaxBatch + } + if requested < ioUringMinMaxBatch { + requested = ioUringMinMaxBatch + } + if requested > ioUringMaxMaxBatch { + requested = ioUringMaxMaxBatch + } + if ringEntries > 0 && requested > int(ringEntries) { + requested = int(ringEntries) + } + if requested < ioUringMinMaxBatch { + requested = ioUringMinMaxBatch + } + return requested +} + +func (s *sendShard) currentHoldoff() time.Duration { + if s.parent == nil { + return 0 + } + holdoff := s.parent.ioUringHoldoff.Load() + if holdoff < 0 { + holdoff = 0 + } + if holdoff <= 0 { + return 0 + } + return time.Duration(holdoff) +} + +func (s *sendShard) currentMaxBatch() int { + if s == nil || s.parent == nil { + return ioUringDefaultMaxBatch + } + maxBatch := s.parent.ioUringMaxBatch.Load() + if maxBatch <= 0 { + return ioUringDefaultMaxBatch + } + if maxBatch > ioUringMaxMaxBatch { + maxBatch = ioUringMaxMaxBatch + } + return int(maxBatch) +} + +func (u *StdConn) initSendShards() { + shardCount := runtime.GOMAXPROCS(0) + if shardCount < 1 { + shardCount = 1 + } + u.resizeSendShards(shardCount) +} + +func toIPv4Mapped(v4 [4]byte) [16]byte { + return [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, v4[0], v4[1], v4[2], v4[3]} +} + +func (u *StdConn) populateSockaddrInet6(sa6 *unix.RawSockaddrInet6, addr netip.Addr) { + sa6.Family = unix.AF_INET6 + if addr.Is4() { + // Convert IPv4 to IPv4-mapped IPv6 format for dual-stack socket + sa6.Addr = toIPv4Mapped(addr.As4()) + } else { + sa6.Addr = addr.As16() + } + sa6.Scope_id = 0 +} + +func (u *StdConn) selectSendShard(addr netip.AddrPort) *sendShard { + if len(u.sendShards) == 0 { + return nil + } + if len(u.sendShards) == 1 { + return u.sendShards[0] + } + idx := int(u.shardCounter.Add(1)-1) % len(u.sendShards) + if idx < 0 { + idx = -idx + } + return u.sendShards[idx] +} + +func (u *StdConn) resizeSendShards(count int) { + if count <= 0 { + count = runtime.GOMAXPROCS(0) + if count < 1 { + count = 1 + } + } + + if len(u.sendShards) == count { + return + } + + // Give existing shard workers time to fully initialize before stopping + // This prevents a race where we try to stop shards before they're ready + if len(u.sendShards) > 0 { + time.Sleep(time.Millisecond) + } + + for _, shard := range u.sendShards { + if shard == nil { + continue + } + shard.mu.Lock() + if shard.pendingSegments > 0 { + if err := shard.flushPendingLocked(); err != nil { + u.l.WithError(err).Warn("Failed to flush send shard while resizing") + } + } else { + shard.stopFlushTimerLocked() + } + buf := shard.pendingBuf + shard.pendingBuf = nil + shard.mu.Unlock() + if buf != nil { + u.releaseGSOBuf(buf) + } + shard.stopSender() + } + + newShards := make([]*sendShard, count) + for i := range newShards { + shard := &sendShard{parent: u} + shard.startSender() + newShards[i] = shard + } + u.sendShards = newShards + u.shardCounter.Store(0) + u.l.WithField("send_shards", count).Debug("Configured UDP send shards") +} + +func (u *StdConn) setGroBufferSize(size int) { + if size < defaultGROReadBufferSize { + size = defaultGROReadBufferSize + } + u.groBufSize.Store(int64(size)) + u.groSegmentPool = sync.Pool{New: func() any { + return make([]byte, size) + }} + if u.rxBufferPool == nil { + poolSize := u.batch * 4 + if poolSize < u.batch { + poolSize = u.batch + } + u.rxBufferPool = make(chan []byte, poolSize) + for i := 0; i < poolSize; i++ { + u.rxBufferPool <- make([]byte, size) + } + } +} + +func (u *StdConn) borrowRxBuffer(desired int) []byte { + if desired < MTU { + desired = MTU + } + if u.rxBufferPool == nil { + return make([]byte, desired) + } + buf := <-u.rxBufferPool + if cap(buf) < desired { + buf = make([]byte, desired) + } + return buf[:desired] +} + +func (u *StdConn) recycleBuffer(buf []byte) { + if buf == nil { + return + } + if u.rxBufferPool == nil { + return + } + buf = buf[:cap(buf)] + desired := int(u.groBufSize.Load()) + if desired < MTU { + desired = MTU + } + if cap(buf) < desired { + return + } + select { + case u.rxBufferPool <- buf[:desired]: + default: + } +} + +func (u *StdConn) recycleBufferSet(bufs [][]byte) { + for i := range bufs { + u.recycleBuffer(bufs[i]) + } +} + +func isSocketCloseError(err error) bool { + if err == nil { + return false + } + if errors.Is(err, unix.EPIPE) || errors.Is(err, unix.ENOTCONN) || errors.Is(err, unix.EINVAL) || errors.Is(err, unix.EBADF) { + return true + } + var opErr *net.OpError + if errors.As(err, &opErr) { + if errno, ok := opErr.Err.(syscall.Errno); ok { + switch errno { + case unix.EPIPE, unix.ENOTCONN, unix.EINVAL, unix.EBADF: + return true + } + } + } + return false +} + +func (u *StdConn) recordGSOFallback(reason string) { + if u == nil { + return + } + if reason == "" { + reason = "unknown" + } + if u.gsoFallbacks != nil { + u.gsoFallbacks.Inc(1) + } + u.gsoFallbackMu.Lock() + counter, ok := u.gsoFallbackReasons[reason] + if !ok { + counter = &atomic.Int64{} + u.gsoFallbackReasons[reason] = counter + } + counter.Add(1) + u.gsoFallbackMu.Unlock() +} + +func (u *StdConn) recordGSOSingle(count int) { + if u == nil || count <= 0 { + return + } + if u.gsoSingles != nil { + u.gsoSingles.Inc(int64(count)) + } + u.gsoSingleTick.Add(int64(count)) +} + +func (u *StdConn) snapshotGSOFallbacks() map[string]int64 { + u.gsoFallbackMu.Lock() + defer u.gsoFallbackMu.Unlock() + if len(u.gsoFallbackReasons) == 0 { + return nil + } + out := make(map[string]int64, len(u.gsoFallbackReasons)) + for reason, counter := range u.gsoFallbackReasons { + if counter == nil { + continue + } + count := counter.Swap(0) + if count != 0 { + out[reason] = count + } + } + return out +} + +func (u *StdConn) logGSOTick() { + u.gsoBatchTick.Store(0) + u.gsoBatchSegmentsTick.Store(0) + u.gsoSingleTick.Store(0) + u.groBatchTick.Store(0) + u.groSegmentsTick.Store(0) + u.snapshotGSOFallbacks() +} + +func (u *StdConn) borrowGSOBuf() []byte { + size := u.gsoMaxBytes + if size <= 0 { + size = MTU + } + if v := u.gsoBufferPool.Get(); v != nil { + buf := v.([]byte) + if cap(buf) < size { + u.gsoBufferPool.Put(buf[:0]) + return make([]byte, 0, size) + } + return buf[:0] + } + return make([]byte, 0, size) +} + +func (u *StdConn) borrowIOBuf(size int) []byte { + if size <= 0 { + size = MTU + } + if v := u.gsoBufferPool.Get(); v != nil { + buf := v.([]byte) + if cap(buf) < size { + u.gsoBufferPool.Put(buf[:0]) + return make([]byte, 0, size) + } + return buf[:0] + } + return make([]byte, 0, size) +} + +func (u *StdConn) releaseGSOBuf(buf []byte) { + if buf == nil { + return + } + size := u.gsoMaxBytes + if size <= 0 { + size = MTU + } + buf = buf[:0] + if cap(buf) > size*4 { + return + } + u.gsoBufferPool.Put(buf) +} + +func (s *sendShard) ensureMmsgCapacity(n int) { + if cap(s.mmsgHeaders) < n { + s.mmsgHeaders = make([]linuxMmsgHdr, n) + } + s.mmsgHeaders = s.mmsgHeaders[:n] + if cap(s.mmsgIovecs) < n { + s.mmsgIovecs = make([]unix.Iovec, n) + } + s.mmsgIovecs = s.mmsgIovecs[:n] + if cap(s.mmsgLengths) < n { + s.mmsgLengths = make([]int, n) + } + s.mmsgLengths = s.mmsgLengths[:n] +} + +func (s *sendShard) ensurePendingBuf(p *StdConn) { + if s.pendingBuf == nil { + s.pendingBuf = p.borrowGSOBuf() + } +} + +func (s *sendShard) startSender() { + if s.outQueue != nil { + return + } + s.outQueue = make(chan *sendTask, sendShardQueueDepth) + s.workerDone.Add(1) + go s.senderLoop() +} + +func (s *sendShard) stopSender() { + s.closeSender() + s.workerDone.Wait() +} + +func (s *sendShard) closeSender() { + s.mu.Lock() + queue := s.outQueue + s.outQueue = nil + s.mu.Unlock() + if queue != nil { + close(queue) + } +} + +func (s *sendShard) submitTask(task *sendTask) error { + if task == nil { + return nil + } + if len(task.buf) == 0 { + if task.owned && task.buf != nil && s.parent != nil { + s.parent.releaseGSOBuf(task.buf) + } + return nil + } + + if s.parent != nil && s.parent.ioClosing.Load() { + if task.owned && task.buf != nil { + s.parent.releaseGSOBuf(task.buf) + } + return &net.OpError{Op: "sendmsg", Err: net.ErrClosed} + } + + queue := s.outQueue + if queue != nil { + sent := false + func() { + defer func() { + if r := recover(); r != nil { + sent = false + } + }() + select { + case queue <- task: + sent = true + default: + } + }() + if sent { + return nil + } + } + + return s.processTask(task) +} + +func (s *sendShard) senderLoop() { + defer s.workerDone.Done() + initialCap := s.currentMaxBatch() + if initialCap <= 0 { + initialCap = ioUringDefaultMaxBatch + } + batch := make([]*sendTask, 0, initialCap) + var holdoffTimer *time.Timer + var holdoffCh <-chan time.Time + + stopTimer := func() { + if holdoffTimer == nil { + return + } + if !holdoffTimer.Stop() { + select { + case <-holdoffTimer.C: + default: + } + } + holdoffTimer = nil + holdoffCh = nil + } + + resetTimer := func() { + holdoff := s.currentHoldoff() + if holdoff <= 0 { + return + } + if holdoffTimer == nil { + holdoffTimer = time.NewTimer(holdoff) + holdoffCh = holdoffTimer.C + return + } + if !holdoffTimer.Stop() { + select { + case <-holdoffTimer.C: + default: + } + } + holdoffTimer.Reset(holdoff) + holdoffCh = holdoffTimer.C + } + + flush := func() { + if len(batch) == 0 { + return + } + stopTimer() + if err := s.processTasksBatch(batch); err != nil && s.parent != nil { + s.parent.l.WithError(err).Debug("io_uring batch send encountered error") + } + for i := range batch { + batch[i] = nil + } + batch = batch[:0] + } + + for { + if len(batch) == 0 { + if s.parent != nil && s.parent.ioClosing.Load() { + flush() + stopTimer() + return + } + task, ok := <-s.outQueue + if !ok { + flush() + stopTimer() + return + } + if task == nil { + continue + } + batch = append(batch, task) + maxBatch := s.currentMaxBatch() + holdoff := s.currentHoldoff() + if len(batch) >= maxBatch || holdoff <= 0 { + flush() + continue + } + if holdoff <= ioUringHoldoffSpinThreshold { + deadline := time.Now().Add(holdoff) + for { + if len(batch) >= maxBatch { + break + } + remaining := time.Until(deadline) + if remaining <= 0 { + break + } + select { + case next, ok := <-s.outQueue: + if !ok { + flush() + return + } + if next == nil { + continue + } + if s.parent != nil && s.parent.ioClosing.Load() { + flush() + return + } + batch = append(batch, next) + default: + if remaining > 5*time.Microsecond { + runtime.Gosched() + } + } + } + flush() + continue + } + resetTimer() + continue + } + + select { + case task, ok := <-s.outQueue: + if !ok { + flush() + stopTimer() + return + } + if task == nil { + continue + } + if s.parent != nil && s.parent.ioClosing.Load() { + flush() + stopTimer() + return + } + batch = append(batch, task) + if len(batch) >= s.currentMaxBatch() { + flush() + } else if s.currentHoldoff() > 0 { + resetTimer() + } + case <-holdoffCh: + stopTimer() + flush() + } + } +} + +func (s *sendShard) processTask(task *sendTask) error { + return s.processTasksBatch([]*sendTask{task}) +} + +func (s *sendShard) processTasksBatch(tasks []*sendTask) error { + if len(tasks) == 0 { + return nil + } + p := s.parent + state := p.ioState.Load() + var firstErr error + if state != nil { + if err := s.processTasksBatchIOUring(state, tasks); err != nil { + firstErr = err + } + } else { + for _, task := range tasks { + if err := s.processTaskFallback(task); err != nil && firstErr == nil { + firstErr = err + } + } + } + for _, task := range tasks { + if task == nil { + continue + } + if task.owned && task.buf != nil { + p.releaseGSOBuf(task.buf) + } + task.buf = nil + } + return firstErr +} + +func (s *sendShard) processTasksBatchIOUring(state *ioUringState, tasks []*sendTask) error { + capEstimate := 0 + maxSeg := 1 + if s.parent != nil && s.parent.ioUringMaxBatch.Load() > 0 { + maxSeg = int(s.parent.ioUringMaxBatch.Load()) + } + for _, task := range tasks { + if task == nil || len(task.buf) == 0 { + continue + } + if task.segSize > 0 && task.segSize < len(task.buf) { + capEstimate += (len(task.buf) + task.segSize - 1) / task.segSize + } else { + capEstimate++ + } + } + if capEstimate <= 0 { + capEstimate = len(tasks) + } + if capEstimate > maxSeg { + capEstimate = maxSeg + } + items := make([]*batchSendItem, 0, capEstimate) + for _, task := range tasks { + if task == nil || len(task.buf) == 0 { + continue + } + useGSO := s.parent.enableGSO && task.segments > 1 + if useGSO { + control := make([]byte, unix.CmsgSpace(2)) + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) + setCmsgLen(hdr, 2) + hdr.Level = unix.SOL_UDP + hdr.Type = unix.UDP_SEGMENT + dataOff := unix.CmsgLen(0) + binary.NativeEndian.PutUint16(control[dataOff:dataOff+2], uint16(task.segSize)) + items = append(items, &batchSendItem{ + task: task, + addr: task.addr, + payload: task.buf, + control: control, + msgFlags: 0, + }) + continue + } + + segSize := task.segSize + if segSize <= 0 || segSize >= len(task.buf) { + items = append(items, &batchSendItem{ + task: task, + addr: task.addr, + payload: task.buf, + }) + continue + } + + for offset := 0; offset < len(task.buf); offset += segSize { + end := offset + segSize + if end > len(task.buf) { + end = len(task.buf) + } + segment := task.buf[offset:end] + items = append(items, &batchSendItem{ + task: task, + addr: task.addr, + payload: segment, + }) + } + } + + if len(items) == 0 { + return nil + } + + if err := s.parent.sendMsgIOUringBatch(state, items); err != nil { + return err + } + + var firstErr error + for _, item := range items { + if item.err != nil && firstErr == nil { + firstErr = item.err + } + } + if firstErr != nil { + return firstErr + } + + for _, task := range tasks { + if task == nil { + continue + } + if s.parent.enableGSO && task.segments > 1 { + s.recordGSOMetrics(task) + } else { + s.parent.recordGSOSingle(task.segments) + } + } + + return nil +} + +func (s *sendShard) processTaskFallback(task *sendTask) error { + if task == nil || len(task.buf) == 0 { + return nil + } + p := s.parent + useGSO := p.enableGSO && task.segments > 1 + s.mu.Lock() + defer s.mu.Unlock() + if useGSO { + if err := s.sendSegmentedLocked(task.buf, task.addr, task.segSize); err != nil { + return err + } + s.recordGSOMetrics(task) + return nil + } + if err := s.sendSequentialLocked(task.buf, task.addr, task.segSize); err != nil { + return err + } + p.recordGSOSingle(task.segments) + return nil +} + +func (s *sendShard) recordGSOMetrics(task *sendTask) { + p := s.parent + if p.gsoBatches != nil { + p.gsoBatches.Inc(1) + } + if p.gsoSegments != nil { + p.gsoSegments.Inc(int64(task.segments)) + } + p.gsoBatchTick.Add(1) + p.gsoBatchSegmentsTick.Add(int64(task.segments)) + if p.l.IsLevelEnabled(logrus.DebugLevel) { + p.l.WithFields(logrus.Fields{ + "tag": "gso-debug", + "stage": "flush", + "segments": task.segments, + "segment_size": task.segSize, + "batch_bytes": len(task.buf), + "remote_addr": task.addr.String(), + }).Debug("gso batch sent") + } +} + +func (s *sendShard) write(b []byte, addr netip.AddrPort) error { + if len(b) == 0 { + return nil + } + + s.mu.Lock() + defer s.mu.Unlock() + + p := s.parent + + if !p.enableGSO || !addr.IsValid() { + p.recordGSOSingle(1) + return p.directWrite(b, addr) + } + + s.ensurePendingBuf(p) + + if s.pendingSegments > 0 && s.pendingAddr != addr { + if err := s.flushPendingLocked(); err != nil { + return err + } + s.ensurePendingBuf(p) + } + + if len(b) > p.gsoMaxBytes || p.gsoMaxSegments <= 1 { + if err := s.flushPendingLocked(); err != nil { + return err + } + p.recordGSOSingle(1) + return p.directWrite(b, addr) + } + + if s.pendingSegments == 0 { + s.pendingAddr = addr + s.pendingSegSize = len(b) + } else if len(b) != s.pendingSegSize { + if err := s.flushPendingLocked(); err != nil { + return err + } + s.pendingAddr = addr + s.pendingSegSize = len(b) + s.ensurePendingBuf(p) + } + + if len(s.pendingBuf)+len(b) > p.gsoMaxBytes { + if err := s.flushPendingLocked(); err != nil { + return err + } + s.pendingAddr = addr + s.pendingSegSize = len(b) + s.ensurePendingBuf(p) + } + + s.pendingBuf = append(s.pendingBuf, b...) + s.pendingSegments++ + + if s.pendingSegments >= p.gsoMaxSegments { + return s.flushPendingLocked() + } + + if p.gsoFlushTimeout <= 0 { + return s.flushPendingLocked() + } + + s.scheduleFlushLocked() + return nil +} + +func (s *sendShard) flushPendingLocked() error { + if s.pendingSegments == 0 { + s.stopFlushTimerLocked() + return nil + } + + buf := s.pendingBuf + task := &sendTask{ + buf: buf, + addr: s.pendingAddr, + segSize: s.pendingSegSize, + segments: s.pendingSegments, + owned: true, + } + + s.pendingBuf = nil + s.pendingSegments = 0 + s.pendingSegSize = 0 + s.pendingAddr = netip.AddrPort{} + + s.stopFlushTimerLocked() + + s.mu.Unlock() + err := s.submitTask(task) + s.mu.Lock() + return err +} + +func (s *sendShard) enqueueImmediate(payload []byte, addr netip.AddrPort) error { + if len(payload) == 0 { + return nil + } + if !addr.IsValid() { + return &net.OpError{Op: "sendmsg", Err: unix.EINVAL} + } + if s.parent != nil && s.parent.ioClosing.Load() { + return &net.OpError{Op: "sendmsg", Err: net.ErrClosed} + } + + buf := s.parent.borrowIOBuf(len(payload)) + buf = append(buf[:0], payload...) + + task := &sendTask{ + buf: buf, + addr: addr, + segSize: len(payload), + segments: 1, + owned: true, + } + if err := s.submitTask(task); err != nil { + return err + } + return nil +} + +func (s *sendShard) sendSegmentedIOUring(state *ioUringState, buf []byte, addr netip.AddrPort, segSize int) error { + if state == nil || len(buf) == 0 { + return nil + } + if segSize <= 0 { + segSize = len(buf) + } + if len(s.controlBuf) < unix.CmsgSpace(2) { + s.controlBuf = make([]byte, unix.CmsgSpace(2)) + } + control := s.controlBuf[:unix.CmsgSpace(2)] + for i := range control { + control[i] = 0 + } + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) + setCmsgLen(hdr, 2) + hdr.Level = unix.SOL_UDP + hdr.Type = unix.UDP_SEGMENT + dataOff := unix.CmsgLen(0) + binary.NativeEndian.PutUint16(control[dataOff:dataOff+2], uint16(segSize)) + + n, err := s.parent.sendMsgIOUring(state, addr, buf, control, 0) + if err != nil { + return err + } + if n != len(buf) { + return &net.OpError{Op: "sendmsg", Err: unix.EIO} + } + return nil +} + +func (s *sendShard) sendSequentialIOUring(state *ioUringState, buf []byte, addr netip.AddrPort, segSize int) error { + if state == nil || len(buf) == 0 { + return nil + } + if segSize <= 0 { + segSize = len(buf) + } + if segSize >= len(buf) { + n, err := s.parent.sendMsgIOUring(state, addr, buf, nil, 0) + if err != nil { + return err + } + if n != len(buf) { + return &net.OpError{Op: "sendmsg", Err: unix.EIO} + } + return nil + } + + total := len(buf) + offset := 0 + for offset < total { + end := offset + segSize + if end > total { + end = total + } + segment := buf[offset:end] + n, err := s.parent.sendMsgIOUring(state, addr, segment, nil, 0) + if err != nil { + return err + } + if n != len(segment) { + return &net.OpError{Op: "sendmsg", Err: unix.EIO} + } + offset = end + } + return nil +} + +func (s *sendShard) sendSegmentedLocked(buf []byte, addr netip.AddrPort, segSize int) error { + if len(buf) == 0 { + return nil + } + if segSize <= 0 { + segSize = len(buf) + } + + if len(s.controlBuf) < unix.CmsgSpace(2) { + s.controlBuf = make([]byte, unix.CmsgSpace(2)) + } + control := s.controlBuf[:unix.CmsgSpace(2)] + for i := range control { + control[i] = 0 + } + + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) + setCmsgLen(hdr, 2) + hdr.Level = unix.SOL_UDP + hdr.Type = unix.UDP_SEGMENT + + dataOff := unix.CmsgLen(0) + binary.NativeEndian.PutUint16(control[dataOff:dataOff+2], uint16(segSize)) + + var sa unix.Sockaddr + if s.parent.isV4 { + sa4 := &unix.SockaddrInet4{Port: int(addr.Port())} + sa4.Addr = addr.Addr().As4() + sa = sa4 + } else { + sa6 := &unix.SockaddrInet6{Port: int(addr.Port())} + sa6.Addr = addr.Addr().As16() + sa = sa6 + } + + for { + n, err := unix.SendmsgN(s.parent.sysFd, buf, control[:unix.CmsgSpace(2)], sa, 0) + if err != nil { + if err == unix.EINTR { + continue + } + return &net.OpError{Op: "sendmsg", Err: err} + } + if n != len(buf) { + return &net.OpError{Op: "sendmsg", Err: unix.EIO} + } + return nil + } +} + +func (s *sendShard) sendSequentialLocked(buf []byte, addr netip.AddrPort, segSize int) error { + if len(buf) == 0 { + return nil + } + if segSize <= 0 { + segSize = len(buf) + } + if segSize >= len(buf) { + return s.parent.directWrite(buf, addr) + } + + var ( + namePtr *byte + nameLen uint32 + ) + if s.parent.isV4 { + var sa4 unix.RawSockaddrInet4 + sa4.Family = unix.AF_INET + sa4.Addr = addr.Addr().As4() + binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa4.Port))[:], addr.Port()) + namePtr = (*byte)(unsafe.Pointer(&sa4)) + nameLen = uint32(unsafe.Sizeof(sa4)) + } else { + var sa6 unix.RawSockaddrInet6 + sa6.Family = unix.AF_INET6 + sa6.Addr = addr.Addr().As16() + binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa6.Port))[:], addr.Port()) + namePtr = (*byte)(unsafe.Pointer(&sa6)) + nameLen = uint32(unsafe.Sizeof(sa6)) + } + + total := len(buf) + if total == 0 { + return nil + } + basePtr := uintptr(unsafe.Pointer(&buf[0])) + offset := 0 + + for offset < total { + remaining := total - offset + segments := (remaining + segSize - 1) / segSize + if segments > maxSendmmsgBatch { + segments = maxSendmmsgBatch + } + + s.ensureMmsgCapacity(segments) + msgs := s.mmsgHeaders[:segments] + iovecs := s.mmsgIovecs[:segments] + lens := s.mmsgLengths[:segments] + + batchStart := offset + segOffset := offset + actual := 0 + for actual < segments && segOffset < total { + segLen := segSize + if segLen > total-segOffset { + segLen = total - segOffset + } + + msgs[actual] = linuxMmsgHdr{} + lens[actual] = segLen + iovecs[actual].Base = &buf[segOffset] + setIovecLen(&iovecs[actual], segLen) + msgs[actual].Hdr.Iov = &iovecs[actual] + setMsghdrIovlen(&msgs[actual].Hdr, 1) + msgs[actual].Hdr.Name = namePtr + msgs[actual].Hdr.Namelen = nameLen + msgs[actual].Hdr.Control = nil + msgs[actual].Hdr.Controllen = 0 + msgs[actual].Hdr.Flags = 0 + msgs[actual].Len = 0 + + actual++ + segOffset += segLen + } + if actual == 0 { + break + } + msgs = msgs[:actual] + lens = lens[:actual] + + retry: + sent, err := sendmmsg(s.parent.sysFd, msgs, 0) + if err != nil { + if err == unix.EINTR { + goto retry + } + return &net.OpError{Op: "sendmmsg", Err: err} + } + if sent == 0 { + goto retry + } + + bytesSent := 0 + for i := 0; i < sent; i++ { + bytesSent += lens[i] + } + offset = batchStart + bytesSent + + if sent < len(msgs) { + for j := sent; j < len(msgs); j++ { + start := int(uintptr(unsafe.Pointer(iovecs[j].Base)) - basePtr) + if start < 0 || start >= total { + continue + } + end := start + lens[j] + if end > total { + end = total + } + if err := s.parent.directWrite(buf[start:end], addr); err != nil { + return err + } + if end > offset { + offset = end + } + } + } + } + + return nil +} + +func (s *sendShard) scheduleFlushLocked() { + timeout := s.parent.gsoFlushTimeout + if timeout <= 0 { + _ = s.flushPendingLocked() + return + } + if s.flushTimer == nil { + s.flushTimer = time.AfterFunc(timeout, s.flushTimerHandler) + return + } + if !s.flushTimer.Stop() { + // allow existing timer to drain + } + if !s.flushTimer.Reset(timeout) { + s.flushTimer = time.AfterFunc(timeout, s.flushTimerHandler) + } +} + +func (s *sendShard) stopFlushTimerLocked() { + if s.flushTimer != nil { + s.flushTimer.Stop() + } +} + +func (s *sendShard) flushTimerHandler() { + s.mu.Lock() + defer s.mu.Unlock() + if s.pendingSegments == 0 { + return + } + if err := s.flushPendingLocked(); err != nil { + if !isSocketCloseError(err) { + s.parent.l.WithError(err).Warn("Failed to flush GSO batch") + } + } } func maybeIPV4(ip net.IP) (net.IP, bool) { @@ -49,6 +1293,14 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in return nil, fmt.Errorf("unable to open socket: %s", err) } + if af == unix.AF_INET6 { + if err := unix.SetsockoptInt(fd, unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 0); err != nil { + l.WithError(err).Warn("Failed to clear IPV6_V6ONLY on IPv6 UDP socket") + } else if v6only, err := unix.GetsockoptInt(fd, unix.IPPROTO_IPV6, unix.IPV6_V6ONLY); err == nil { + l.WithField("v6only", v6only).Debug("Configured IPv6 UDP socket V6ONLY state") + } + } + if multi { if err = unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil { return nil, fmt.Errorf("unable to set SO_REUSEPORT: %s", err) @@ -69,7 +1321,35 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in return nil, fmt.Errorf("unable to bind to socket: %s", err) } - return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err + if ip.Is4() && udpChecksumDisabled() { + if err := unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_NO_CHECK, 1); err != nil { + l.WithError(err).Warn("Failed to disable IPv4 UDP checksum via SO_NO_CHECK") + } else { + l.Debug("Disabled IPv4 UDP checksum using SO_NO_CHECK") + } + } + + conn := &StdConn{ + sysFd: fd, + isV4: ip.Is4(), + l: l, + batch: batch, + gsoMaxSegments: defaultGSOMaxSegments, + gsoMaxBytes: defaultGSOMaxBytes, + gsoFlushTimeout: defaultGSOFlushTimeout, + gsoBatches: metrics.GetOrRegisterCounter("udp.gso.batches", nil), + gsoSegments: metrics.GetOrRegisterCounter("udp.gso.segments", nil), + gsoSingles: metrics.GetOrRegisterCounter("udp.gso.singles", nil), + groBatches: metrics.GetOrRegisterCounter("udp.gro.batches", nil), + groSegments: metrics.GetOrRegisterCounter("udp.gro.segments", nil), + gsoFallbacks: metrics.GetOrRegisterCounter("udp.gso.fallbacks", nil), + gsoFallbackReasons: make(map[string]*atomic.Int64), + } + conn.ioUringHoldoff.Store(int64(ioUringDefaultHoldoff)) + conn.ioUringMaxBatch.Store(int64(ioUringDefaultMaxBatch)) + conn.setGroBufferSize(defaultGROReadBufferSize) + conn.initSendShards() + return conn, err } func (u *StdConn) Rebind() error { @@ -121,77 +1401,361 @@ func (u *StdConn) LocalAddr() (netip.AddrPort, error) { func (u *StdConn) ListenOut(r EncReader) { var ip netip.Addr - msgs, buffers, names := u.PrepareRawMessages(u.batch) + // Check if io_uring receive ring is available + recvRing := u.ioRecvState.Load() + useIoUringRecv := recvRing != nil && u.ioRecvActive.Load() + + u.l.WithFields(logrus.Fields{ + "batch": u.batch, + "io_uring_send": u.ioState.Load() != nil, + "io_uring_recv": useIoUringRecv, + }).Info("ListenOut starting") + + if useIoUringRecv { + // Use dedicated io_uring receive ring + u.l.Info("ListenOut: using io_uring receive path") + + // Pre-fill the receive queue now that we're ready to receive + if err := recvRing.fillRecvQueue(); err != nil { + u.l.WithError(err).Error("Failed to fill receive queue") + return + } + + for { + // Receive packets from io_uring (wait=true blocks until at least one packet arrives) + packets, err := recvRing.receivePackets(true) + if err != nil { + u.l.WithError(err).Error("io_uring receive failed") + return + } + + if len(packets) > 0 && u.l.IsLevelEnabled(logrus.DebugLevel) { + totalBytes := 0 + groPackets := 0 + groSegments := 0 + for i := range packets { + totalBytes += packets[i].N + if packets[i].Controllen > 0 { + if _, segCount := u.parseGROSegmentFromControl(packets[i].Control, packets[i].Controllen); segCount > 1 { + groPackets++ + groSegments += segCount + } + } + } + fields := logrus.Fields{ + "entry_count": len(packets), + "payload_bytes": totalBytes, + } + if groPackets > 0 { + fields["gro_packets"] = groPackets + fields["gro_segments"] = groSegments + } + u.l.WithFields(fields).Debug("io_uring recv batch") + } + + for _, pkt := range packets { + // Extract address from RawSockaddrInet6 + if pkt.From.Family != unix.AF_INET6 { + u.l.WithField("family", pkt.From.Family).Warn("Received packet with unexpected address family") + continue + } + + ip, _ = netip.AddrFromSlice(pkt.From.Addr[:]) + addr := netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16((*[2]byte)(unsafe.Pointer(&pkt.From.Port))[:])) + payload := pkt.Data[:pkt.N] + release := pkt.RecycleFunc + released := false + releaseOnce := func() { + if !released { + released = true + release() + } + } + + // Check for GRO segments + handled := false + if pkt.Controllen > 0 && len(pkt.Control) > 0 { + if segSize, segCount := u.parseGROSegmentFromControl(pkt.Control, pkt.Controllen); segSize > 0 && segSize < pkt.N { + if segCount > 1 && u.l.IsLevelEnabled(logrus.DebugLevel) { + u.l.WithFields(logrus.Fields{ + "segments": segCount, + "segment_size": segSize, + "batch_bytes": pkt.N, + "remote_addr": addr.String(), + }).Debug("gro batch received") + } + if u.emitSegments(r, addr, payload, segSize, segCount, releaseOnce) { + handled = true + } else if segCount > 1 { + u.l.WithFields(logrus.Fields{ + "tag": "gro-debug", + "stage": "io_uring_recv", + "reason": "emit_failed", + "payload_len": pkt.N, + "seg_size": segSize, + "seg_count": segCount, + }).Debug("gro-debug fallback to single packet") + } + } + } + + if !handled { + r(addr, payload, releaseOnce) + } + } + } + } + + // Fallback path: use standard recvmsg + u.l.Info("ListenOut: using standard recvmsg path") + msgs, buffers, names, controls := u.PrepareRawMessages(u.batch) read := u.ReadMulti if u.batch == 1 { read = u.ReadSingle } + u.l.WithFields(logrus.Fields{ + "using_ReadSingle": u.batch == 1, + "using_ReadMulti": u.batch != 1, + }).Info("ListenOut read function selected") + for { - n, err := read(msgs) - if err != nil { - u.l.WithError(err).Debug("udp socket is closed, exiting read loop") - return + desiredGroSize := int(u.groBufSize.Load()) + if desiredGroSize < MTU { + desiredGroSize = MTU + } + if len(buffers) == 0 || cap(buffers[0]) < desiredGroSize { + u.recycleBufferSet(buffers) + msgs, buffers, names, controls = u.PrepareRawMessages(u.batch) + } + desiredControl := int(u.controlLen.Load()) + hasControl := len(controls) > 0 + if (desiredControl > 0) != hasControl || (desiredControl > 0 && hasControl && len(controls[0]) != desiredControl) { + u.recycleBufferSet(buffers) + msgs, buffers, names, controls = u.PrepareRawMessages(u.batch) + hasControl = len(controls) > 0 } + if hasControl { + for i := range msgs { + if len(controls) <= i || len(controls[i]) == 0 { + continue + } + msgs[i].Hdr.Controllen = controllen(len(controls[i])) + } + } + + u.l.Debug("ListenOut: about to call read(msgs)") + n, err := read(msgs) + if err != nil { + u.l.WithError(err).Error("ListenOut: read(msgs) failed, exiting read loop") + u.recycleBufferSet(buffers) + return + } + u.l.WithField("packets_read", n).Debug("ListenOut: read(msgs) returned") + for i := 0; i < n; i++ { + payloadLen := int(msgs[i].Len) + if payloadLen == 0 { + continue + } + // Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic if u.isV4 { ip, _ = netip.AddrFromSlice(names[i][4:8]) } else { ip, _ = netip.AddrFromSlice(names[i][8:24]) } - r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), buffers[i][:msgs[i].Len]) + addr := netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])) + buf := buffers[i] + payload := buf[:payloadLen] + released := false + release := func() { + if !released { + released = true + u.recycleBuffer(buf) + } + } + handled := false + + if len(controls) > i && len(controls[i]) > 0 { + if segSize, segCount := u.parseGROSegment(&msgs[i], controls[i]); segSize > 0 && segSize < payloadLen { + if segCount > 1 && u.l.IsLevelEnabled(logrus.DebugLevel) { + u.l.WithFields(logrus.Fields{ + "segments": segCount, + "segment_size": segSize, + "batch_bytes": payloadLen, + "remote_addr": addr.String(), + }).Debug("gro batch received") + } + if u.emitSegments(r, addr, payload, segSize, segCount, release) { + handled = true + } else if segCount > 1 { + u.l.WithFields(logrus.Fields{ + "tag": "gro-debug", + "stage": "listen_out", + "reason": "emit_failed", + "payload_len": payloadLen, + "seg_size": segSize, + "seg_count": segCount, + }).Debug("gro-debug fallback to single packet") + } + } + } + + if !handled { + r(addr, payload, release) + } + + buffers[i] = u.borrowRxBuffer(desiredGroSize) + setIovecBase(&msgs[i], buffers[i]) } } } +func isEAgain(err error) bool { + if err == nil { + return false + } + var opErr *net.OpError + if errors.As(err, &opErr) { + if errno, ok := opErr.Err.(syscall.Errno); ok { + return errno == unix.EAGAIN || errno == unix.EWOULDBLOCK + } + } + if errno, ok := err.(syscall.Errno); ok { + return errno == unix.EAGAIN || errno == unix.EWOULDBLOCK + } + return false +} func (u *StdConn) ReadSingle(msgs []rawMessage) (int, error) { - for { - n, _, err := unix.Syscall6( - unix.SYS_RECVMSG, - uintptr(u.sysFd), - uintptr(unsafe.Pointer(&(msgs[0].Hdr))), - 0, - 0, - 0, - 0, - ) - - if err != 0 { - return 0, &net.OpError{Op: "recvmsg", Err: err} - } - - msgs[0].Len = uint32(n) - return 1, nil + if len(msgs) == 0 { + return 0, nil } + + u.l.Debug("ReadSingle called") + + state := u.ioState.Load() + if state == nil { + u.l.Error("ReadSingle: io_uring not initialized") + return 0, &net.OpError{Op: "recvmsg", Err: errors.New("io_uring not initialized")} + } + + u.l.Debug("ReadSingle: converting rawMessage to unix.Msghdr") + hdr, iov, err := rawMessageToUnixMsghdr(&msgs[0]) + if err != nil { + u.l.WithError(err).Error("ReadSingle: rawMessageToUnixMsghdr failed") + return 0, &net.OpError{Op: "recvmsg", Err: err} + } + + u.l.WithFields(logrus.Fields{ + "bufLen": iov.Len, + "nameLen": hdr.Namelen, + "ctrlLen": hdr.Controllen, + }).Debug("ReadSingle: calling state.Recvmsg") + + n, _, recvErr := state.Recvmsg(u.sysFd, &hdr, 0) + if recvErr != nil { + u.l.WithError(recvErr).Error("ReadSingle: state.Recvmsg failed") + return 0, recvErr + } + + u.l.WithFields(logrus.Fields{ + "bytesRead": n, + }).Debug("ReadSingle: successfully received") + + updateRawMessageFromUnixMsghdr(&msgs[0], &hdr, n) + runtime.KeepAlive(iov) + runtime.KeepAlive(hdr) + return 1, nil } func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) { - for { - n, _, err := unix.Syscall6( - unix.SYS_RECVMMSG, - uintptr(u.sysFd), - uintptr(unsafe.Pointer(&msgs[0])), - uintptr(len(msgs)), - unix.MSG_WAITFORONE, - 0, - 0, - ) + if len(msgs) == 0 { + return 0, nil + } - if err != 0 { - return 0, &net.OpError{Op: "recvmmsg", Err: err} + u.l.WithField("batch_size", len(msgs)).Debug("ReadMulti called") + + state := u.ioState.Load() + if state == nil { + u.l.Error("ReadMulti: io_uring not initialized") + return 0, &net.OpError{Op: "recvmsg", Err: errors.New("io_uring not initialized")} + } + + count := 0 + for i := range msgs { + hdr, iov, err := rawMessageToUnixMsghdr(&msgs[i]) + if err != nil { + u.l.WithError(err).WithField("index", i).Error("ReadMulti: rawMessageToUnixMsghdr failed") + if count > 0 { + return count, nil + } + return 0, &net.OpError{Op: "recvmsg", Err: err} } - return int(n), nil + flags := uint32(0) + if i > 0 { + flags = unix.MSG_DONTWAIT + } + + u.l.WithFields(logrus.Fields{ + "index": i, + "flags": flags, + "bufLen": iov.Len, + }).Debug("ReadMulti: calling state.Recvmsg") + + n, _, recvErr := state.Recvmsg(u.sysFd, &hdr, flags) + if recvErr != nil { + u.l.WithError(recvErr).WithFields(logrus.Fields{ + "index": i, + "count": count, + }).Debug("ReadMulti: state.Recvmsg error") + if isEAgain(recvErr) && count > 0 { + u.l.WithField("count", count).Debug("ReadMulti: EAGAIN with existing packets, returning") + return count, nil + } + if count > 0 { + return count, recvErr + } + return 0, recvErr + } + + u.l.WithFields(logrus.Fields{ + "index": i, + "bytesRead": n, + }).Debug("ReadMulti: packet received") + + updateRawMessageFromUnixMsghdr(&msgs[i], &hdr, n) + runtime.KeepAlive(iov) + runtime.KeepAlive(hdr) + count++ } + + u.l.WithField("total_count", count).Debug("ReadMulti: completed") + return count, nil } func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error { - if u.isV4 { - return u.writeTo4(b, ip) + if len(b) == 0 { + return nil } - return u.writeTo6(b, ip) + if u.ioClosing.Load() { + return &net.OpError{Op: "sendmsg", Err: net.ErrClosed} + } + if u.enableGSO { + return u.writeToGSO(b, ip) + } + if u.ioState.Load() != nil { + if shard := u.selectSendShard(ip); shard != nil { + if err := shard.enqueueImmediate(b, ip); err != nil { + return err + } + return nil + } + } + u.recordGSOSingle(1) + return u.directWrite(b, ip) } func (u *StdConn) writeTo6(b []byte, ip netip.AddrPort) error { @@ -248,6 +1812,900 @@ func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error { } } +func (u *StdConn) writeToGSO(b []byte, addr netip.AddrPort) error { + if len(b) == 0 { + return nil + } + shard := u.selectSendShard(addr) + if shard == nil { + u.recordGSOSingle(1) + return u.directWrite(b, addr) + } + return shard.write(b, addr) +} + +func (u *StdConn) sendMsgIOUring(state *ioUringState, addr netip.AddrPort, payload []byte, control []byte, msgFlags uint32) (int, error) { + if state == nil { + return 0, &net.OpError{Op: "sendmsg", Err: syscall.EINVAL} + } + if len(payload) == 0 { + return 0, nil + } + if !addr.IsValid() { + return 0, &net.OpError{Op: "sendmsg", Err: unix.EINVAL} + } + if !u.ioAttempted.Load() { + u.ioAttempted.Store(true) + u.l.WithFields(logrus.Fields{ + "addr": addr.String(), + "len": len(payload), + "ctrl": control != nil, + }).Debug("io_uring send attempt") + } + u.l.WithFields(logrus.Fields{ + "addr": addr.String(), + "len": len(payload), + "ctrl": control != nil, + }).Debug("io_uring sendMsgIOUring invoked") + + var iov unix.Iovec + iov.Base = &payload[0] + setIovecLen(&iov, len(payload)) + + var msg unix.Msghdr + msg.Iov = &iov + setMsghdrIovlen(&msg, 1) + + if len(control) > 0 { + msg.Control = &control[0] + msg.Controllen = controllen(len(control)) + } + + u.l.WithFields(logrus.Fields{ + "addr": addr.String(), + "payload_len": len(payload), + "ctrl_len": len(control), + "msg_iovlen": msg.Iovlen, + "msg_controllen": msg.Controllen, + }).Debug("io_uring prepared msghdr") + + var ( + n int + err error + ) + + if u.isV4 { + if !addr.Addr().Is4() { + return 0, ErrInvalidIPv6RemoteForSocket + } + var sa4 unix.RawSockaddrInet4 + sa4.Family = unix.AF_INET + sa4.Addr = addr.Addr().As4() + binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa4.Port))[:], addr.Port()) + msg.Name = (*byte)(unsafe.Pointer(&sa4)) + msg.Namelen = uint32(unsafe.Sizeof(sa4)) + u.l.WithFields(logrus.Fields{ + "addr": addr.String(), + "sa_family": sa4.Family, + "sa_port": sa4.Port, + "msg_namelen": msg.Namelen, + }).Debug("io_uring sendmsg sockaddr v4") + n, err = state.Sendmsg(u.sysFd, &msg, msgFlags, uint32(len(payload))) + runtime.KeepAlive(sa4) + } else { + // For IPv6 sockets, always use RawSockaddrInet6, even for IPv4 addresses + // (convert IPv4 to IPv4-mapped IPv6 format) + var sa6 unix.RawSockaddrInet6 + u.populateSockaddrInet6(&sa6, addr.Addr()) + binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa6.Port))[:], addr.Port()) + msg.Name = (*byte)(unsafe.Pointer(&sa6)) + msg.Namelen = uint32(unsafe.Sizeof(sa6)) + u.l.WithFields(logrus.Fields{ + "addr": addr.String(), + "sa_family": sa6.Family, + "sa_port": sa6.Port, + "scope_id": sa6.Scope_id, + "msg_namelen": msg.Namelen, + "is_v4": addr.Addr().Is4(), + }).Debug("io_uring sendmsg sockaddr v6") + n, err = state.Sendmsg(u.sysFd, &msg, msgFlags, uint32(len(payload))) + runtime.KeepAlive(sa6) + } + + if err == nil && n == len(payload) { + u.noteIoUringSuccess() + } + runtime.KeepAlive(payload) + runtime.KeepAlive(control) + u.logIoUringResult(addr, len(payload), n, err) + if err == nil && n == 0 && len(payload) > 0 { + syncWritten, syncErr := u.sendMsgSync(addr, payload, control, int(msgFlags)) + if syncErr == nil && syncWritten == len(payload) { + u.l.WithFields(logrus.Fields{ + "addr": addr.String(), + "expected": len(payload), + "sync_written": syncWritten, + }).Warn("io_uring returned short write; used synchronous sendmsg fallback") + u.noteIoUringSuccess() + u.logIoUringResult(addr, len(payload), syncWritten, syncErr) + return syncWritten, nil + } + u.l.WithFields(logrus.Fields{ + "addr": addr.String(), + "expected": len(payload), + "sync_written": syncWritten, + "sync_err": syncErr, + }).Warn("sync sendmsg result after io_uring short write") + } + return n, err +} + +func (u *StdConn) sendMsgIOUringBatch(state *ioUringState, items []*batchSendItem) error { + if u.ioClosing.Load() { + for _, item := range items { + if item != nil { + item.err = &net.OpError{Op: "sendmsg", Err: net.ErrClosed} + } + } + return &net.OpError{Op: "sendmsg", Err: net.ErrClosed} + } + if state == nil { + return &net.OpError{Op: "sendmsg", Err: syscall.EINVAL} + } + if len(items) == 0 { + return nil + } + + results := make([]ioUringBatchResult, len(items)) + payloads := make([][]byte, len(items)) + controls := make([][]byte, len(items)) + entries := make([]ioUringBatchEntry, len(items)) + msgs := make([]unix.Msghdr, len(items)) + iovecs := make([]unix.Iovec, len(items)) + var sa4 []unix.RawSockaddrInet4 + var sa6 []unix.RawSockaddrInet6 + if u.isV4 { + sa4 = make([]unix.RawSockaddrInet4, len(items)) + } else { + sa6 = make([]unix.RawSockaddrInet6, len(items)) + } + + entryIdx := 0 + totalPayload := 0 + skipped := 0 + for i, item := range items { + if item == nil || len(item.payload) == 0 { + item.resultBytes = 0 + item.err = nil + skipped++ + continue + } + + addr := item.addr + if !addr.IsValid() { + item.err = &net.OpError{Op: "sendmsg", Err: unix.EINVAL} + skipped++ + continue + } + if u.isV4 && !addr.Addr().Is4() { + item.err = ErrInvalidIPv6RemoteForSocket + skipped++ + continue + } + + payload := item.payload + payloads[i] = payload + totalPayload += len(payload) + + iov := &iovecs[entryIdx] + iov.Base = &payload[0] + setIovecLen(iov, len(payload)) + + msg := &msgs[entryIdx] + msg.Iov = iov + setMsghdrIovlen(msg, 1) + + if len(item.control) > 0 { + controls[i] = item.control + msg.Control = &item.control[0] + msg.Controllen = controllen(len(item.control)) + } + + if u.isV4 { + sa := &sa4[entryIdx] + sa.Family = unix.AF_INET + sa.Addr = addr.Addr().As4() + binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port()) + msg.Name = (*byte)(unsafe.Pointer(sa)) + msg.Namelen = uint32(unsafe.Sizeof(*sa)) + } else { + sa := &sa6[entryIdx] + sa.Family = unix.AF_INET6 + u.populateSockaddrInet6(sa, addr.Addr()) + binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa.Port))[:], addr.Port()) + msg.Name = (*byte)(unsafe.Pointer(sa)) + msg.Namelen = uint32(unsafe.Sizeof(*sa)) + } + + entries[entryIdx] = ioUringBatchEntry{ + fd: u.sysFd, + msg: msg, + msgFlags: item.msgFlags, + payloadLen: uint32(len(payload)), + result: &results[i], + } + entryIdx++ + } + + if entryIdx == 0 { + for _, payload := range payloads { + runtime.KeepAlive(payload) + } + for _, control := range controls { + runtime.KeepAlive(control) + } + var firstErr error + for _, item := range items { + if item != nil && item.err != nil { + firstErr = item.err + break + } + } + return firstErr + } + + if err := ioUringSendmsgBatch(state, entries[:entryIdx]); err != nil { + for _, payload := range payloads { + runtime.KeepAlive(payload) + } + for _, control := range controls { + runtime.KeepAlive(control) + } + if len(sa4) > 0 { + runtime.KeepAlive(sa4[:entryIdx]) + } + if len(sa6) > 0 { + runtime.KeepAlive(sa6[:entryIdx]) + } + return err + } + + if u.l.IsLevelEnabled(logrus.DebugLevel) { + u.l.WithFields(logrus.Fields{ + "entry_count": entryIdx, + "skipped_items": skipped, + "payload_bytes": totalPayload, + }).Debug("io_uring batch submitted") + } + + var firstErr error + for i, item := range items { + if item == nil || len(item.payload) == 0 { + continue + } + if item.err != nil { + if firstErr == nil { + firstErr = item.err + } + continue + } + + res := results[i] + if res.err != nil { + item.err = res.err + } else if res.res < 0 { + item.err = syscall.Errno(-res.res) + } else if int(res.res) != len(item.payload) { + item.err = fmt.Errorf("io_uring short write: wrote %d expected %d", res.res, len(item.payload)) + } else { + item.err = nil + item.resultBytes = int(res.res) + } + + u.logIoUringResult(item.addr, len(item.payload), int(res.res), item.err) + if item.err != nil && firstErr == nil { + firstErr = item.err + } + } + + for _, payload := range payloads { + runtime.KeepAlive(payload) + } + for _, control := range controls { + runtime.KeepAlive(control) + } + if len(sa4) > 0 { + runtime.KeepAlive(sa4[:entryIdx]) + } + if len(sa6) > 0 { + runtime.KeepAlive(sa6[:entryIdx]) + } + + if firstErr == nil { + u.noteIoUringSuccess() + } + + return firstErr +} + +func (u *StdConn) sendMsgSync(addr netip.AddrPort, payload []byte, control []byte, msgFlags int) (int, error) { + if len(payload) == 0 { + return 0, nil + } + if u.isV4 { + if !addr.Addr().Is4() { + return 0, ErrInvalidIPv6RemoteForSocket + } + sa := &unix.SockaddrInet4{Port: int(addr.Port())} + sa.Addr = addr.Addr().As4() + return unix.SendmsgN(u.sysFd, payload, control, sa, msgFlags) + } + sa := &unix.SockaddrInet6{Port: int(addr.Port())} + if addr.Addr().Is4() { + sa.Addr = toIPv4Mapped(addr.Addr().As4()) + } else { + sa.Addr = addr.Addr().As16() + } + if zone := addr.Addr().Zone(); zone != "" { + if iface, err := net.InterfaceByName(zone); err == nil { + sa.ZoneId = uint32(iface.Index) + } else { + u.l.WithFields(logrus.Fields{ + "addr": addr.Addr().String(), + "zone": zone, + }).WithError(err).Debug("io_uring failed to resolve IPv6 zone") + } + } + return unix.SendmsgN(u.sysFd, payload, control, sa, msgFlags) +} + +func (u *StdConn) directWrite(b []byte, addr netip.AddrPort) error { + if len(b) == 0 { + return nil + } + if !addr.IsValid() { + return &net.OpError{Op: "sendmsg", Err: unix.EINVAL} + } + state := u.ioState.Load() + u.l.WithFields(logrus.Fields{ + "addr": addr.String(), + "len": len(b), + "state_nil": state == nil, + "socket_v4": u.isV4, + "remote_is_v4": addr.Addr().Is4(), + "remote_is_v6": addr.Addr().Is6(), + }).Debug("io_uring directWrite invoked") + if state == nil { + return errors.New("io_uring state unavailable") + } + n, err := u.sendMsgIOUring(state, addr, b, nil, 0) + if err != nil { + return err + } + if n != len(b) { + return fmt.Errorf("io_uring short write: wrote %d expected %d", n, len(b)) + } + return nil +} + +func (u *StdConn) noteIoUringSuccess() { + if u == nil { + return + } + if u.ioActive.Load() { + return + } + if u.ioActive.CompareAndSwap(false, true) { + u.l.Debug("io_uring send path active") + } +} + +func (u *StdConn) logIoUringResult(addr netip.AddrPort, expected, written int, err error) { + if u == nil { + return + } + u.l.WithFields(logrus.Fields{ + "addr": addr.String(), + "expected": expected, + "written": written, + "err": err, + "socket_v4": u.isV4, + "remote_is_v4": addr.Addr().Is4(), + "remote_is_v6": addr.Addr().Is6(), + }).Debug("io_uring send result") +} + +func (u *StdConn) emitSegments(r EncReader, addr netip.AddrPort, payload []byte, segSize, segCount int, release func()) bool { + if segSize <= 0 || segSize >= len(payload) { + u.l.WithFields(logrus.Fields{ + "tag": "gro-debug", + "stage": "emit", + "reason": "invalid_seg_size", + "payload_len": len(payload), + "seg_size": segSize, + "seg_count": segCount, + }).Debug("gro-debug skip emit") + return false + } + + totalLen := len(payload) + if segCount <= 0 { + segCount = (totalLen + segSize - 1) / segSize + } + if segCount <= 1 { + u.l.WithFields(logrus.Fields{ + "tag": "gro-debug", + "stage": "emit", + "reason": "single_segment", + "payload_len": totalLen, + "seg_size": segSize, + "seg_count": segCount, + }).Debug("gro-debug skip emit") + return false + } + + defer func() { + if release != nil { + release() + } + }() + + actualSegments := 0 + start := 0 + debugEnabled := u.l.IsLevelEnabled(logrus.DebugLevel) + var firstHeader header.H + var firstParsed bool + var firstCounter uint64 + var firstRemote uint32 + + for start < totalLen && actualSegments < segCount { + end := start + segSize + if end > totalLen { + end = totalLen + } + + segLen := end - start + bufAny := u.groSegmentPool.Get() + var segBuf []byte + if bufAny == nil { + segBuf = make([]byte, segLen) + } else { + segBuf = bufAny.([]byte) + if cap(segBuf) < segLen { + segBuf = make([]byte, segLen) + } + } + segment := segBuf[:segLen] + copy(segment, payload[start:end]) + + if debugEnabled && !firstParsed { + if err := firstHeader.Parse(segment); err == nil { + firstParsed = true + firstCounter = firstHeader.MessageCounter + firstRemote = firstHeader.RemoteIndex + } else { + u.l.WithFields(logrus.Fields{ + "tag": "gro-debug", + "stage": "emit", + "event": "parse_fail", + "seg_index": actualSegments, + "seg_size": segSize, + "seg_count": segCount, + "payload_len": totalLen, + "err": err, + }).Debug("gro-debug segment parse failed") + } + } + + start = end + actualSegments++ + r(addr, segment, func() { + u.groSegmentPool.Put(segBuf[:cap(segBuf)]) + }) + + if debugEnabled && actualSegments == segCount && segLen < segSize { + var tail header.H + if err := tail.Parse(segment); err == nil { + u.l.WithFields(logrus.Fields{ + "tag": "gro-debug", + "stage": "emit", + "event": "tail_segment", + "segment_len": segLen, + "remote_index": tail.RemoteIndex, + "message_counter": tail.MessageCounter, + }).Debug("gro-debug tail segment metadata") + } + } + + } + + if u.groBatches != nil { + u.groBatches.Inc(1) + } + if u.groSegments != nil { + u.groSegments.Inc(int64(actualSegments)) + } + u.groBatchTick.Add(1) + u.groSegmentsTick.Add(int64(actualSegments)) + + if debugEnabled && actualSegments > 0 { + lastLen := segSize + if tail := totalLen % segSize; tail != 0 { + lastLen = tail + } + u.l.WithFields(logrus.Fields{ + "tag": "gro-debug", + "stage": "emit", + "event": "success", + "payload_len": totalLen, + "seg_size": segSize, + "seg_count": segCount, + "actual_segs": actualSegments, + "last_seg_len": lastLen, + "addr": addr.String(), + "first_remote": firstRemote, + "first_counter": firstCounter, + }).Debug("gro-debug emit") + } + + return true +} + +func (u *StdConn) parseGROSegment(msg *rawMessage, control []byte) (int, int) { + ctrlLen := int(msg.Hdr.Controllen) + if ctrlLen <= 0 { + return 0, 0 + } + if ctrlLen > len(control) { + ctrlLen = len(control) + } + return u.parseGROSegmentFromControl(control, ctrlLen) +} + +func (u *StdConn) parseGROSegmentFromControl(control []byte, ctrlLen int) (int, int) { + if ctrlLen <= 0 { + return 0, 0 + } + if ctrlLen > len(control) { + ctrlLen = len(control) + } + + cmsgs, err := unix.ParseSocketControlMessage(control[:ctrlLen]) + if err != nil { + u.l.WithError(err).Debug("failed to parse UDP GRO control message") + return 0, 0 + } + + for _, c := range cmsgs { + if c.Header.Level == unix.SOL_UDP && c.Header.Type == unix.UDP_GRO && len(c.Data) >= 2 { + segSize := int(binary.NativeEndian.Uint16(c.Data[:2])) + segCount := 0 + if len(c.Data) >= 4 { + segCount = int(binary.NativeEndian.Uint16(c.Data[2:4])) + } + u.l.WithFields(logrus.Fields{ + "tag": "gro-debug", + "stage": "parse", + "seg_size": segSize, + "seg_count": segCount, + }).Debug("gro-debug control parsed") + return segSize, segCount + } + } + + return 0, 0 +} + +func (u *StdConn) configureIOUring(enable bool, c *config.C) { + if enable { + if u.ioState.Load() != nil { + return + } + + // Serialize io_uring initialization globally to avoid kernel resource races + ioUringInitMu.Lock() + defer ioUringInitMu.Unlock() + + var configured uint32 + requestedBatch := ioUringDefaultMaxBatch + if c != nil { + entries := c.GetInt("listen.io_uring_entries", 0) + if entries < 0 { + entries = 0 + } + configured = uint32(entries) + holdoff := c.GetDuration("listen.io_uring_batch_holdoff", -1) + if holdoff < 0 { + holdoffVal := c.GetInt("listen.io_uring_batch_holdoff", int(ioUringDefaultHoldoff/time.Microsecond)) + holdoff = time.Duration(holdoffVal) * time.Microsecond + } + if holdoff < ioUringMinHoldoff { + holdoff = ioUringMinHoldoff + } + if holdoff > ioUringMaxHoldoff { + holdoff = ioUringMaxHoldoff + } + u.ioUringHoldoff.Store(int64(holdoff)) + requestedBatch = clampIoUringBatchSize(c.GetInt("listen.io_uring_max_batch", ioUringDefaultMaxBatch), 0) + } else { + u.ioUringHoldoff.Store(int64(ioUringDefaultHoldoff)) + requestedBatch = ioUringDefaultMaxBatch + } + if !u.enableGSO { + if len(u.sendShards) != 1 { + u.resizeSendShards(1) + } + } + u.ioUringMaxBatch.Store(int64(requestedBatch)) + ring, err := newIoUringState(configured) + if err != nil { + u.l.WithError(err).Warn("Failed to enable io_uring; falling back to sendmmsg path") + return + } + u.ioState.Store(ring) + finalBatch := clampIoUringBatchSize(requestedBatch, ring.sqEntryCount) + u.ioUringMaxBatch.Store(int64(finalBatch)) + fields := logrus.Fields{ + "entries": ring.sqEntryCount, + "max_batch": finalBatch, + } + if finalBatch != requestedBatch { + fields["requested_batch"] = requestedBatch + } + u.l.WithFields(fields).Debug("io_uring ioState pointer initialized") + desired := configured + if desired == 0 { + desired = defaultIoUringEntries + } + if ring.sqEntryCount < desired { + fields["requested_entries"] = desired + u.l.WithFields(fields).Warn("UDP io_uring send path enabled with reduced queue depth (ENOMEM)") + } else { + u.l.WithFields(fields).Debug("UDP io_uring send path enabled") + } + + // Initialize dedicated receive ring with retry logic + recvPoolSize := 128 // Number of receive operations to keep queued + recvBufferSize := defaultGROReadBufferSize + if recvBufferSize < MTU { + recvBufferSize = MTU + } + + var recvRing *ioUringRecvState + maxRetries := 10 + retryDelay := 10 * time.Millisecond + + for attempt := 0; attempt < maxRetries; attempt++ { + var err error + recvRing, err = newIoUringRecvState(u.sysFd, configured, recvPoolSize, recvBufferSize) + if err == nil { + break + } + + if attempt < maxRetries-1 { + u.l.WithFields(logrus.Fields{ + "attempt": attempt + 1, + "error": err, + "delay": retryDelay, + }).Warn("Failed to create io_uring receive ring, retrying") + time.Sleep(retryDelay) + retryDelay *= 2 // Exponential backoff + } else { + u.l.WithError(err).Error("Failed to create io_uring receive ring after retries; will use standard recvmsg") + } + } + + if recvRing != nil { + u.ioRecvState.Store(recvRing) + u.ioRecvActive.Store(true) + u.l.WithFields(logrus.Fields{ + "entries": recvRing.sqEntryCount, + "poolSize": recvPoolSize, + "bufferSize": recvBufferSize, + }).Info("UDP io_uring receive path enabled") + // Note: receive queue will be filled on first receivePackets() call + } + + return + } + + if c != nil { + if u.ioState.Load() != nil { + u.l.Warn("Runtime disabling of io_uring is not supported; keeping existing ring active until shutdown") + } + holdoff := c.GetDuration("listen.io_uring_batch_holdoff", -1) + if holdoff < 0 { + holdoffVal := c.GetInt("listen.io_uring_batch_holdoff", int(ioUringDefaultHoldoff/time.Microsecond)) + holdoff = time.Duration(holdoffVal) * time.Microsecond + } + if holdoff < ioUringMinHoldoff { + holdoff = ioUringMinHoldoff + } + if holdoff > ioUringMaxHoldoff { + holdoff = ioUringMaxHoldoff + } + u.ioUringHoldoff.Store(int64(holdoff)) + requestedBatch := clampIoUringBatchSize(c.GetInt("listen.io_uring_max_batch", ioUringDefaultMaxBatch), 0) + if ring := u.ioState.Load(); ring != nil { + requestedBatch = clampIoUringBatchSize(requestedBatch, ring.sqEntryCount) + } + u.ioUringMaxBatch.Store(int64(requestedBatch)) + if !u.enableGSO { + // io_uring uses a single shared ring with a global mutex, + // so multiple shards cause severe lock contention. + // Force 1 shard for optimal io_uring batching performance. + if ring := u.ioState.Load(); ring != nil { + if len(u.sendShards) != 1 { + u.resizeSendShards(1) + } + } else { + // No io_uring, allow config override + shards := c.GetInt("listen.send_shards", 0) + if shards <= 0 { + shards = 1 + } + if len(u.sendShards) != shards { + u.resizeSendShards(shards) + } + } + } + } +} + +func (u *StdConn) disableIOUring(reason error) { + if ring := u.ioState.Swap(nil); ring != nil { + if err := ring.Close(); err != nil { + u.l.WithError(err).Warn("Failed to close io_uring state during disable") + } + if reason != nil { + u.l.WithError(reason).Warn("Disabling io_uring send/receive path; falling back to sendmmsg/recvmmsg") + } else { + u.l.Warn("Disabling io_uring send/receive path; falling back to sendmmsg/recvmmsg") + } + } +} + +func (u *StdConn) configureGRO(enable bool) { + if enable == u.enableGRO { + if enable { + u.controlLen.Store(int32(unix.CmsgSpace(2))) + } else { + u.controlLen.Store(0) + } + return + } + + if enable { + if err := unix.SetsockoptInt(u.sysFd, unix.SOL_UDP, unix.UDP_GRO, 1); err != nil { + u.l.WithError(err).Warn("Failed to enable UDP GRO") + u.enableGRO = false + u.controlLen.Store(0) + return + } + u.enableGRO = true + u.controlLen.Store(int32(unix.CmsgSpace(2))) + u.l.Info("UDP GRO enabled") + } else { + if u.enableGRO { + if err := unix.SetsockoptInt(u.sysFd, unix.SOL_UDP, unix.UDP_GRO, 0); err != nil { + u.l.WithError(err).Warn("Failed to disable UDP GRO") + } + } + u.enableGRO = false + u.controlLen.Store(0) + } +} + +func (u *StdConn) configureGSO(enable bool, c *config.C) { + if len(u.sendShards) == 0 { + u.initSendShards() + } + desiredShards := 0 + if c != nil { + desiredShards = c.GetInt("listen.send_shards", 0) + } + + // io_uring requires 1 shard due to shared ring mutex contention + if u.ioState.Load() != nil { + if desiredShards > 1 { + u.l.WithField("requested_shards", desiredShards).Warn("listen.send_shards ignored because io_uring is enabled; forcing 1 send shard") + } + desiredShards = 1 + } else if !enable { + if c != nil && desiredShards > 1 { + u.l.WithField("requested_shards", desiredShards).Warn("listen.send_shards ignored because UDP GSO is disabled; forcing 1 send shard") + } + desiredShards = 1 + } + + // Only resize if actually changing shard count + if len(u.sendShards) != desiredShards { + u.resizeSendShards(desiredShards) + } + + if !enable { + if u.enableGSO { + for _, shard := range u.sendShards { + shard.mu.Lock() + if shard.pendingSegments > 0 { + if err := shard.flushPendingLocked(); err != nil { + u.l.WithError(err).Warn("Failed to flush GSO buffers while disabling") + } + } else { + shard.stopFlushTimerLocked() + } + buf := shard.pendingBuf + shard.pendingBuf = nil + shard.mu.Unlock() + if buf != nil { + u.releaseGSOBuf(buf) + } + } + u.enableGSO = false + u.l.Info("UDP GSO disabled") + } + u.setGroBufferSize(defaultGROReadBufferSize) + return + } + + maxSegments := c.GetInt("listen.gso_max_segments", defaultGSOMaxSegments) + if maxSegments < 2 { + maxSegments = 2 + } + + maxBytes := c.GetInt("listen.gso_max_bytes", 0) + if maxBytes <= 0 { + maxBytes = defaultGSOMaxBytes + } + if maxBytes < MTU { + maxBytes = MTU + } + if maxBytes > linuxMaxGSOBatchBytes { + u.l.WithFields(logrus.Fields{ + "configured_bytes": maxBytes, + "clamped_bytes": linuxMaxGSOBatchBytes, + }).Warn("listen.gso_max_bytes exceeds Linux UDP limit; clamping") + maxBytes = linuxMaxGSOBatchBytes + } + + flushTimeout := c.GetDuration("listen.gso_flush_timeout", defaultGSOFlushTimeout) + if flushTimeout < 0 { + flushTimeout = 0 + } + + u.enableGSO = true + u.gsoMaxSegments = maxSegments + u.gsoMaxBytes = maxBytes + u.gsoFlushTimeout = flushTimeout + bufSize := defaultGROReadBufferSize + if u.gsoMaxBytes > bufSize { + bufSize = u.gsoMaxBytes + } + u.setGroBufferSize(bufSize) + + for _, shard := range u.sendShards { + shard.mu.Lock() + if shard.pendingBuf != nil { + u.releaseGSOBuf(shard.pendingBuf) + shard.pendingBuf = nil + } + shard.pendingSegments = 0 + shard.pendingSegSize = 0 + shard.pendingAddr = netip.AddrPort{} + shard.stopFlushTimerLocked() + if len(shard.controlBuf) < unix.CmsgSpace(2) { + shard.controlBuf = make([]byte, unix.CmsgSpace(2)) + } + shard.mu.Unlock() + } + + u.l.WithFields(logrus.Fields{ + "segments": u.gsoMaxSegments, + "bytes": u.gsoMaxBytes, + "flush_timeout": u.gsoFlushTimeout, + }).Info("UDP GSO configured") +} + func (u *StdConn) ReloadConfig(c *config.C) { b := c.GetInt("listen.read_buffer", 0) if b > 0 { @@ -294,6 +2752,10 @@ func (u *StdConn) ReloadConfig(c *config.C) { u.l.WithError(err).Error("Failed to set listen.so_mark") } } + + u.configureIOUring(c.GetBool("listen.use_io_uring", false), c) + u.configureGRO(c.GetBool("listen.enable_gro", false)) + u.configureGSO(c.GetBool("listen.enable_gso", false), c) } func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error { @@ -306,7 +2768,54 @@ func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error { } func (u *StdConn) Close() error { - return syscall.Close(u.sysFd) + if !u.ioClosing.CompareAndSwap(false, true) { + return nil + } + // Attempt to unblock any outstanding sendmsg/sendmmsg calls so the shard + // workers can drain promptly during shutdown. Ignoring errors here is fine + // because some platforms/kernels may not support shutdown on UDP sockets. + if err := unix.Shutdown(u.sysFd, unix.SHUT_RDWR); err != nil && err != unix.ENOTCONN && err != unix.EINVAL && err != unix.EBADF { + u.l.WithError(err).Debug("Failed to shutdown UDP socket for close") + } + + var flushErr error + for _, shard := range u.sendShards { + if shard == nil { + continue + } + shard.mu.Lock() + if shard.pendingSegments > 0 { + if err := shard.flushPendingLocked(); err != nil && flushErr == nil { + flushErr = err + } + } else { + shard.stopFlushTimerLocked() + } + buf := shard.pendingBuf + shard.pendingBuf = nil + shard.mu.Unlock() + if buf != nil { + u.releaseGSOBuf(buf) + } + shard.stopSender() + } + + closeErr := syscall.Close(u.sysFd) + if ring := u.ioState.Swap(nil); ring != nil { + if err := ring.Close(); err != nil && flushErr == nil { + flushErr = err + } + } + if recvRing := u.ioRecvState.Swap(nil); recvRing != nil { + u.ioRecvActive.Store(false) + if err := recvRing.Close(); err != nil && flushErr == nil { + flushErr = err + } + } + if flushErr != nil { + return flushErr + } + return closeErr } func NewUDPStatsEmitter(udpConns []Conn) func() { @@ -330,6 +2839,13 @@ func NewUDPStatsEmitter(udpConns []Conn) func() { } } + var stdConns []*StdConn + for _, conn := range udpConns { + if sc, ok := conn.(*StdConn); ok { + stdConns = append(stdConns, sc) + } + } + return func() { for i, gauges := range udpGauges { if err := udpConns[i].(*StdConn).getMemInfo(&meminfo); err == nil { @@ -338,5 +2854,9 @@ func NewUDPStatsEmitter(udpConns []Conn) func() { } } } + + for _, sc := range stdConns { + sc.logGSOTick() + } } } diff --git a/udp/udp_linux_32.go b/udp/udp_linux_32.go index de8f1cd..3a2b80e 100644 --- a/udp/udp_linux_32.go +++ b/udp/udp_linux_32.go @@ -7,6 +7,9 @@ package udp import ( + "errors" + "fmt" + "golang.org/x/sys/unix" ) @@ -30,17 +33,29 @@ type rawMessage struct { Len uint32 } -func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { +func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte, [][]byte) { + controlLen := int(u.controlLen.Load()) + msgs := make([]rawMessage, n) buffers := make([][]byte, n) names := make([][]byte, n) + var controls [][]byte + if controlLen > 0 { + controls = make([][]byte, n) + } + for i := range msgs { - buffers[i] = make([]byte, MTU) + size := int(u.groBufSize.Load()) + if size < MTU { + size = MTU + } + buf := u.borrowRxBuffer(size) + buffers[i] = buf names[i] = make([]byte, unix.SizeofSockaddrInet6) vs := []iovec{ - {Base: &buffers[i][0], Len: uint32(len(buffers[i]))}, + {Base: &buf[0], Len: uint32(len(buf))}, } msgs[i].Hdr.Iov = &vs[0] @@ -48,7 +63,71 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { msgs[i].Hdr.Name = &names[i][0] msgs[i].Hdr.Namelen = uint32(len(names[i])) + + if controlLen > 0 { + controls[i] = make([]byte, controlLen) + msgs[i].Hdr.Control = &controls[i][0] + msgs[i].Hdr.Controllen = controllen(len(controls[i])) + } else { + msgs[i].Hdr.Control = nil + msgs[i].Hdr.Controllen = controllen(0) + } } - return msgs, buffers, names + return msgs, buffers, names, controls +} + +func setIovecBase(msg *rawMessage, buf []byte) { + iov := (*iovec)(msg.Hdr.Iov) + iov.Base = &buf[0] + iov.Len = uint32(len(buf)) +} + +func rawMessageToUnixMsghdr(msg *rawMessage) (unix.Msghdr, unix.Iovec, error) { + var hdr unix.Msghdr + var iov unix.Iovec + if msg == nil { + return hdr, iov, errors.New("nil rawMessage") + } + if msg.Hdr.Iov == nil || msg.Hdr.Iov.Base == nil { + return hdr, iov, errors.New("rawMessage missing payload buffer") + } + payloadLen := int(msg.Hdr.Iov.Len) + if payloadLen < 0 { + return hdr, iov, fmt.Errorf("invalid payload length: %d", payloadLen) + } + iov.Base = msg.Hdr.Iov.Base + iov.Len = uint32(payloadLen) + hdr.Iov = &iov + hdr.Iovlen = 1 + hdr.Name = msg.Hdr.Name + // CRITICAL: Always set to full buffer size for receive, not what kernel wrote last time + if hdr.Name != nil { + hdr.Namelen = uint32(unix.SizeofSockaddrInet6) + } else { + hdr.Namelen = 0 + } + hdr.Control = msg.Hdr.Control + // CRITICAL: Use the allocated size, not what was previously returned + if hdr.Control != nil { + // Control buffer size is stored in Controllen from PrepareRawMessages + hdr.Controllen = msg.Hdr.Controllen + } else { + hdr.Controllen = 0 + } + hdr.Flags = 0 // Reset flags for new receive + return hdr, iov, nil +} + +func updateRawMessageFromUnixMsghdr(msg *rawMessage, hdr *unix.Msghdr, n int) { + if msg == nil || hdr == nil { + return + } + msg.Hdr.Namelen = hdr.Namelen + msg.Hdr.Controllen = hdr.Controllen + msg.Hdr.Flags = hdr.Flags + if n < 0 { + n = 0 + } + msg.Len = uint32(n) } diff --git a/udp/udp_linux_64.go b/udp/udp_linux_64.go index 48c5a97..ae26eb4 100644 --- a/udp/udp_linux_64.go +++ b/udp/udp_linux_64.go @@ -7,6 +7,9 @@ package udp import ( + "errors" + "fmt" + "golang.org/x/sys/unix" ) @@ -33,25 +36,99 @@ type rawMessage struct { Pad0 [4]byte } -func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { +func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte, [][]byte) { + controlLen := int(u.controlLen.Load()) + msgs := make([]rawMessage, n) buffers := make([][]byte, n) names := make([][]byte, n) + var controls [][]byte + if controlLen > 0 { + controls = make([][]byte, n) + } + for i := range msgs { - buffers[i] = make([]byte, MTU) + size := int(u.groBufSize.Load()) + if size < MTU { + size = MTU + } + buf := u.borrowRxBuffer(size) + buffers[i] = buf names[i] = make([]byte, unix.SizeofSockaddrInet6) - vs := []iovec{ - {Base: &buffers[i][0], Len: uint64(len(buffers[i]))}, - } + vs := []iovec{{Base: &buf[0], Len: uint64(len(buf))}} msgs[i].Hdr.Iov = &vs[0] msgs[i].Hdr.Iovlen = uint64(len(vs)) msgs[i].Hdr.Name = &names[i][0] msgs[i].Hdr.Namelen = uint32(len(names[i])) + + if controlLen > 0 { + controls[i] = make([]byte, controlLen) + msgs[i].Hdr.Control = &controls[i][0] + msgs[i].Hdr.Controllen = controllen(len(controls[i])) + } else { + msgs[i].Hdr.Control = nil + msgs[i].Hdr.Controllen = controllen(0) + } } - return msgs, buffers, names + return msgs, buffers, names, controls +} + +func setIovecBase(msg *rawMessage, buf []byte) { + iov := (*iovec)(msg.Hdr.Iov) + iov.Base = &buf[0] + iov.Len = uint64(len(buf)) +} + +func rawMessageToUnixMsghdr(msg *rawMessage) (unix.Msghdr, unix.Iovec, error) { + var hdr unix.Msghdr + var iov unix.Iovec + if msg == nil { + return hdr, iov, errors.New("nil rawMessage") + } + if msg.Hdr.Iov == nil || msg.Hdr.Iov.Base == nil { + return hdr, iov, errors.New("rawMessage missing payload buffer") + } + payloadLen := int(msg.Hdr.Iov.Len) + if payloadLen < 0 { + return hdr, iov, fmt.Errorf("invalid payload length: %d", payloadLen) + } + iov.Base = msg.Hdr.Iov.Base + iov.Len = uint64(payloadLen) + hdr.Iov = &iov + hdr.Iovlen = 1 + hdr.Name = msg.Hdr.Name + // CRITICAL: Always set to full buffer size for receive, not what kernel wrote last time + if hdr.Name != nil { + hdr.Namelen = uint32(unix.SizeofSockaddrInet6) + } else { + hdr.Namelen = 0 + } + hdr.Control = msg.Hdr.Control + // CRITICAL: Use the allocated size, not what was previously returned + if hdr.Control != nil { + // Control buffer size is stored in Controllen from PrepareRawMessages + hdr.Controllen = msg.Hdr.Controllen + } else { + hdr.Controllen = 0 + } + hdr.Flags = 0 // Reset flags for new receive + return hdr, iov, nil +} + +func updateRawMessageFromUnixMsghdr(msg *rawMessage, hdr *unix.Msghdr, n int) { + if msg == nil || hdr == nil { + return + } + msg.Hdr.Namelen = hdr.Namelen + msg.Hdr.Controllen = hdr.Controllen + msg.Hdr.Flags = hdr.Flags + if n < 0 { + n = 0 + } + msg.Len = uint32(n) } diff --git a/udp/udp_rio_windows.go b/udp/udp_rio_windows.go index 886e024..3c665de 100644 --- a/udp/udp_rio_windows.go +++ b/udp/udp_rio_windows.go @@ -149,7 +149,7 @@ func (u *RIOConn) ListenOut(r EncReader) { continue } - r(netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), buffer[:n]) + r(netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), buffer[:n], nil) } } diff --git a/udp/udp_tester.go b/udp/udp_tester.go index 8d5e6c1..abd45af 100644 --- a/udp/udp_tester.go +++ b/udp/udp_tester.go @@ -112,7 +112,7 @@ func (u *TesterConn) ListenOut(r EncReader) { if !ok { return } - r(p.From, p.Data) + r(p.From, p.Data, func() {}) } }