From 71c849e63eb0a1e3e4f4a8ed7c6c01f12d04ed16 Mon Sep 17 00:00:00 2001 From: Ryan Huber Date: Fri, 31 Oct 2025 13:34:39 -0400 Subject: [PATCH] extra stinky slop that works kinda --- CHANGELOG.md | 18 + connection_state.go | 5 +- interface.go | 366 +++++++++++++- main.go | 7 + outside.go | 8 +- overlay/tun_linux.go | 207 +++++++- udp/conn.go | 1 + udp/udp_darwin.go | 4 +- udp/udp_generic.go | 4 +- udp/udp_linux.go | 1082 +++++++++++++++++++++++++++++++++++++++- udp/udp_linux_32.go | 35 +- udp/udp_linux_64.go | 37 +- udp/udp_rio_windows.go | 4 +- udp/udp_tester.go | 2 +- 14 files changed, 1734 insertions(+), 46 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1de3c19..ff4728e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,12 +7,30 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added + +- Experimental Linux UDP offload support: enable `listen.enable_gso` and + `listen.enable_gro` to activate UDP_SEGMENT batching and GRO receive + splitting. Includes automatic capability probing, per-packet fallbacks, and + runtime metrics/logs for visibility. +- Optional Linux TUN `virtio_net_hdr` support: set `tun.enable_vnet_hdr` to + have Nebula negotiate VNET headers and offload flags so future batches can + be delivered to the kernel with metadata instead of per-packet writes. +- Linux UDP send sharding can now be tuned with `listen.send_shards`; defaults + to `GOMAXPROCS` but can be increased to stripe heavy peers across more + goroutines. + ### Changed - `default_local_cidr_any` now defaults to false, meaning that any firewall rule intended to target an `unsafe_routes` entry must explicitly declare it via the `local_cidr` field. This is almost always the intended behavior. This flag is deprecated and will be removed in a future release. +- UDP receive path now enqueues into per-worker lock-free rings, restoring the + `listen.decrypt_workers`/`listen.decrypt_queue_depth` tuning knobs while + eliminating the mutex contention from the old shared channel. +- Increased replay protection window to 32k packets so high-throughput links + tolerate larger bursts of reordering without tripping the anti-replay logic. ## [1.9.4] - 2024-09-09 diff --git a/connection_state.go b/connection_state.go index faee443..5645972 100644 --- a/connection_state.go +++ b/connection_state.go @@ -13,7 +13,10 @@ import ( "github.com/slackhq/nebula/noiseutil" ) -const ReplayWindow = 1024 +// ReplayWindow controls the size of the sliding window used to detect replays. +// High-bandwidth links with GRO/GSO can reorder more than a thousand packets in +// flight, so keep this comfortably above the largest expected burst. +const ReplayWindow = 32768 type ConnectionState struct { eKey *NebulaCipherState diff --git a/interface.go b/interface.go index 082906d..584ea3e 100644 --- a/interface.go +++ b/interface.go @@ -5,9 +5,11 @@ import ( "errors" "fmt" "io" + "math/bits" "net/netip" "os" "runtime" + "sync" "sync/atomic" "time" @@ -21,7 +23,12 @@ import ( "github.com/slackhq/nebula/udp" ) -const mtu = 9001 +const ( + mtu = 9001 + tunReadBufferSize = mtu * 8 + defaultDecryptWorkerFactor = 2 + defaultInboundQueueDepth = 1024 +) type InterfaceConfig struct { HostMap *HostMap @@ -48,6 +55,8 @@ type InterfaceConfig struct { ConntrackCacheTimeout time.Duration l *logrus.Logger + DecryptWorkers int + DecryptQueueDepth int } type Interface struct { @@ -92,7 +101,167 @@ type Interface struct { messageMetrics *MessageMetrics cachedPacketMetrics *cachedPacketMetrics - l *logrus.Logger + l *logrus.Logger + ctx context.Context + udpListenWG sync.WaitGroup + inboundPool sync.Pool + decryptWG sync.WaitGroup + decryptQueues []*inboundRing + decryptWorkers int + decryptStates []decryptWorkerState + decryptCounter atomic.Uint32 +} + +type inboundPacket struct { + addr netip.AddrPort + payload []byte + release func() + queue int +} + +type decryptWorkerState struct { + queue *inboundRing + notify chan struct{} +} + +type decryptContext struct { + ctTicker *firewall.ConntrackCacheTicker + plain []byte + head header.H + fwPacket firewall.Packet + light *LightHouseHandler + nebula []byte +} + +type inboundCell struct { + seq atomic.Uint64 + pkt *inboundPacket +} + +type inboundRing struct { + mask uint64 + cells []inboundCell + enqueuePos atomic.Uint64 + dequeuePos atomic.Uint64 +} + +func newInboundRing(capacity int) *inboundRing { + if capacity < 2 { + capacity = 2 + } + size := nextPowerOfTwo(uint32(capacity)) + if size < 2 { + size = 2 + } + ring := &inboundRing{ + mask: uint64(size - 1), + cells: make([]inboundCell, size), + } + for i := range ring.cells { + ring.cells[i].seq.Store(uint64(i)) + } + return ring +} + +func nextPowerOfTwo(v uint32) uint32 { + if v == 0 { + return 1 + } + return 1 << (32 - bits.LeadingZeros32(v-1)) +} + +func (r *inboundRing) Enqueue(pkt *inboundPacket) bool { + var cell *inboundCell + pos := r.enqueuePos.Load() + for { + cell = &r.cells[pos&r.mask] + seq := cell.seq.Load() + diff := int64(seq) - int64(pos) + if diff == 0 { + if r.enqueuePos.CompareAndSwap(pos, pos+1) { + break + } + } else if diff < 0 { + return false + } else { + pos = r.enqueuePos.Load() + } + } + cell.pkt = pkt + cell.seq.Store(pos + 1) + return true +} + +func (r *inboundRing) Dequeue() (*inboundPacket, bool) { + var cell *inboundCell + pos := r.dequeuePos.Load() + for { + cell = &r.cells[pos&r.mask] + seq := cell.seq.Load() + diff := int64(seq) - int64(pos+1) + if diff == 0 { + if r.dequeuePos.CompareAndSwap(pos, pos+1) { + break + } + } else if diff < 0 { + return nil, false + } else { + pos = r.dequeuePos.Load() + } + } + pkt := cell.pkt + cell.pkt = nil + cell.seq.Store(pos + r.mask + 1) + return pkt, true +} + +func (f *Interface) getInboundPacket() *inboundPacket { + if pkt, ok := f.inboundPool.Get().(*inboundPacket); ok && pkt != nil { + return pkt + } + return &inboundPacket{} +} + +func (f *Interface) putInboundPacket(pkt *inboundPacket) { + if pkt == nil { + return + } + pkt.addr = netip.AddrPort{} + pkt.payload = nil + pkt.release = nil + pkt.queue = 0 + f.inboundPool.Put(pkt) +} + +func newDecryptContext(f *Interface) *decryptContext { + return &decryptContext{ + ctTicker: firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout), + plain: make([]byte, udp.MTU), + head: header.H{}, + fwPacket: firewall.Packet{}, + light: f.lightHouse.NewRequestHandler(), + nebula: make([]byte, 12, 12), + } +} + +func (f *Interface) processInboundPacket(pkt *inboundPacket, ctx *decryptContext) { + if pkt == nil { + return + } + defer func() { + if pkt.release != nil { + pkt.release() + } + f.putInboundPacket(pkt) + }() + + ctx.head = header.H{} + ctx.fwPacket = firewall.Packet{} + var cache firewall.ConntrackCache + if ctx.ctTicker != nil { + cache = ctx.ctTicker.Get(f.l) + } + f.readOutsidePackets(pkt.addr, nil, ctx.plain[:0], pkt.payload, &ctx.head, &ctx.fwPacket, ctx.light, ctx.nebula, pkt.queue, cache) } type EncWriter interface { @@ -162,6 +331,32 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { } cs := c.pki.getCertState() + decryptWorkers := c.DecryptWorkers + if decryptWorkers < 0 { + decryptWorkers = 0 + } + if decryptWorkers == 0 { + decryptWorkers = c.routines * defaultDecryptWorkerFactor + if decryptWorkers < c.routines { + decryptWorkers = c.routines + } + } + if decryptWorkers < 0 { + decryptWorkers = 0 + } + + queueDepth := c.DecryptQueueDepth + if queueDepth <= 0 { + queueDepth = defaultInboundQueueDepth + } + minDepth := c.routines * 64 + if minDepth <= 0 { + minDepth = 64 + } + if queueDepth < minDepth { + queueDepth = minDepth + } + ifce := &Interface{ pki: c.pki, hostMap: c.HostMap, @@ -194,7 +389,10 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { dropped: metrics.GetOrRegisterCounter("hostinfo.cached_packets.dropped", nil), }, - l: c.l, + l: c.l, + ctx: ctx, + inboundPool: sync.Pool{New: func() any { return &inboundPacket{} }}, + decryptWorkers: decryptWorkers, } ifce.tryPromoteEvery.Store(c.tryPromoteEvery) @@ -203,6 +401,19 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { ifce.connectionManager.intf = ifce + if decryptWorkers > 0 { + ifce.decryptQueues = make([]*inboundRing, decryptWorkers) + ifce.decryptStates = make([]decryptWorkerState, decryptWorkers) + for i := 0; i < decryptWorkers; i++ { + queue := newInboundRing(queueDepth) + ifce.decryptQueues[i] = queue + ifce.decryptStates[i] = decryptWorkerState{ + queue: queue, + notify: make(chan struct{}, 1), + } + } + } + return ifce, nil } @@ -242,8 +453,68 @@ func (f *Interface) activate() { } } +func (f *Interface) startDecryptWorkers() { + if f.decryptWorkers <= 0 || len(f.decryptQueues) == 0 { + return + } + f.decryptWG.Add(f.decryptWorkers) + for i := 0; i < f.decryptWorkers; i++ { + go f.decryptWorker(i) + } +} + +func (f *Interface) decryptWorker(id int) { + defer f.decryptWG.Done() + if id < 0 || id >= len(f.decryptStates) { + return + } + state := f.decryptStates[id] + if state.queue == nil { + return + } + ctx := newDecryptContext(f) + for { + for { + pkt, ok := state.queue.Dequeue() + if !ok { + break + } + f.processInboundPacket(pkt, ctx) + } + if f.closed.Load() || f.ctx.Err() != nil { + for { + pkt, ok := state.queue.Dequeue() + if !ok { + return + } + f.processInboundPacket(pkt, ctx) + } + } + select { + case <-f.ctx.Done(): + case <-state.notify: + } + } +} + +func (f *Interface) notifyDecryptWorker(idx int) { + if idx < 0 || idx >= len(f.decryptStates) { + return + } + state := f.decryptStates[idx] + if state.notify == nil { + return + } + select { + case state.notify <- struct{}{}: + default: + } +} + func (f *Interface) run() { + f.startDecryptWorkers() // Launch n queues to read packets from udp + f.udpListenWG.Add(f.routines) for i := 0; i < f.routines; i++ { go f.listenOut(i) } @@ -256,6 +527,7 @@ func (f *Interface) run() { func (f *Interface) listenOut(i int) { runtime.LockOSThread() + defer f.udpListenWG.Done() var li udp.Conn if i > 0 { @@ -264,23 +536,78 @@ func (f *Interface) listenOut(i int) { li = f.outside } - ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) - lhh := f.lightHouse.NewRequestHandler() - plaintext := make([]byte, udp.MTU) - h := &header.H{} - fwPacket := &firewall.Packet{} - nb := make([]byte, 12, 12) + useWorkers := f.decryptWorkers > 0 && len(f.decryptQueues) > 0 + var ( + inlineTicker *firewall.ConntrackCacheTicker + inlineHandler *LightHouseHandler + inlinePlain []byte + inlineHeader header.H + inlinePacket firewall.Packet + inlineNB []byte + inlineCtx *decryptContext + ) - li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { - f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l)) + if useWorkers { + inlineCtx = newDecryptContext(f) + } else { + inlineTicker = firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) + inlineHandler = f.lightHouse.NewRequestHandler() + inlinePlain = make([]byte, udp.MTU) + inlineNB = make([]byte, 12, 12) + } + + li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte, release func()) { + if !useWorkers { + if release != nil { + defer release() + } + select { + case <-f.ctx.Done(): + return + default: + } + inlineHeader = header.H{} + inlinePacket = firewall.Packet{} + var cache firewall.ConntrackCache + if inlineTicker != nil { + cache = inlineTicker.Get(f.l) + } + f.readOutsidePackets(fromUdpAddr, nil, inlinePlain[:0], payload, &inlineHeader, &inlinePacket, inlineHandler, inlineNB, i, cache) + return + } + + if f.ctx.Err() != nil { + if release != nil { + release() + } + return + } + + pkt := f.getInboundPacket() + pkt.addr = fromUdpAddr + pkt.payload = payload + pkt.release = release + pkt.queue = i + + queueCount := len(f.decryptQueues) + if queueCount == 0 { + f.processInboundPacket(pkt, inlineCtx) + return + } + w := int(f.decryptCounter.Add(1)-1) % queueCount + if w < 0 || w >= queueCount || !f.decryptQueues[w].Enqueue(pkt) { + f.processInboundPacket(pkt, inlineCtx) + return + } + f.notifyDecryptWorker(w) }) } func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { runtime.LockOSThread() - packet := make([]byte, mtu) - out := make([]byte, mtu) + packet := make([]byte, tunReadBufferSize) + out := make([]byte, tunReadBufferSize) fwPacket := &firewall.Packet{} nb := make([]byte, 12, 12) @@ -458,6 +785,19 @@ func (f *Interface) Close() error { } } + f.udpListenWG.Wait() + if f.decryptWorkers > 0 { + for _, state := range f.decryptStates { + if state.notify != nil { + select { + case state.notify <- struct{}{}: + default: + } + } + } + f.decryptWG.Wait() + } + // Release the tun device return f.inside.Close() } diff --git a/main.go b/main.go index eb296fb..584f1c6 100644 --- a/main.go +++ b/main.go @@ -120,6 +120,8 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg l.WithField("duration", conntrackCacheTimeout).Info("Using routine-local conntrack cache") } + udp.SetDisableUDPCsum(c.GetBool("listen.disable_udp_checksum", false)) + var tun overlay.Device if !configTest { c.CatchHUP(ctx) @@ -221,6 +223,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg } } + decryptWorkers := c.GetInt("listen.decrypt_workers", 0) + decryptQueueDepth := c.GetInt("listen.decrypt_queue_depth", 0) + ifConfig := &InterfaceConfig{ HostMap: hostMap, Inside: tun, @@ -243,6 +248,8 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg punchy: punchy, ConntrackCacheTimeout: conntrackCacheTimeout, l: l, + DecryptWorkers: decryptWorkers, + DecryptQueueDepth: decryptQueueDepth, } var ifce *Interface diff --git a/outside.go b/outside.go index 5ff87bd..195665d 100644 --- a/outside.go +++ b/outside.go @@ -470,7 +470,13 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb) if err != nil { - hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet") + hostinfo.logger(f.l). + WithError(err). + WithField("tag", "decrypt-debug"). + WithField("remoteIndexLocal", hostinfo.localIndexId). + WithField("messageCounter", messageCounter). + WithField("packet_len", len(packet)). + Error("Failed to decrypt packet") return false } diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 44d8746..bcf3aea 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -25,14 +25,17 @@ import ( type tun struct { io.ReadWriteCloser - fd int - Device string - vpnNetworks []netip.Prefix - MaxMTU int - DefaultMTU int - TXQueueLen int - deviceIndex int - ioctlFd uintptr + fd int + Device string + vpnNetworks []netip.Prefix + MaxMTU int + DefaultMTU int + TXQueueLen int + deviceIndex int + ioctlFd uintptr + enableVnetHdr bool + vnetHdrLen int + queues []*tunQueue Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] @@ -65,10 +68,90 @@ type ifreqQLEN struct { pad [8]byte } +const ( + virtioNetHdrLen = 12 + tunDefaultMaxPacket = 65536 +) + +type tunQueue struct { + file *os.File + fd int + enableVnetHdr bool + vnetHdrLen int + maxPacket int + writeScratch []byte + readScratch []byte + l *logrus.Logger +} + +func newTunQueue(file *os.File, enableVnetHdr bool, vnetHdrLen, maxPacket int, l *logrus.Logger) *tunQueue { + if maxPacket <= 0 { + maxPacket = tunDefaultMaxPacket + } + q := &tunQueue{ + file: file, + fd: int(file.Fd()), + enableVnetHdr: enableVnetHdr, + vnetHdrLen: vnetHdrLen, + maxPacket: maxPacket, + l: l, + } + if enableVnetHdr { + q.growReadScratch(maxPacket) + } + return q +} + +func (q *tunQueue) growReadScratch(packetSize int) { + needed := q.vnetHdrLen + packetSize + if needed < q.vnetHdrLen+DefaultMTU { + needed = q.vnetHdrLen + DefaultMTU + } + if q.readScratch == nil || cap(q.readScratch) < needed { + q.readScratch = make([]byte, needed) + } else { + q.readScratch = q.readScratch[:needed] + } +} + +func (q *tunQueue) setMaxPacket(packet int) { + if packet <= 0 { + packet = DefaultMTU + } + q.maxPacket = packet + if q.enableVnetHdr { + q.growReadScratch(packet) + } +} + +func configureVnetHdr(fd int, hdrLen int, l *logrus.Logger) error { + features, err := unix.IoctlGetInt(fd, unix.TUNGETFEATURES) + if err == nil && features&unix.IFF_VNET_HDR == 0 { + return fmt.Errorf("kernel does not support IFF_VNET_HDR") + } + if err := unix.IoctlSetInt(fd, unix.TUNSETVNETHDRSZ, hdrLen); err != nil { + return err + } + offload := unix.TUN_F_CSUM | unix.TUN_F_UFO + if err := unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, offload); err != nil { + if l != nil { + l.WithError(err).Warn("Failed to enable TUN offload features") + } + } + return nil +} + func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") + enableVnetHdr := c.GetBool("tun.enable_vnet_hdr", false) + if enableVnetHdr { + if err := configureVnetHdr(deviceFd, virtioNetHdrLen, l); err != nil { + l.WithError(err).Warn("Failed to configure VNET header support on provided tun fd; disabling") + enableVnetHdr = false + } + } - t, err := newTunGeneric(c, l, file, vpnNetworks) + t, err := newTunGeneric(c, l, file, vpnNetworks, enableVnetHdr) if err != nil { return nil, err } @@ -106,14 +189,25 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu if multiqueue { req.Flags |= unix.IFF_MULTI_QUEUE } + enableVnetHdr := c.GetBool("tun.enable_vnet_hdr", false) + if enableVnetHdr { + req.Flags |= unix.IFF_VNET_HDR + } copy(req.Name[:], c.GetString("tun.dev", "")) if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { return nil, err } name := strings.Trim(string(req.Name[:]), "\x00") + if enableVnetHdr { + if err := configureVnetHdr(fd, virtioNetHdrLen, l); err != nil { + l.WithError(err).Warn("Failed to configure VNET header support on tun device; disabling") + enableVnetHdr = false + } + } + file := os.NewFile(uintptr(fd), "/dev/net/tun") - t, err := newTunGeneric(c, l, file, vpnNetworks) + t, err := newTunGeneric(c, l, file, vpnNetworks, enableVnetHdr) if err != nil { return nil, err } @@ -123,21 +217,30 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu return t, nil } -func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) { +func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix, enableVnetHdr bool) (*tun, error) { + queue := newTunQueue(file, enableVnetHdr, virtioNetHdrLen, tunDefaultMaxPacket, l) t := &tun{ - ReadWriteCloser: file, + ReadWriteCloser: queue, fd: int(file.Fd()), vpnNetworks: vpnNetworks, TXQueueLen: c.GetInt("tun.tx_queue", 500), useSystemRoutes: c.GetBool("tun.use_system_route_table", false), useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0), l: l, + enableVnetHdr: enableVnetHdr, + vnetHdrLen: virtioNetHdrLen, + queues: []*tunQueue{queue}, } err := t.reload(c, true) if err != nil { return nil, err } + if enableVnetHdr { + for _, q := range t.queues { + q.setMaxPacket(t.MaxMTU) + } + } c.RegisterReloadCallback(func(c *config.C) { err := t.reload(c, false) @@ -180,6 +283,11 @@ func (t *tun) reload(c *config.C, initial bool) error { t.MaxMTU = newMaxMTU t.DefaultMTU = newDefaultMTU + if t.enableVnetHdr { + for _, q := range t.queues { + q.setMaxPacket(t.MaxMTU) + } + } // Teach nebula how to handle the routes before establishing them in the system table oldRoutes := t.Routes.Swap(&routes) @@ -224,14 +332,87 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { var req ifReq req.Flags = uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE) + if t.enableVnetHdr { + req.Flags |= unix.IFF_VNET_HDR + } copy(req.Name[:], t.Device) if err = ioctl(uintptr(fd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&req))); err != nil { return nil, err } file := os.NewFile(uintptr(fd), "/dev/net/tun") + queue := newTunQueue(file, t.enableVnetHdr, t.vnetHdrLen, t.MaxMTU, t.l) + if t.enableVnetHdr { + if err := configureVnetHdr(fd, t.vnetHdrLen, t.l); err != nil { + queue.enableVnetHdr = false + } + } + t.queues = append(t.queues, queue) - return file, nil + return queue, nil +} + +func (q *tunQueue) Read(p []byte) (int, error) { + if !q.enableVnetHdr { + return q.file.Read(p) + } + + if len(p)+q.vnetHdrLen > cap(q.readScratch) { + q.growReadScratch(len(p)) + } + + buf := q.readScratch[:cap(q.readScratch)] + n, err := q.file.Read(buf) + if n <= 0 { + return n, err + } + if n < q.vnetHdrLen { + if err == nil { + err = io.ErrUnexpectedEOF + } + return 0, err + } + + payload := buf[q.vnetHdrLen:n] + if len(payload) > len(p) { + copy(p, payload[:len(p)]) + if err == nil { + err = io.ErrShortBuffer + } + return len(p), err + } + copy(p, payload) + return len(payload), err +} + +func (q *tunQueue) Write(b []byte) (int, error) { + if !q.enableVnetHdr { + return unix.Write(q.fd, b) + } + + total := q.vnetHdrLen + len(b) + if cap(q.writeScratch) < total { + q.writeScratch = make([]byte, total) + } else { + q.writeScratch = q.writeScratch[:total] + } + + for i := 0; i < q.vnetHdrLen; i++ { + q.writeScratch[i] = 0 + } + copy(q.writeScratch[q.vnetHdrLen:], b) + + n, err := unix.Write(q.fd, q.writeScratch) + if n >= q.vnetHdrLen { + n -= q.vnetHdrLen + } else { + n = 0 + } + return n, err +} + +func (q *tunQueue) Close() error { + return q.file.Close() } func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways { diff --git a/udp/conn.go b/udp/conn.go index 895b0df..1c6a6de 100644 --- a/udp/conn.go +++ b/udp/conn.go @@ -11,6 +11,7 @@ const MTU = 9001 type EncReader func( addr netip.AddrPort, payload []byte, + release func(), ) type Conn interface { diff --git a/udp/udp_darwin.go b/udp/udp_darwin.go index c0c6233..747e5a8 100644 --- a/udp/udp_darwin.go +++ b/udp/udp_darwin.go @@ -180,7 +180,9 @@ func (u *StdConn) ListenOut(r EncReader) { u.l.WithError(err).Error("unexpected udp socket receive error") } - r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n]) + payload := make([]byte, n) + copy(payload, buffer[:n]) + r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), payload, func() {}) } } diff --git a/udp/udp_generic.go b/udp/udp_generic.go index cb21e57..8a3b0d8 100644 --- a/udp/udp_generic.go +++ b/udp/udp_generic.go @@ -82,6 +82,8 @@ func (u *GenericConn) ListenOut(r EncReader) { return } - r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n]) + payload := make([]byte, n) + copy(payload, buffer[:n]) + r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), payload, func() {}) } } diff --git a/udp/udp_linux.go b/udp/udp_linux.go index ec0bf64..f6bce41 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -5,23 +5,669 @@ package udp import ( "encoding/binary" + "errors" "fmt" "net" "net/netip" + "runtime" + "sync" + "sync/atomic" "syscall" + "time" "unsafe" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/header" "golang.org/x/sys/unix" ) +const ( + defaultGSOMaxSegments = 8 + defaultGSOMaxBytes = MTU * defaultGSOMaxSegments + defaultGROReadBufferSize = MTU * defaultGSOMaxSegments + defaultGSOFlushTimeout = 50 * time.Microsecond + linuxMaxGSOBatchBytes = 0xFFFF // Linux UDP GSO still limits the datagram payload to 64 KiB + maxSendmmsgBatch = 32 +) + type StdConn struct { sysFd int isV4 bool l *logrus.Logger batch int + + enableGRO bool + enableGSO bool + + controlLen atomic.Int32 + + gsoMaxSegments int + gsoMaxBytes int + gsoFlushTimeout time.Duration + + groSegmentPool sync.Pool + groBufSize atomic.Int64 + rxBufferPool chan []byte + gsoBufferPool sync.Pool + + gsoBatches metrics.Counter + gsoSegments metrics.Counter + groSegments metrics.Counter + + sendShards []*sendShard + shardCounter atomic.Uint32 +} + +type sendTask struct { + buf []byte + addr netip.AddrPort + segSize int + segments int + owned bool +} + +const sendShardQueueDepth = 128 + +type sendShard struct { + parent *StdConn + + mu sync.Mutex + + pendingBuf []byte + pendingSegments int + pendingAddr netip.AddrPort + pendingSegSize int + flushTimer *time.Timer + controlBuf []byte + + mmsgHeaders []linuxMmsgHdr + mmsgIovecs []unix.Iovec + mmsgLengths []int + + outQueue chan *sendTask + workerDone sync.WaitGroup +} + +func (u *StdConn) initSendShards() { + shardCount := runtime.GOMAXPROCS(0) + if shardCount < 1 { + shardCount = 1 + } + u.resizeSendShards(shardCount) +} + +func (u *StdConn) selectSendShard(addr netip.AddrPort) *sendShard { + if len(u.sendShards) == 0 { + return nil + } + if len(u.sendShards) == 1 { + return u.sendShards[0] + } + idx := int(u.shardCounter.Add(1)-1) % len(u.sendShards) + if idx < 0 { + idx = -idx + } + return u.sendShards[idx] +} + +func (u *StdConn) resizeSendShards(count int) { + if count <= 0 { + count = runtime.GOMAXPROCS(0) + if count < 1 { + count = 1 + } + } + + if len(u.sendShards) == count { + return + } + + for _, shard := range u.sendShards { + if shard == nil { + continue + } + shard.mu.Lock() + if shard.pendingSegments > 0 { + if err := shard.flushPendingLocked(); err != nil { + u.l.WithError(err).Warn("Failed to flush send shard while resizing") + } + } else { + shard.stopFlushTimerLocked() + } + buf := shard.pendingBuf + shard.pendingBuf = nil + shard.mu.Unlock() + if buf != nil { + u.releaseGSOBuf(buf) + } + shard.stopSender() + } + + newShards := make([]*sendShard, count) + for i := range newShards { + shard := &sendShard{parent: u} + shard.startSender() + newShards[i] = shard + } + u.sendShards = newShards + u.shardCounter.Store(0) + u.l.WithField("send_shards", count).Debug("Configured UDP send shards") +} + +func (u *StdConn) setGroBufferSize(size int) { + if size < defaultGROReadBufferSize { + size = defaultGROReadBufferSize + } + u.groBufSize.Store(int64(size)) + u.groSegmentPool = sync.Pool{New: func() any { + return make([]byte, size) + }} + if u.rxBufferPool == nil { + poolSize := u.batch * 4 + if poolSize < u.batch { + poolSize = u.batch + } + u.rxBufferPool = make(chan []byte, poolSize) + for i := 0; i < poolSize; i++ { + u.rxBufferPool <- make([]byte, size) + } + } +} + +func (u *StdConn) borrowRxBuffer(desired int) []byte { + if desired < MTU { + desired = MTU + } + if u.rxBufferPool == nil { + return make([]byte, desired) + } + buf := <-u.rxBufferPool + if cap(buf) < desired { + buf = make([]byte, desired) + } + return buf[:desired] +} + +func (u *StdConn) recycleBuffer(buf []byte) { + if buf == nil { + return + } + if u.rxBufferPool == nil { + return + } + buf = buf[:cap(buf)] + desired := int(u.groBufSize.Load()) + if desired < MTU { + desired = MTU + } + if cap(buf) < desired { + return + } + select { + case u.rxBufferPool <- buf[:desired]: + default: + } +} + +func (u *StdConn) recycleBufferSet(bufs [][]byte) { + for i := range bufs { + u.recycleBuffer(bufs[i]) + } +} + +func (u *StdConn) borrowGSOBuf() []byte { + size := u.gsoMaxBytes + if size <= 0 { + size = MTU + } + if v := u.gsoBufferPool.Get(); v != nil { + buf := v.([]byte) + if cap(buf) < size { + return make([]byte, 0, size) + } + return buf[:0] + } + return make([]byte, 0, size) +} + +func (u *StdConn) releaseGSOBuf(buf []byte) { + if buf == nil { + return + } + size := u.gsoMaxBytes + if size <= 0 { + size = MTU + } + buf = buf[:0] + if cap(buf) > size*4 { + return + } + u.gsoBufferPool.Put(buf) +} + +func (s *sendShard) ensureMmsgCapacity(n int) { + if cap(s.mmsgHeaders) < n { + s.mmsgHeaders = make([]linuxMmsgHdr, n) + } + s.mmsgHeaders = s.mmsgHeaders[:n] + if cap(s.mmsgIovecs) < n { + s.mmsgIovecs = make([]unix.Iovec, n) + } + s.mmsgIovecs = s.mmsgIovecs[:n] + if cap(s.mmsgLengths) < n { + s.mmsgLengths = make([]int, n) + } + s.mmsgLengths = s.mmsgLengths[:n] +} + +func (s *sendShard) ensurePendingBuf(p *StdConn) { + if s.pendingBuf == nil { + s.pendingBuf = p.borrowGSOBuf() + } +} + +func (s *sendShard) startSender() { + if s.outQueue != nil { + return + } + s.outQueue = make(chan *sendTask, sendShardQueueDepth) + s.workerDone.Add(1) + go s.senderLoop() +} + +func (s *sendShard) stopSender() { + s.closeSender() + s.workerDone.Wait() +} + +func (s *sendShard) closeSender() { + s.mu.Lock() + queue := s.outQueue + s.outQueue = nil + s.mu.Unlock() + if queue != nil { + close(queue) + } +} + +func (s *sendShard) senderLoop() { + defer s.workerDone.Done() + for task := range s.outQueue { + if task == nil { + continue + } + _ = s.processTask(task) + } +} + +func (s *sendShard) processTask(task *sendTask) error { + if task == nil { + return nil + } + p := s.parent + defer func() { + if task.owned && task.buf != nil { + p.releaseGSOBuf(task.buf) + } + task.buf = nil + }() + if len(task.buf) == 0 { + return nil + } + useGSO := p.enableGSO && task.segments > 1 + if useGSO { + if err := s.sendSegmentedLocked(task.buf, task.addr, task.segSize); err != nil { + if errors.Is(err, unix.EOPNOTSUPP) || errors.Is(err, unix.ENOTSUP) { + p.enableGSO = false + p.l.WithError(err).Warn("UDP GSO not supported, disabling") + } else { + p.l.WithError(err).Warn("Failed to flush GSO batch") + return err + } + } else { + s.recordGSOMetrics(task) + return nil + } + } + if err := s.sendSequentialLocked(task.buf, task.addr, task.segSize); err != nil { + p.l.WithError(err).Warn("Failed to flush batch") + return err + } + return nil +} + +func (s *sendShard) recordGSOMetrics(task *sendTask) { + p := s.parent + if p.gsoBatches != nil { + p.gsoBatches.Inc(1) + } + if p.gsoSegments != nil { + p.gsoSegments.Inc(int64(task.segments)) + } + if p.l.IsLevelEnabled(logrus.DebugLevel) { + p.l.WithFields(logrus.Fields{ + "tag": "gso-debug", + "stage": "flush", + "segments": task.segments, + "segment_size": task.segSize, + "batch_bytes": len(task.buf), + "remote_addr": task.addr.String(), + }).Debug("gso batch sent") + } +} + +func (s *sendShard) write(b []byte, addr netip.AddrPort) error { + if len(b) == 0 { + return nil + } + + s.mu.Lock() + defer s.mu.Unlock() + + p := s.parent + + if !p.enableGSO || !addr.IsValid() { + return p.directWrite(b, addr) + } + + s.ensurePendingBuf(p) + + if s.pendingSegments > 0 && s.pendingAddr != addr { + if err := s.flushPendingLocked(); err != nil { + return err + } + s.ensurePendingBuf(p) + } + + if len(b) > p.gsoMaxBytes || p.gsoMaxSegments <= 1 { + if err := s.flushPendingLocked(); err != nil { + return err + } + return p.directWrite(b, addr) + } + + if s.pendingSegments == 0 { + s.pendingAddr = addr + s.pendingSegSize = len(b) + } else if len(b) != s.pendingSegSize { + if err := s.flushPendingLocked(); err != nil { + return err + } + s.pendingAddr = addr + s.pendingSegSize = len(b) + s.ensurePendingBuf(p) + } + + if len(s.pendingBuf)+len(b) > p.gsoMaxBytes { + if err := s.flushPendingLocked(); err != nil { + return err + } + s.pendingAddr = addr + s.pendingSegSize = len(b) + s.ensurePendingBuf(p) + } + + s.pendingBuf = append(s.pendingBuf, b...) + s.pendingSegments++ + + if s.pendingSegments >= p.gsoMaxSegments { + return s.flushPendingLocked() + } + + if p.gsoFlushTimeout <= 0 { + return s.flushPendingLocked() + } + + s.scheduleFlushLocked() + return nil +} + +func (s *sendShard) flushPendingLocked() error { + if s.pendingSegments == 0 { + s.stopFlushTimerLocked() + return nil + } + + buf := s.pendingBuf + task := &sendTask{ + buf: buf, + addr: s.pendingAddr, + segSize: s.pendingSegSize, + segments: s.pendingSegments, + owned: true, + } + + s.pendingBuf = nil + s.pendingSegments = 0 + s.pendingSegSize = 0 + s.pendingAddr = netip.AddrPort{} + + s.stopFlushTimerLocked() + + queue := s.outQueue + s.mu.Unlock() + var err error + if queue == nil { + err = s.processTask(task) + } else { + defer func() { + if r := recover(); r != nil { + err = s.processTask(task) + } + }() + queue <- task + } + s.mu.Lock() + return err +} + +func (s *sendShard) sendSegmentedLocked(buf []byte, addr netip.AddrPort, segSize int) error { + if len(buf) == 0 { + return nil + } + if segSize <= 0 { + segSize = len(buf) + } + + if len(s.controlBuf) < unix.CmsgSpace(2) { + s.controlBuf = make([]byte, unix.CmsgSpace(2)) + } + control := s.controlBuf[:unix.CmsgSpace(2)] + for i := range control { + control[i] = 0 + } + + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) + setCmsgLen(hdr, 2) + hdr.Level = unix.SOL_UDP + hdr.Type = unix.UDP_SEGMENT + + dataOff := unix.CmsgLen(0) + binary.NativeEndian.PutUint16(control[dataOff:dataOff+2], uint16(segSize)) + + var sa unix.Sockaddr + if s.parent.isV4 { + sa4 := &unix.SockaddrInet4{Port: int(addr.Port())} + sa4.Addr = addr.Addr().As4() + sa = sa4 + } else { + sa6 := &unix.SockaddrInet6{Port: int(addr.Port())} + sa6.Addr = addr.Addr().As16() + sa = sa6 + } + + for { + n, err := unix.SendmsgN(s.parent.sysFd, buf, control[:unix.CmsgSpace(2)], sa, 0) + if err != nil { + if err == unix.EINTR { + continue + } + return &net.OpError{Op: "sendmsg", Err: err} + } + if n != len(buf) { + return &net.OpError{Op: "sendmsg", Err: unix.EIO} + } + return nil + } +} + +func (s *sendShard) sendSequentialLocked(buf []byte, addr netip.AddrPort, segSize int) error { + if len(buf) == 0 { + return nil + } + if segSize <= 0 { + segSize = len(buf) + } + if segSize >= len(buf) { + return s.parent.directWrite(buf, addr) + } + + var ( + namePtr *byte + nameLen uint32 + ) + if s.parent.isV4 { + var sa4 unix.RawSockaddrInet4 + sa4.Family = unix.AF_INET + sa4.Addr = addr.Addr().As4() + binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa4.Port))[:], addr.Port()) + namePtr = (*byte)(unsafe.Pointer(&sa4)) + nameLen = uint32(unsafe.Sizeof(sa4)) + } else { + var sa6 unix.RawSockaddrInet6 + sa6.Family = unix.AF_INET6 + sa6.Addr = addr.Addr().As16() + binary.BigEndian.PutUint16((*[2]byte)(unsafe.Pointer(&sa6.Port))[:], addr.Port()) + namePtr = (*byte)(unsafe.Pointer(&sa6)) + nameLen = uint32(unsafe.Sizeof(sa6)) + } + + total := len(buf) + if total == 0 { + return nil + } + basePtr := uintptr(unsafe.Pointer(&buf[0])) + offset := 0 + + for offset < total { + remaining := total - offset + segments := (remaining + segSize - 1) / segSize + if segments > maxSendmmsgBatch { + segments = maxSendmmsgBatch + } + + s.ensureMmsgCapacity(segments) + msgs := s.mmsgHeaders[:segments] + iovecs := s.mmsgIovecs[:segments] + lens := s.mmsgLengths[:segments] + + batchStart := offset + segOffset := offset + actual := 0 + for actual < segments && segOffset < total { + segLen := segSize + if segLen > total-segOffset { + segLen = total - segOffset + } + + msgs[actual] = linuxMmsgHdr{} + lens[actual] = segLen + iovecs[actual].Base = &buf[segOffset] + setIovecLen(&iovecs[actual], segLen) + msgs[actual].Hdr.Iov = &iovecs[actual] + setMsghdrIovlen(&msgs[actual].Hdr, 1) + msgs[actual].Hdr.Name = namePtr + msgs[actual].Hdr.Namelen = nameLen + msgs[actual].Hdr.Control = nil + msgs[actual].Hdr.Controllen = 0 + msgs[actual].Hdr.Flags = 0 + msgs[actual].Len = 0 + + actual++ + segOffset += segLen + } + if actual == 0 { + break + } + msgs = msgs[:actual] + lens = lens[:actual] + + retry: + sent, err := sendmmsg(s.parent.sysFd, msgs, 0) + if err != nil { + if err == unix.EINTR { + goto retry + } + return &net.OpError{Op: "sendmmsg", Err: err} + } + if sent == 0 { + goto retry + } + + bytesSent := 0 + for i := 0; i < sent; i++ { + bytesSent += lens[i] + } + offset = batchStart + bytesSent + + if sent < len(msgs) { + for j := sent; j < len(msgs); j++ { + start := int(uintptr(unsafe.Pointer(iovecs[j].Base)) - basePtr) + if start < 0 || start >= total { + continue + } + end := start + lens[j] + if end > total { + end = total + } + if err := s.parent.directWrite(buf[start:end], addr); err != nil { + return err + } + if end > offset { + offset = end + } + } + } + } + + return nil +} + +func (s *sendShard) scheduleFlushLocked() { + timeout := s.parent.gsoFlushTimeout + if timeout <= 0 { + _ = s.flushPendingLocked() + return + } + if s.flushTimer == nil { + s.flushTimer = time.AfterFunc(timeout, s.flushTimerHandler) + return + } + if !s.flushTimer.Stop() { + // allow existing timer to drain + } + if !s.flushTimer.Reset(timeout) { + s.flushTimer = time.AfterFunc(timeout, s.flushTimerHandler) + } +} + +func (s *sendShard) stopFlushTimerLocked() { + if s.flushTimer != nil { + s.flushTimer.Stop() + } +} + +func (s *sendShard) flushTimerHandler() { + s.mu.Lock() + defer s.mu.Unlock() + if s.pendingSegments == 0 { + return + } + if err := s.flushPendingLocked(); err != nil { + s.parent.l.WithError(err).Warn("Failed to flush GSO batch") + } } func maybeIPV4(ip net.IP) (net.IP, bool) { @@ -69,7 +715,29 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in return nil, fmt.Errorf("unable to bind to socket: %s", err) } - return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err + if ip.Is4() && udpChecksumDisabled() { + if err := unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_NO_CHECK, 1); err != nil { + l.WithError(err).Warn("Failed to disable IPv4 UDP checksum via SO_NO_CHECK") + } else { + l.Debug("Disabled IPv4 UDP checksum using SO_NO_CHECK") + } + } + + conn := &StdConn{ + sysFd: fd, + isV4: ip.Is4(), + l: l, + batch: batch, + gsoMaxSegments: defaultGSOMaxSegments, + gsoMaxBytes: defaultGSOMaxBytes, + gsoFlushTimeout: defaultGSOFlushTimeout, + gsoBatches: metrics.GetOrRegisterCounter("udp.gso.batches", nil), + gsoSegments: metrics.GetOrRegisterCounter("udp.gso.segments", nil), + groSegments: metrics.GetOrRegisterCounter("udp.gro.segments", nil), + } + conn.setGroBufferSize(defaultGROReadBufferSize) + conn.initSendShards() + return conn, err } func (u *StdConn) Rebind() error { @@ -121,27 +789,92 @@ func (u *StdConn) LocalAddr() (netip.AddrPort, error) { func (u *StdConn) ListenOut(r EncReader) { var ip netip.Addr - msgs, buffers, names := u.PrepareRawMessages(u.batch) + msgs, buffers, names, controls := u.PrepareRawMessages(u.batch) read := u.ReadMulti if u.batch == 1 { read = u.ReadSingle } for { + desiredGroSize := int(u.groBufSize.Load()) + if desiredGroSize < MTU { + desiredGroSize = MTU + } + if len(buffers) == 0 || cap(buffers[0]) < desiredGroSize { + u.recycleBufferSet(buffers) + msgs, buffers, names, controls = u.PrepareRawMessages(u.batch) + } + desiredControl := int(u.controlLen.Load()) + hasControl := len(controls) > 0 + if (desiredControl > 0) != hasControl || (desiredControl > 0 && hasControl && len(controls[0]) != desiredControl) { + u.recycleBufferSet(buffers) + msgs, buffers, names, controls = u.PrepareRawMessages(u.batch) + hasControl = len(controls) > 0 + } + + if hasControl { + for i := range msgs { + if len(controls) <= i || len(controls[i]) == 0 { + continue + } + msgs[i].Hdr.Controllen = controllen(len(controls[i])) + } + } + n, err := read(msgs) if err != nil { u.l.WithError(err).Debug("udp socket is closed, exiting read loop") + u.recycleBufferSet(buffers) return } for i := 0; i < n; i++ { + payloadLen := int(msgs[i].Len) + if payloadLen == 0 { + continue + } + // Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic if u.isV4 { ip, _ = netip.AddrFromSlice(names[i][4:8]) } else { ip, _ = netip.AddrFromSlice(names[i][8:24]) } - r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), buffers[i][:msgs[i].Len]) + addr := netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])) + buf := buffers[i] + payload := buf[:payloadLen] + released := false + release := func() { + if !released { + released = true + u.recycleBuffer(buf) + } + } + handled := false + + if len(controls) > i && len(controls[i]) > 0 { + if segSize, segCount := u.parseGROSegment(&msgs[i], controls[i]); segSize > 0 && segSize < payloadLen { + if u.emitSegments(r, addr, payload, segSize, segCount, release) { + handled = true + } else if segCount > 1 { + u.l.WithFields(logrus.Fields{ + "tag": "gro-debug", + "stage": "listen_out", + "reason": "emit_failed", + "payload_len": payloadLen, + "seg_size": segSize, + "seg_count": segCount, + }).Debug("gro-debug fallback to single packet") + } + } + } + + if !handled { + r(addr, payload, release) + } + + buffers[i] = u.borrowRxBuffer(desiredGroSize) + setIovecBase(&msgs[i], buffers[i]) } } } @@ -188,6 +921,13 @@ func (u *StdConn) ReadMulti(msgs []rawMessage) (int, error) { } func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error { + if u.enableGSO { + if err := u.writeToGSO(b, ip); err != nil { + return err + } + return nil + } + if u.isV4 { return u.writeTo4(b, ip) } @@ -248,6 +988,311 @@ func (u *StdConn) writeTo4(b []byte, ip netip.AddrPort) error { } } +func (u *StdConn) writeToGSO(b []byte, addr netip.AddrPort) error { + if len(b) == 0 { + return nil + } + shard := u.selectSendShard(addr) + if shard == nil { + return u.directWrite(b, addr) + } + return shard.write(b, addr) +} + +func (u *StdConn) directWrite(b []byte, addr netip.AddrPort) error { + if u.isV4 { + return u.writeTo4(b, addr) + } + return u.writeTo6(b, addr) +} + +func (u *StdConn) emitSegments(r EncReader, addr netip.AddrPort, payload []byte, segSize, segCount int, release func()) bool { + if segSize <= 0 || segSize >= len(payload) { + u.l.WithFields(logrus.Fields{ + "tag": "gro-debug", + "stage": "emit", + "reason": "invalid_seg_size", + "payload_len": len(payload), + "seg_size": segSize, + "seg_count": segCount, + }).Debug("gro-debug skip emit") + return false + } + + totalLen := len(payload) + if segCount <= 0 { + segCount = (totalLen + segSize - 1) / segSize + } + if segCount <= 1 { + u.l.WithFields(logrus.Fields{ + "tag": "gro-debug", + "stage": "emit", + "reason": "single_segment", + "payload_len": totalLen, + "seg_size": segSize, + "seg_count": segCount, + }).Debug("gro-debug skip emit") + return false + } + + defer func() { + if release != nil { + release() + } + }() + + actualSegments := 0 + start := 0 + debugEnabled := u.l.IsLevelEnabled(logrus.DebugLevel) + var firstHeader header.H + var firstParsed bool + var firstCounter uint64 + var firstRemote uint32 + + for start < totalLen && actualSegments < segCount { + end := start + segSize + if end > totalLen { + end = totalLen + } + + segLen := end - start + bufAny := u.groSegmentPool.Get() + var segBuf []byte + if bufAny == nil { + segBuf = make([]byte, segLen) + } else { + segBuf = bufAny.([]byte) + if cap(segBuf) < segLen { + segBuf = make([]byte, segLen) + } + } + segment := segBuf[:segLen] + copy(segment, payload[start:end]) + + if debugEnabled && !firstParsed { + if err := firstHeader.Parse(segment); err == nil { + firstParsed = true + firstCounter = firstHeader.MessageCounter + firstRemote = firstHeader.RemoteIndex + } else { + u.l.WithFields(logrus.Fields{ + "tag": "gro-debug", + "stage": "emit", + "event": "parse_fail", + "seg_index": actualSegments, + "seg_size": segSize, + "seg_count": segCount, + "payload_len": totalLen, + "err": err, + }).Debug("gro-debug segment parse failed") + } + } + + start = end + actualSegments++ + r(addr, segment, func() { + u.groSegmentPool.Put(segBuf[:cap(segBuf)]) + }) + + if debugEnabled && actualSegments == segCount && segLen < segSize { + var tail header.H + if err := tail.Parse(segment); err == nil { + u.l.WithFields(logrus.Fields{ + "tag": "gro-debug", + "stage": "emit", + "event": "tail_segment", + "segment_len": segLen, + "remote_index": tail.RemoteIndex, + "message_counter": tail.MessageCounter, + }).Debug("gro-debug tail segment metadata") + } + } + + } + + if u.groSegments != nil { + u.groSegments.Inc(int64(actualSegments)) + } + + if debugEnabled && actualSegments > 0 { + lastLen := segSize + if tail := totalLen % segSize; tail != 0 { + lastLen = tail + } + u.l.WithFields(logrus.Fields{ + "tag": "gro-debug", + "stage": "emit", + "event": "success", + "payload_len": totalLen, + "seg_size": segSize, + "seg_count": segCount, + "actual_segs": actualSegments, + "last_seg_len": lastLen, + "addr": addr.String(), + "first_remote": firstRemote, + "first_counter": firstCounter, + }).Debug("gro-debug emit") + } + + return true +} + +func (u *StdConn) parseGROSegment(msg *rawMessage, control []byte) (int, int) { + ctrlLen := int(msg.Hdr.Controllen) + if ctrlLen <= 0 { + return 0, 0 + } + if ctrlLen > len(control) { + ctrlLen = len(control) + } + + cmsgs, err := unix.ParseSocketControlMessage(control[:ctrlLen]) + if err != nil { + u.l.WithError(err).Debug("failed to parse UDP GRO control message") + return 0, 0 + } + + for _, c := range cmsgs { + if c.Header.Level == unix.SOL_UDP && c.Header.Type == unix.UDP_GRO && len(c.Data) >= 2 { + segSize := int(binary.NativeEndian.Uint16(c.Data[:2])) + segCount := 0 + if len(c.Data) >= 4 { + segCount = int(binary.NativeEndian.Uint16(c.Data[2:4])) + } + u.l.WithFields(logrus.Fields{ + "tag": "gro-debug", + "stage": "parse", + "seg_size": segSize, + "seg_count": segCount, + }).Debug("gro-debug control parsed") + return segSize, segCount + } + } + + return 0, 0 +} + +func (u *StdConn) configureGRO(enable bool) { + if enable == u.enableGRO { + if enable { + u.controlLen.Store(int32(unix.CmsgSpace(2))) + } else { + u.controlLen.Store(0) + } + return + } + + if enable { + if err := unix.SetsockoptInt(u.sysFd, unix.SOL_UDP, unix.UDP_GRO, 1); err != nil { + u.l.WithError(err).Warn("Failed to enable UDP GRO") + u.enableGRO = false + u.controlLen.Store(0) + return + } + u.enableGRO = true + u.controlLen.Store(int32(unix.CmsgSpace(2))) + u.l.Info("UDP GRO enabled") + } else { + if u.enableGRO { + if err := unix.SetsockoptInt(u.sysFd, unix.SOL_UDP, unix.UDP_GRO, 0); err != nil { + u.l.WithError(err).Warn("Failed to disable UDP GRO") + } + } + u.enableGRO = false + u.controlLen.Store(0) + } +} + +func (u *StdConn) configureGSO(enable bool, c *config.C) { + if len(u.sendShards) == 0 { + u.initSendShards() + } + shardCount := c.GetInt("listen.send_shards", 0) + u.resizeSendShards(shardCount) + + if !enable { + if u.enableGSO { + for _, shard := range u.sendShards { + shard.mu.Lock() + if shard.pendingSegments > 0 { + if err := shard.flushPendingLocked(); err != nil { + u.l.WithError(err).Warn("Failed to flush GSO buffers while disabling") + } + } else { + shard.stopFlushTimerLocked() + } + buf := shard.pendingBuf + shard.pendingBuf = nil + shard.mu.Unlock() + if buf != nil { + u.releaseGSOBuf(buf) + } + } + u.enableGSO = false + u.l.Info("UDP GSO disabled") + } + u.setGroBufferSize(defaultGROReadBufferSize) + return + } + + maxSegments := c.GetInt("listen.gso_max_segments", defaultGSOMaxSegments) + if maxSegments < 2 { + maxSegments = 2 + } + + maxBytes := c.GetInt("listen.gso_max_bytes", 0) + if maxBytes <= 0 { + maxBytes = defaultGSOMaxBytes + } + if maxBytes < MTU { + maxBytes = MTU + } + if maxBytes > linuxMaxGSOBatchBytes { + u.l.WithFields(logrus.Fields{ + "configured_bytes": maxBytes, + "clamped_bytes": linuxMaxGSOBatchBytes, + }).Warn("listen.gso_max_bytes exceeds Linux UDP limit; clamping") + maxBytes = linuxMaxGSOBatchBytes + } + + flushTimeout := c.GetDuration("listen.gso_flush_timeout", defaultGSOFlushTimeout) + if flushTimeout < 0 { + flushTimeout = 0 + } + + u.enableGSO = true + u.gsoMaxSegments = maxSegments + u.gsoMaxBytes = maxBytes + u.gsoFlushTimeout = flushTimeout + bufSize := defaultGROReadBufferSize + if u.gsoMaxBytes > bufSize { + bufSize = u.gsoMaxBytes + } + u.setGroBufferSize(bufSize) + + for _, shard := range u.sendShards { + shard.mu.Lock() + if shard.pendingBuf != nil { + u.releaseGSOBuf(shard.pendingBuf) + shard.pendingBuf = nil + } + shard.pendingSegments = 0 + shard.pendingSegSize = 0 + shard.pendingAddr = netip.AddrPort{} + shard.stopFlushTimerLocked() + if len(shard.controlBuf) < unix.CmsgSpace(2) { + shard.controlBuf = make([]byte, unix.CmsgSpace(2)) + } + shard.mu.Unlock() + } + + u.l.WithFields(logrus.Fields{ + "segments": u.gsoMaxSegments, + "bytes": u.gsoMaxBytes, + "flush_timeout": u.gsoFlushTimeout, + }).Info("UDP GSO configured") +} + func (u *StdConn) ReloadConfig(c *config.C) { b := c.GetInt("listen.read_buffer", 0) if b > 0 { @@ -294,6 +1339,9 @@ func (u *StdConn) ReloadConfig(c *config.C) { u.l.WithError(err).Error("Failed to set listen.so_mark") } } + + u.configureGRO(c.GetBool("listen.enable_gro", false)) + u.configureGSO(c.GetBool("listen.enable_gso", false), c) } func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error { @@ -306,7 +1354,33 @@ func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error { } func (u *StdConn) Close() error { - return syscall.Close(u.sysFd) + var flushErr error + for _, shard := range u.sendShards { + if shard == nil { + continue + } + shard.mu.Lock() + if shard.pendingSegments > 0 { + if err := shard.flushPendingLocked(); err != nil && flushErr == nil { + flushErr = err + } + } else { + shard.stopFlushTimerLocked() + } + buf := shard.pendingBuf + shard.pendingBuf = nil + shard.mu.Unlock() + if buf != nil { + u.releaseGSOBuf(buf) + } + shard.stopSender() + } + + closeErr := syscall.Close(u.sysFd) + if flushErr != nil { + return flushErr + } + return closeErr } func NewUDPStatsEmitter(udpConns []Conn) func() { diff --git a/udp/udp_linux_32.go b/udp/udp_linux_32.go index de8f1cd..c0030a6 100644 --- a/udp/udp_linux_32.go +++ b/udp/udp_linux_32.go @@ -30,17 +30,29 @@ type rawMessage struct { Len uint32 } -func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { +func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte, [][]byte) { + controlLen := int(u.controlLen.Load()) + msgs := make([]rawMessage, n) buffers := make([][]byte, n) names := make([][]byte, n) + var controls [][]byte + if controlLen > 0 { + controls = make([][]byte, n) + } + for i := range msgs { - buffers[i] = make([]byte, MTU) + size := int(u.groBufSize.Load()) + if size < MTU { + size = MTU + } + buf := u.borrowRxBuffer(size) + buffers[i] = buf names[i] = make([]byte, unix.SizeofSockaddrInet6) vs := []iovec{ - {Base: &buffers[i][0], Len: uint32(len(buffers[i]))}, + {Base: &buf[0], Len: uint32(len(buf))}, } msgs[i].Hdr.Iov = &vs[0] @@ -48,7 +60,22 @@ func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { msgs[i].Hdr.Name = &names[i][0] msgs[i].Hdr.Namelen = uint32(len(names[i])) + + if controlLen > 0 { + controls[i] = make([]byte, controlLen) + msgs[i].Hdr.Control = &controls[i][0] + msgs[i].Hdr.Controllen = controllen(len(controls[i])) + } else { + msgs[i].Hdr.Control = nil + msgs[i].Hdr.Controllen = controllen(0) + } } - return msgs, buffers, names + return msgs, buffers, names, controls +} + +func setIovecBase(msg *rawMessage, buf []byte) { + iov := (*iovec)(msg.Hdr.Iov) + iov.Base = &buf[0] + iov.Len = uint32(len(buf)) } diff --git a/udp/udp_linux_64.go b/udp/udp_linux_64.go index 48c5a97..1c45fda 100644 --- a/udp/udp_linux_64.go +++ b/udp/udp_linux_64.go @@ -33,25 +33,50 @@ type rawMessage struct { Pad0 [4]byte } -func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte) { +func (u *StdConn) PrepareRawMessages(n int) ([]rawMessage, [][]byte, [][]byte, [][]byte) { + controlLen := int(u.controlLen.Load()) + msgs := make([]rawMessage, n) buffers := make([][]byte, n) names := make([][]byte, n) + var controls [][]byte + if controlLen > 0 { + controls = make([][]byte, n) + } + for i := range msgs { - buffers[i] = make([]byte, MTU) + size := int(u.groBufSize.Load()) + if size < MTU { + size = MTU + } + buf := u.borrowRxBuffer(size) + buffers[i] = buf names[i] = make([]byte, unix.SizeofSockaddrInet6) - vs := []iovec{ - {Base: &buffers[i][0], Len: uint64(len(buffers[i]))}, - } + vs := []iovec{{Base: &buf[0], Len: uint64(len(buf))}} msgs[i].Hdr.Iov = &vs[0] msgs[i].Hdr.Iovlen = uint64(len(vs)) msgs[i].Hdr.Name = &names[i][0] msgs[i].Hdr.Namelen = uint32(len(names[i])) + + if controlLen > 0 { + controls[i] = make([]byte, controlLen) + msgs[i].Hdr.Control = &controls[i][0] + msgs[i].Hdr.Controllen = controllen(len(controls[i])) + } else { + msgs[i].Hdr.Control = nil + msgs[i].Hdr.Controllen = controllen(0) + } } - return msgs, buffers, names + return msgs, buffers, names, controls +} + +func setIovecBase(msg *rawMessage, buf []byte) { + iov := (*iovec)(msg.Hdr.Iov) + iov.Base = &buf[0] + iov.Len = uint64(len(buf)) } diff --git a/udp/udp_rio_windows.go b/udp/udp_rio_windows.go index 886e024..7193d39 100644 --- a/udp/udp_rio_windows.go +++ b/udp/udp_rio_windows.go @@ -149,7 +149,9 @@ func (u *RIOConn) ListenOut(r EncReader) { continue } - r(netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), buffer[:n]) + payload := make([]byte, n) + copy(payload, buffer[:n]) + r(netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), payload, func() {}) } } diff --git a/udp/udp_tester.go b/udp/udp_tester.go index 8d5e6c1..abd45af 100644 --- a/udp/udp_tester.go +++ b/udp/udp_tester.go @@ -112,7 +112,7 @@ func (u *TesterConn) ListenOut(r EncReader) { if !ok { return } - r(p.From, p.Data) + r(p.From, p.Data, func() {}) } }