From 1fd24a19c7b172a8612cdfc59b533e0abc15654d Mon Sep 17 00:00:00 2001 From: JackDoan Date: Fri, 17 Apr 2026 14:56:18 -0500 Subject: [PATCH] holy crap 2x --- interface.go | 13 +- outside.go | 2 +- overlay/device.go | 19 ++ overlay/tun_linux.go | 68 ++++++ overlay/tun_linux_offload.go | 12 + tcp_coalesce.go | 436 +++++++++++++++++++++++++++++++++++ tcp_coalesce_test.go | 356 ++++++++++++++++++++++++++++ udp/conn.go | 9 +- udp/udp_darwin.go | 3 +- udp/udp_generic.go | 3 +- udp/udp_linux.go | 14 +- udp/udp_rio_windows.go | 3 +- udp/udp_tester.go | 3 +- 13 files changed, 928 insertions(+), 13 deletions(-) create mode 100644 tcp_coalesce.go create mode 100644 tcp_coalesce_test.go diff --git a/interface.go b/interface.go index f4a66c6f..883afebd 100644 --- a/interface.go +++ b/interface.go @@ -86,7 +86,11 @@ type Interface struct { writers []udp.Conn readers []overlay.Queue - wg sync.WaitGroup + // tunCoalescers is one tcpCoalescer per tun queue, wrapping readers[i]. + // decryptToTun sends plaintext into the coalescer; listenOut calls its + // Flush at the end of each UDP recvmmsg batch. + tunCoalescers []*tcpCoalescer + wg sync.WaitGroup // fatalErr holds the first unexpected reader error that caused shutdown. // nil means "no fatal error" (yet) @@ -184,6 +188,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { version: c.version, writers: make([]udp.Conn, c.routines), readers: make([]overlay.Queue, c.routines), + tunCoalescers: make([]*tcpCoalescer, c.routines), myVpnNetworks: cs.myVpnNetworks, myVpnNetworksTable: cs.myVpnNetworksTable, myVpnAddrs: cs.myVpnAddrs, @@ -247,6 +252,7 @@ func (f *Interface) activate() error { } } f.readers[i] = reader + f.tunCoalescers[i] = newTCPCoalescer(reader) } f.wg.Add(1) // for us to wait on Close() to return @@ -308,8 +314,13 @@ func (f *Interface) listenOut(i int) { fwPacket := &firewall.Packet{} nb := make([]byte, 12, 12) + coalescer := f.tunCoalescers[i] err := li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l)) + }, func() { + if err := coalescer.Flush(); err != nil { + f.l.WithError(err).Error("Failed to flush tun coalescer") + } }) if err != nil && !f.closed.Load() { diff --git a/outside.go b/outside.go index eba9d887..41fa5dd4 100644 --- a/outside.go +++ b/outside.go @@ -535,7 +535,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out } f.connectionManager.In(hostinfo) - _, err = f.readers[q].Write(out) + err = f.tunCoalescers[q].Add(out) if err != nil { f.l.WithError(err).Error("Failed to write to tun") } diff --git a/overlay/device.go b/overlay/device.go index 420fa8d2..dc58bcfe 100644 --- a/overlay/device.go +++ b/overlay/device.go @@ -30,3 +30,22 @@ type Device interface { SupportsMultiqueue() bool NewMultiQueueReader() (Queue, error) } + +// GSOWriter is implemented by Queues that can write a TCP TSO superpacket as +// a single virtio_net_hdr + payload writev, letting the kernel segment on +// egress. Callers type-assert on it; backends that don't support GSO return +// false from Supported and all coalescing logic is skipped. +// +// pkt must contain the IPv4/IPv6 + TCP header plus the concatenated +// coalesced payload. hdrLen is the total L3+L4 header length (where the +// payload starts). csumStart is the byte offset where the TCP header +// begins (= IP header length). gsoSize is the MSS — every segment except +// possibly the last must be exactly this many payload bytes. isV6 selects +// GSO_TCPV4 vs GSO_TCPV6. +// +// pkt's TCP checksum field must already hold the pseudo-header partial +// sum (single-fold, not inverted), per virtio NEEDS_CSUM semantics. +type GSOWriter interface { + WriteGSO(pkt []byte, gsoSize uint16, isV6 bool, hdrLen, csumStart uint16) error + GSOSupported() bool +} diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 271fa7be..41a8a1a0 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -48,6 +48,12 @@ type tunFile struct { pending [][]byte // segments waiting to be drained by Read pendingIdx int writeIovs [2]unix.Iovec // preallocated iovecs for vnetHdr writes; iovs[0] is fixed to zeroVnetHdr + + // gsoHdrBuf is a per-queue 10-byte scratch for the virtio_net_hdr emitted + // by WriteGSO. Separate from zeroVnetHdr so a concurrent non-GSO Write on + // another queue never observes a half-written header. + gsoHdrBuf [virtioNetHdrLen]byte + gsoIovs [2]unix.Iovec } // zeroVnetHdr is the 10-byte virtio_net_hdr we prepend to every TUN write when @@ -78,6 +84,8 @@ func (r *tunFile) newFriend(fd int) (*tunFile, error) { out.segBuf = make([]byte, tunSegBufCap) out.writeIovs[0].Base = &zeroVnetHdr[0] out.writeIovs[0].SetLen(virtioNetHdrLen) + out.gsoIovs[0].Base = &out.gsoHdrBuf[0] + out.gsoIovs[0].SetLen(virtioNetHdrLen) } return out, nil } @@ -111,6 +119,8 @@ func newTunFd(fd int, vnetHdr bool) (*tunFile, error) { out.segBuf = make([]byte, tunSegBufCap) out.writeIovs[0].Base = &zeroVnetHdr[0] out.writeIovs[0].SetLen(virtioNetHdrLen) + out.gsoIovs[0].Base = &out.gsoHdrBuf[0] + out.gsoIovs[0].SetLen(virtioNetHdrLen) } return out, nil @@ -331,6 +341,64 @@ func (r *tunFile) Write(buf []byte) (int, error) { } } +// GSOSupported reports whether this queue was opened with IFF_VNET_HDR and +// can accept WriteGSO. When false, callers should fall back to per-segment +// Write calls. +func (r *tunFile) GSOSupported() bool { return r.vnetHdr } + +// WriteGSO emits pkt as a single TCP TSO superpacket via writev. pkt must +// contain a full IPv4/IPv6 + TCP header prefix followed by the concatenated +// coalesced payload. The TCP checksum field must already hold the +// pseudo-header partial (NEEDS_CSUM semantics). gsoSize is the MSS; every +// segment except the last must be exactly that many payload bytes. +func (r *tunFile) WriteGSO(pkt []byte, gsoSize uint16, isV6 bool, hdrLen, csumStart uint16) error { + if !r.vnetHdr { + return fmt.Errorf("WriteGSO called on tun without IFF_VNET_HDR") + } + if len(pkt) == 0 { + return nil + } + hdr := virtioNetHdr{ + Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, + HdrLen: hdrLen, + GSOSize: gsoSize, + CsumStart: csumStart, + CsumOffset: 16, // TCP checksum field lives 16 bytes into the TCP header + } + if isV6 { + hdr.GSOType = unix.VIRTIO_NET_HDR_GSO_TCPV6 + } else { + hdr.GSOType = unix.VIRTIO_NET_HDR_GSO_TCPV4 + } + hdr.encode(r.gsoHdrBuf[:]) + + r.gsoIovs[1].Base = &pkt[0] + r.gsoIovs[1].SetLen(len(pkt)) + iovPtr := uintptr(unsafe.Pointer(&r.gsoIovs[0])) + for { + n, _, errno := syscall.RawSyscall(unix.SYS_WRITEV, uintptr(r.fd), iovPtr, 2) + if errno == 0 { + runtime.KeepAlive(pkt) + if int(n) < virtioNetHdrLen { + return io.ErrShortWrite + } + return nil + } + if errno == unix.EAGAIN { + runtime.KeepAlive(pkt) + if err := r.blockOnWrite(); err != nil { + return err + } + continue + } + if errno == unix.EINTR { + continue + } + runtime.KeepAlive(pkt) + return errno + } +} + func (r *tunFile) wakeForShutdown() error { var buf [8]byte binary.NativeEndian.PutUint64(buf[:], 1) diff --git a/overlay/tun_linux_offload.go b/overlay/tun_linux_offload.go index f4ea672a..5e86833e 100644 --- a/overlay/tun_linux_offload.go +++ b/overlay/tun_linux_offload.go @@ -54,6 +54,18 @@ func (h *virtioNetHdr) decode(b []byte) { h.CsumOffset = binary.NativeEndian.Uint16(b[8:10]) } +// encode is the inverse of decode: writes the virtio_net_hdr fields into b +// (must be at least virtioNetHdrLen bytes). Used to emit a TSO superpacket +// on egress. +func (h *virtioNetHdr) encode(b []byte) { + b[0] = h.Flags + b[1] = h.GSOType + binary.NativeEndian.PutUint16(b[2:4], h.HdrLen) + binary.NativeEndian.PutUint16(b[4:6], h.GSOSize) + binary.NativeEndian.PutUint16(b[6:8], h.CsumStart) + binary.NativeEndian.PutUint16(b[8:10], h.CsumOffset) +} + // segmentInto splits a TUN-side packet described by hdr into one or more // IP packets, each appended to *out as a slice of scratch. scratch must be // sized to hold every segment (including replicated headers). diff --git a/tcp_coalesce.go b/tcp_coalesce.go new file mode 100644 index 00000000..b16d6f2f --- /dev/null +++ b/tcp_coalesce.go @@ -0,0 +1,436 @@ +package nebula + +import ( + "encoding/binary" + "io" + + "github.com/slackhq/nebula/overlay" +) + +// IPPROTO_TCP is the IANA protocol number for TCP. Hardcoded instead of +// reaching for ipProtoTCP because golang.org/x/sys/unix doesn't +// define that constant on Windows, which would break cross-compiles even +// though this file runs unchanged on every platform. +const ipProtoTCP = 6 + +// tcpCoalesceBufSize bounds the largest coalesced superpacket we will buffer. +// Linux caps sk_gso_max_size around 64KiB; 65535 bytes covers IP hdr + TCP +// hdr + up to ~65KB of payload, which is the most the kernel's TSO can +// segment in one shot. +const tcpCoalesceBufSize = 65535 + +// tcpCoalesceMaxSegs caps how many segments we are willing to coalesce into +// a single superpacket regardless of byte budget. Kernel allows up to 64 +// for UDP GSO and 128 for many TSO engines; stop well before either limit +// to keep latency bounded. +const tcpCoalesceMaxSegs = 64 + +// tcpCoalescer accumulates adjacent in-flow TCP data segments into a single +// TSO superpacket and emits them via overlay.GSOWriter in one writev. When +// a packet fails admission or fails to extend the pending flow, the +// pending superpacket is flushed and the non-matching packet is written +// through as-is. Owns no locks — one coalescer per TUN write queue. +type tcpCoalescer struct { + plainW io.Writer + gsoW overlay.GSOWriter // nil when the queue doesn't support TSO + + buf []byte + bufLen int // valid bytes in buf — hdrLen plus accumulated payload + active bool // a seed packet is present + numSeg int + gsoSize int // payload length of each segment (= MSS of the seed) + isV6 bool + ipHdrLen int + hdrLen int // ipHdrLen + tcpHdrLen, the offset where payload starts + nextSeq uint32 // expected TCP seq of the next packet to coalesce + // psh indicates the last-accepted segment had PSH set. We accept a PSH + // packet as the final segment but reject any further Adds after that. + psh bool +} + +func newTCPCoalescer(w io.Writer) *tcpCoalescer { + c := &tcpCoalescer{plainW: w, buf: make([]byte, tcpCoalesceBufSize)} + if gw, ok := w.(overlay.GSOWriter); ok && gw.GSOSupported() { + c.gsoW = gw + } + return c +} + +// parsedTCP holds the byte offsets / values we extract from one admission +// check so Add and canAppend don't re-parse the same header twice. +type parsedTCP struct { + isV6 bool + ipHdrLen int + tcpHdrLen int + hdrLen int // ipHdrLen + tcpHdrLen + payLen int + seq uint32 + flags byte +} + +// parseCoalesceable decides whether pkt is eligible for TCP coalescing. It +// accepts IPv4 (no options, DF set, no fragmentation) and IPv6 (no +// extension headers) carrying a TCP segment with flags in {ACK, ACK|PSH} +// and a non-empty payload. On success it returns the parsed offsets. +func parseCoalesceable(pkt []byte) (parsedTCP, bool) { + var p parsedTCP + if len(pkt) < 20 { + return p, false + } + v := pkt[0] >> 4 + switch v { + case 4: + if len(pkt) < 20 { + return p, false + } + ihl := int(pkt[0]&0x0f) * 4 + if ihl != 20 { + return p, false // reject IP options + } + if pkt[9] != ipProtoTCP { + return p, false + } + // Fragment check: MF=0 and frag offset=0. Accept DF=1 or DF=0 — + // just reject any actual fragmentation. + fragField := binary.BigEndian.Uint16(pkt[6:8]) + if fragField&0x3fff != 0 { + return p, false + } + totalLen := int(binary.BigEndian.Uint16(pkt[2:4])) + if totalLen > len(pkt) || totalLen < ihl { + return p, false + } + p.isV6 = false + p.ipHdrLen = ihl + pkt = pkt[:totalLen] + case 6: + if len(pkt) < 40 { + return p, false + } + if pkt[6] != ipProtoTCP { + return p, false // reject ext headers + } + payloadLen := int(binary.BigEndian.Uint16(pkt[4:6])) + if 40+payloadLen > len(pkt) { + return p, false + } + p.isV6 = true + p.ipHdrLen = 40 + pkt = pkt[:40+payloadLen] + default: + return p, false + } + + if len(pkt) < p.ipHdrLen+20 { + return p, false + } + tcpOff := int(pkt[p.ipHdrLen+12]>>4) * 4 + if tcpOff < 20 || tcpOff > 60 { + return p, false + } + if len(pkt) < p.ipHdrLen+tcpOff { + return p, false + } + flags := pkt[p.ipHdrLen+13] + // Allow only ACK and ACK|PSH. In particular: no SYN/FIN/RST/URG/CWR/ECE. + const ack = 0x10 + const psh = 0x08 + if flags&^(ack|psh) != 0 || flags&ack == 0 { + return p, false + } + p.tcpHdrLen = tcpOff + p.hdrLen = p.ipHdrLen + tcpOff + p.payLen = len(pkt) - p.hdrLen + if p.payLen <= 0 { + return p, false + } + p.seq = binary.BigEndian.Uint32(pkt[p.ipHdrLen+4 : p.ipHdrLen+8]) + p.flags = flags + return p, true +} + +// Add takes a plaintext inbound packet destined for the tun. If GSO is +// unavailable or the packet isn't coalesceable, Add falls through to a +// plain Write on the underlying queue (flushing any pending superpacket +// first). +func (c *tcpCoalescer) Add(pkt []byte) error { + if c.gsoW == nil { + _, err := c.plainW.Write(pkt) + return err + } + + info, ok := parseCoalesceable(pkt) + if !ok { + if c.active { + if err := c.flushLocked(); err != nil { + return err + } + } + _, err := c.plainW.Write(pkt) + return err + } + + if c.active { + if c.canAppend(pkt, info) { + c.appendPayload(pkt, info) + if info.flags&0x08 != 0 { + c.psh = true + } + return nil + } + if err := c.flushLocked(); err != nil { + return err + } + } + return c.seed(pkt, info) +} + +// Flush emits any pending superpacket. Called by the UDP read loop at +// recvmmsg batch boundaries — "no more packets coming right now". +func (c *tcpCoalescer) Flush() error { + if !c.active { + return nil + } + return c.flushLocked() +} + +func (c *tcpCoalescer) reset() { + c.active = false + c.bufLen = 0 + c.numSeg = 0 + c.gsoSize = 0 + c.hdrLen = 0 + c.ipHdrLen = 0 + c.nextSeq = 0 + c.psh = false +} + +func (c *tcpCoalescer) seed(pkt []byte, info parsedTCP) error { + if info.hdrLen+info.payLen > len(c.buf) { + // Oversize single packet — flush (already done above) and passthrough. + _, err := c.plainW.Write(pkt) + return err + } + copy(c.buf, pkt[:info.hdrLen+info.payLen]) + c.active = true + c.bufLen = info.hdrLen + info.payLen + c.numSeg = 1 + c.gsoSize = info.payLen + c.isV6 = info.isV6 + c.ipHdrLen = info.ipHdrLen + c.hdrLen = info.hdrLen + c.nextSeq = info.seq + uint32(info.payLen) + c.psh = info.flags&0x08 != 0 + return nil +} + +// canAppend reports whether info's packet extends the current seed: same +// flow, adjacent seq, payload size rule, and no-PSH-mid-chain. +func (c *tcpCoalescer) canAppend(pkt []byte, info parsedTCP) bool { + if c.psh { + return false // we already accepted a PSH — chain is closed + } + if info.isV6 != c.isV6 { + return false + } + if info.hdrLen != c.hdrLen { + return false + } + if info.seq != c.nextSeq { + return false + } + if c.numSeg >= tcpCoalesceMaxSegs { + return false + } + if c.bufLen+info.payLen > len(c.buf) { + return false + } + // Every mid-chain segment must be exactly gsoSize. The final segment may + // be shorter, but once a short segment is appended we can't add another. + if info.payLen > c.gsoSize { + return false + } + if info.payLen < c.gsoSize { + // Will become the last segment — always OK to append, just no more. + } + // Compare the stable parts of the header. + if !headersMatch(c.buf[:c.hdrLen], pkt[:info.hdrLen], c.isV6, c.ipHdrLen) { + return false + } + return true +} + +func (c *tcpCoalescer) appendPayload(pkt []byte, info parsedTCP) { + copy(c.buf[c.bufLen:], pkt[info.hdrLen:info.hdrLen+info.payLen]) + c.bufLen += info.payLen + c.numSeg++ + c.nextSeq = info.seq + uint32(info.payLen) + // If this was a sub-gsoSize last segment, mark chain as closed. + if info.payLen < c.gsoSize { + c.psh = true + } +} + +// headersMatch compares two IP+TCP header prefixes for byte-for-byte +// equality on every field that must be identical across coalesced +// segments. Size/IPID/IPCsum/seq/flags/tcpCsum are masked out. +func headersMatch(a, b []byte, isV6 bool, ipHdrLen int) bool { + if len(a) != len(b) { + return false + } + if isV6 { + // IPv6: bytes [0:4] = version/TC/flow-label, [6:8] = next_hdr/hop, + // [8:40] = src+dst. Skip [4:6] payload length. + if !bytesEq(a[0:4], b[0:4]) { + return false + } + if !bytesEq(a[6:40], b[6:40]) { + return false + } + } else { + // IPv4: [0:2] version/IHL/TOS, [6:10] flags/fragoff/TTL/proto, + // [12:20] src+dst. Skip [2:4] total len, [4:6] id, [10:12] csum. + if !bytesEq(a[0:2], b[0:2]) { + return false + } + if !bytesEq(a[6:10], b[6:10]) { + return false + } + if !bytesEq(a[12:20], b[12:20]) { + return false + } + } + // TCP: compare [0:4] ports, [8:13] ack+dataoff, [14:16] window, + // [18:tcpHdrLen] options (incl. urgent). + tcp := ipHdrLen + if !bytesEq(a[tcp:tcp+4], b[tcp:tcp+4]) { + return false + } + if !bytesEq(a[tcp+8:tcp+13], b[tcp+8:tcp+13]) { + return false + } + if !bytesEq(a[tcp+14:tcp+16], b[tcp+14:tcp+16]) { + return false + } + if !bytesEq(a[tcp+18:], b[tcp+18:]) { + return false + } + return true +} + +func bytesEq(a, b []byte) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} + +func (c *tcpCoalescer) flushLocked() error { + // Guarantee the coalescer is empty on exit regardless of how we leave. + defer c.reset() + + if c.numSeg <= 1 { + _, err := c.plainW.Write(c.buf[:c.bufLen]) + return err + } + + total := c.bufLen + l4Len := total - c.ipHdrLen + + // Fix IP header length field. + if c.isV6 { + if l4Len > 0xffff { + // Shouldn't happen given buffer size, but guard against it. + return c.flushAsPerSegment() + } + binary.BigEndian.PutUint16(c.buf[4:6], uint16(l4Len)) + } else { + if total > 0xffff { + return c.flushAsPerSegment() + } + binary.BigEndian.PutUint16(c.buf[2:4], uint16(total)) + // Recompute IPv4 header checksum. + c.buf[10] = 0 + c.buf[11] = 0 + binary.BigEndian.PutUint16(c.buf[10:12], ipv4HdrChecksum(c.buf[:c.ipHdrLen])) + } + + // Write the virtio NEEDS_CSUM pseudo-header partial into the TCP csum field. + var psum uint32 + if c.isV6 { + psum = pseudoSumIPv6(c.buf[8:24], c.buf[24:40], ipProtoTCP, l4Len) + } else { + psum = pseudoSumIPv4(c.buf[12:16], c.buf[16:20], ipProtoTCP, l4Len) + } + tcsum := c.ipHdrLen + 16 + binary.BigEndian.PutUint16(c.buf[tcsum:tcsum+2], foldOnceNoInvert(psum)) + + return c.gsoW.WriteGSO(c.buf[:total], uint16(c.gsoSize), c.isV6, uint16(c.hdrLen), uint16(c.ipHdrLen)) +} + +// flushAsPerSegment is a defensive fallback used if the coalesced superpacket +// somehow exceeds 16-bit length fields. It writes the packet as-is through +// the plain writer (the kernel will reject it, but that's a visible error +// rather than silent corruption). +func (c *tcpCoalescer) flushAsPerSegment() error { + _, err := c.plainW.Write(c.buf[:c.bufLen]) + return err +} + +// ipv4HdrChecksum computes the IPv4 header checksum over hdr (which must +// already have its checksum field zeroed) and returns the folded/inverted +// 16-bit value to store. +func ipv4HdrChecksum(hdr []byte) uint16 { + var sum uint32 + for i := 0; i+1 < len(hdr); i += 2 { + sum += uint32(binary.BigEndian.Uint16(hdr[i : i+2])) + } + if len(hdr)%2 == 1 { + sum += uint32(hdr[len(hdr)-1]) << 8 + } + for sum>>16 != 0 { + sum = (sum & 0xffff) + (sum >> 16) + } + return ^uint16(sum) +} + +// pseudoSumIPv4 / pseudoSumIPv6 build the TCP pseudo-header partial sum +// expected by the virtio NEEDS_CSUM kernel path: the 32-bit accumulator +// before folding. +func pseudoSumIPv4(src, dst []byte, proto byte, l4Len int) uint32 { + var sum uint32 + sum += uint32(binary.BigEndian.Uint16(src[0:2])) + sum += uint32(binary.BigEndian.Uint16(src[2:4])) + sum += uint32(binary.BigEndian.Uint16(dst[0:2])) + sum += uint32(binary.BigEndian.Uint16(dst[2:4])) + sum += uint32(proto) + sum += uint32(l4Len) + return sum +} + +func pseudoSumIPv6(src, dst []byte, proto byte, l4Len int) uint32 { + var sum uint32 + for i := 0; i < 16; i += 2 { + sum += uint32(binary.BigEndian.Uint16(src[i : i+2])) + sum += uint32(binary.BigEndian.Uint16(dst[i : i+2])) + } + sum += uint32(l4Len >> 16) + sum += uint32(l4Len & 0xffff) + sum += uint32(proto) + return sum +} + +// foldOnceNoInvert folds the 32-bit accumulator to 16 bits and returns it +// unchanged (no one's complement). This is what virtio NEEDS_CSUM wants in +// the L4 checksum field — the kernel will add the payload sum and invert. +func foldOnceNoInvert(sum uint32) uint16 { + for sum>>16 != 0 { + sum = (sum & 0xffff) + (sum >> 16) + } + return uint16(sum) +} diff --git a/tcp_coalesce_test.go b/tcp_coalesce_test.go new file mode 100644 index 00000000..c70d28e3 --- /dev/null +++ b/tcp_coalesce_test.go @@ -0,0 +1,356 @@ +package nebula + +import ( + "encoding/binary" + "testing" +) + +// A minimal stub writer that records each plain Write and each WriteGSO +// call without touching a real TUN fd. +type fakeTunWriter struct { + gsoEnabled bool + writes [][]byte + gsoWrites []fakeGSOWrite +} + +type fakeGSOWrite struct { + pkt []byte + gsoSize uint16 + isV6 bool + hdrLen uint16 + csumStart uint16 +} + +func (w *fakeTunWriter) Write(p []byte) (int, error) { + buf := make([]byte, len(p)) + copy(buf, p) + w.writes = append(w.writes, buf) + return len(p), nil +} + +func (w *fakeTunWriter) WriteGSO(pkt []byte, gsoSize uint16, isV6 bool, hdrLen, csumStart uint16) error { + buf := make([]byte, len(pkt)) + copy(buf, pkt) + w.gsoWrites = append(w.gsoWrites, fakeGSOWrite{pkt: buf, gsoSize: gsoSize, isV6: isV6, hdrLen: hdrLen, csumStart: csumStart}) + return nil +} + +func (w *fakeTunWriter) GSOSupported() bool { return w.gsoEnabled } + +// buildTCPv4 constructs a minimal IPv4+TCP packet with the given payload, +// seq, and flags. Assumes no IP options and a 20-byte TCP header. +func buildTCPv4(seq uint32, flags byte, payload []byte) []byte { + const ipHdrLen = 20 + const tcpHdrLen = 20 + total := ipHdrLen + tcpHdrLen + len(payload) + pkt := make([]byte, total) + + // IPv4 header. + pkt[0] = 0x45 // version 4, IHL 5 + pkt[1] = 0x00 // TOS + binary.BigEndian.PutUint16(pkt[2:4], uint16(total)) + binary.BigEndian.PutUint16(pkt[4:6], 0) // id + binary.BigEndian.PutUint16(pkt[6:8], 0x4000) // DF + pkt[8] = 64 // TTL + pkt[9] = ipProtoTCP + // csum left zero — coalescer recomputes on emit. + copy(pkt[12:16], []byte{10, 0, 0, 1}) // src + copy(pkt[16:20], []byte{10, 0, 0, 2}) // dst + + // TCP header. + binary.BigEndian.PutUint16(pkt[20:22], 1000) // sport + binary.BigEndian.PutUint16(pkt[22:24], 2000) // dport + binary.BigEndian.PutUint32(pkt[24:28], seq) + binary.BigEndian.PutUint32(pkt[28:32], 12345) // ack + pkt[32] = 0x50 // data offset = 5 << 4 + pkt[33] = flags + binary.BigEndian.PutUint16(pkt[34:36], 0xffff) // window + // tcp csum zero + // urgent zero + + copy(pkt[40:], payload) + return pkt +} + +const ( + tcpAck = 0x10 + tcpPsh = 0x08 + tcpSyn = 0x02 + tcpFin = 0x01 + tcpAckPsh = tcpAck | tcpPsh +) + +func TestCoalescerPassthroughWhenGSOUnavailable(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: false} + c := newTCPCoalescer(w) + pkt := buildTCPv4(1000, tcpAck, []byte("hello")) + if err := c.Add(pkt); err != nil { + t.Fatal(err) + } + if len(w.writes) != 1 || len(w.gsoWrites) != 0 { + t.Fatalf("want single plain write, got writes=%d gso=%d", len(w.writes), len(w.gsoWrites)) + } +} + +func TestCoalescerNonTCPPassthrough(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := newTCPCoalescer(w) + // ICMP packet: proto=1. + pkt := make([]byte, 28) + pkt[0] = 0x45 + binary.BigEndian.PutUint16(pkt[2:4], 28) + pkt[9] = 1 + copy(pkt[12:16], []byte{10, 0, 0, 1}) + copy(pkt[16:20], []byte{10, 0, 0, 2}) + if err := c.Add(pkt); err != nil { + t.Fatal(err) + } + if len(w.writes) != 1 || len(w.gsoWrites) != 0 { + t.Fatalf("ICMP should pass through unchanged") + } +} + +func TestCoalescerSeedThenFlushAlone(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := newTCPCoalescer(w) + pkt := buildTCPv4(1000, tcpAck, make([]byte, 1000)) + if err := c.Add(pkt); err != nil { + t.Fatal(err) + } + // No flush yet — still pending. + if len(w.writes) != 0 || len(w.gsoWrites) != 0 { + t.Fatalf("unexpected output before flush") + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + // Single segment — should use plain write, not gso. + if len(w.writes) != 1 || len(w.gsoWrites) != 0 { + t.Fatalf("single-seg flush: writes=%d gso=%d", len(w.writes), len(w.gsoWrites)) + } +} + +func TestCoalescerCoalescesAdjacentACKs(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := newTCPCoalescer(w) + pay := make([]byte, 1200) + if err := c.Add(buildTCPv4(1000, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Add(buildTCPv4(2200, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Add(buildTCPv4(3400, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + if len(w.gsoWrites) != 1 { + t.Fatalf("want 1 gso write, got %d (plain=%d)", len(w.gsoWrites), len(w.writes)) + } + g := w.gsoWrites[0] + if g.gsoSize != 1200 { + t.Errorf("gsoSize=%d want 1200", g.gsoSize) + } + if g.hdrLen != 40 { + t.Errorf("hdrLen=%d want 40", g.hdrLen) + } + if g.csumStart != 20 { + t.Errorf("csumStart=%d want 20", g.csumStart) + } + if len(g.pkt) != 40+3*1200 { + t.Errorf("superpacket len=%d want %d", len(g.pkt), 40+3*1200) + } + // IP total length should reflect superpacket. + if tot := binary.BigEndian.Uint16(g.pkt[2:4]); int(tot) != len(g.pkt) { + t.Errorf("ip total_length=%d want %d", tot, len(g.pkt)) + } +} + +func TestCoalescerRejectsSeqGap(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := newTCPCoalescer(w) + pay := make([]byte, 1200) + if err := c.Add(buildTCPv4(1000, tcpAck, pay)); err != nil { + t.Fatal(err) + } + // seq should be 2200; use 3000 to simulate a gap. + if err := c.Add(buildTCPv4(3000, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + // First packet should have been flushed as a plain write (single seg), + // then second packet seeded and flushed likewise. + if len(w.writes) != 2 || len(w.gsoWrites) != 0 { + t.Fatalf("seq gap: want 2 plain writes got writes=%d gso=%d", len(w.writes), len(w.gsoWrites)) + } +} + +func TestCoalescerRejectsFlagMismatch(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := newTCPCoalescer(w) + pay := make([]byte, 1200) + if err := c.Add(buildTCPv4(1000, tcpAck, pay)); err != nil { + t.Fatal(err) + } + // SYN flag — not admissible at all. Should flush first packet + plain-write second. + syn := buildTCPv4(2200, tcpSyn|tcpAck, pay) + if err := c.Add(syn); err != nil { + t.Fatal(err) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + if len(w.writes) != 2 || len(w.gsoWrites) != 0 { + t.Fatalf("flag mismatch: want 2 plain writes got writes=%d gso=%d", len(w.writes), len(w.gsoWrites)) + } +} + +func TestCoalescerRejectsFIN(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := newTCPCoalescer(w) + fin := buildTCPv4(1000, tcpAck|tcpFin, []byte("x")) + if err := c.Add(fin); err != nil { + t.Fatal(err) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + if len(w.writes) != 1 || len(w.gsoWrites) != 0 { + t.Fatalf("FIN should be passthrough, got writes=%d gso=%d", len(w.writes), len(w.gsoWrites)) + } +} + +func TestCoalescerShortLastSegmentClosesChain(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := newTCPCoalescer(w) + full := make([]byte, 1200) + half := make([]byte, 500) + if err := c.Add(buildTCPv4(1000, tcpAck, full)); err != nil { + t.Fatal(err) + } + if err := c.Add(buildTCPv4(2200, tcpAck, half)); err != nil { + t.Fatal(err) + } + // Next full-size would have to start at 2700 but chain is closed — + // should flush + seed. + if err := c.Add(buildTCPv4(2700, tcpAck, full)); err != nil { + t.Fatal(err) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + // Expect: one gso write (first two coalesced) + one plain write (the + // third, flushed alone). + if len(w.gsoWrites) != 1 { + t.Fatalf("want 1 gso write got %d", len(w.gsoWrites)) + } + if len(w.writes) != 1 { + t.Fatalf("want 1 plain write got %d", len(w.writes)) + } + if w.gsoWrites[0].gsoSize != 1200 { + t.Errorf("gsoSize=%d want 1200", w.gsoWrites[0].gsoSize) + } + if got, want := len(w.gsoWrites[0].pkt), 40+1200+500; got != want { + t.Errorf("super len=%d want %d", got, want) + } +} + +func TestCoalescerPSHFinalizesChain(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := newTCPCoalescer(w) + pay := make([]byte, 1200) + if err := c.Add(buildTCPv4(1000, tcpAck, pay)); err != nil { + t.Fatal(err) + } + // Last full-size segment with PSH — admitted but chain is now closed. + if err := c.Add(buildTCPv4(2200, tcpAckPsh, pay)); err != nil { + t.Fatal(err) + } + // Further full-size would not coalesce. + if err := c.Add(buildTCPv4(3400, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + if len(w.gsoWrites) != 1 { + t.Fatalf("want 1 gso write got %d", len(w.gsoWrites)) + } + if len(w.writes) != 1 { + t.Fatalf("want 1 plain write got %d", len(w.writes)) + } +} + +func TestCoalescerRejectsDifferentFlow(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := newTCPCoalescer(w) + pay := make([]byte, 1200) + p1 := buildTCPv4(1000, tcpAck, pay) + p2 := buildTCPv4(2200, tcpAck, pay) + // Mutate p2's source port to break flow match. + binary.BigEndian.PutUint16(p2[20:22], 9999) + if err := c.Add(p1); err != nil { + t.Fatal(err) + } + if err := c.Add(p2); err != nil { + t.Fatal(err) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + // Both flushed as plain writes. + if len(w.writes) != 2 || len(w.gsoWrites) != 0 { + t.Fatalf("diff flow: writes=%d gso=%d", len(w.writes), len(w.gsoWrites)) + } +} + +func TestCoalescerRejectsIPOptions(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := newTCPCoalescer(w) + pay := make([]byte, 500) + pkt := buildTCPv4(1000, tcpAck, pay) + // Bump IHL to 6 to simulate 4 bytes of IP options. Don't actually add + // bytes — parser should bail before it matters. + pkt[0] = 0x46 + if err := c.Add(pkt); err != nil { + t.Fatal(err) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + if len(w.writes) != 1 || len(w.gsoWrites) != 0 { + t.Fatalf("IP options should passthrough, got writes=%d gso=%d", len(w.writes), len(w.gsoWrites)) + } +} + +func TestCoalescerCapBySegments(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := newTCPCoalescer(w) + pay := make([]byte, 512) // small so we can fit many before byte cap + seq := uint32(1000) + for i := 0; i < tcpCoalesceMaxSegs+5; i++ { + if err := c.Add(buildTCPv4(seq, tcpAck, pay)); err != nil { + t.Fatal(err) + } + seq += uint32(len(pay)) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + // We expect the first tcpCoalesceMaxSegs to form one gso, then 5 more: + // The 5 follow-ons seed a new super that completes as another gso if >=2, + // or a mix. Just assert we never exceed the cap per super. + for _, g := range w.gsoWrites { + segs := (len(g.pkt) - int(g.hdrLen)) / int(g.gsoSize) + if rem := (len(g.pkt) - int(g.hdrLen)) % int(g.gsoSize); rem != 0 { + segs++ + } + if segs > tcpCoalesceMaxSegs { + t.Fatalf("super exceeded seg cap: %d > %d", segs, tcpCoalesceMaxSegs) + } + } +} diff --git a/udp/conn.go b/udp/conn.go index f66deba9..c80cad8b 100644 --- a/udp/conn.go +++ b/udp/conn.go @@ -22,7 +22,12 @@ type EncReader func( type Conn interface { Rebind() error LocalAddr() (netip.AddrPort, error) - ListenOut(r EncReader) error + // ListenOut invokes r for each received packet. On batch-capable + // backends (recvmmsg), flush is called after each batch is fully + // delivered — callers use it to flush per-batch accumulators such as + // TUN write coalescers. Single-packet backends call flush after each + // packet. flush must not be nil. + ListenOut(r EncReader, flush func()) error WriteTo(b []byte, addr netip.AddrPort) error // WriteBatch sends a contiguous batch of packets, each with its own // destination. bufs and addrs must have the same length. Linux uses @@ -53,7 +58,7 @@ func (NoopConn) Rebind() error { func (NoopConn) LocalAddr() (netip.AddrPort, error) { return netip.AddrPort{}, nil } -func (NoopConn) ListenOut(_ EncReader) error { +func (NoopConn) ListenOut(_ EncReader, _ func()) error { return nil } func (NoopConn) SupportsMultipleReaders() bool { diff --git a/udp/udp_darwin.go b/udp/udp_darwin.go index 00c88203..e4bdd659 100644 --- a/udp/udp_darwin.go +++ b/udp/udp_darwin.go @@ -185,7 +185,7 @@ func NewUDPStatsEmitter(udpConns []Conn) func() { return func() {} } -func (u *StdConn) ListenOut(r EncReader) error { +func (u *StdConn) ListenOut(r EncReader, flush func()) error { buffer := make([]byte, MTU) for { @@ -200,6 +200,7 @@ func (u *StdConn) ListenOut(r EncReader) error { } r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n]) + flush() } } diff --git a/udp/udp_generic.go b/udp/udp_generic.go index f29bbc1f..8e71fa18 100644 --- a/udp/udp_generic.go +++ b/udp/udp_generic.go @@ -91,7 +91,7 @@ type rawMessage struct { Len uint32 } -func (u *GenericConn) ListenOut(r EncReader) error { +func (u *GenericConn) ListenOut(r EncReader, flush func()) error { buffer := make([]byte, MTU) for { @@ -102,6 +102,7 @@ func (u *GenericConn) ListenOut(r EncReader) error { } r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n]) + flush() } } diff --git a/udp/udp_linux.go b/udp/udp_linux.go index ca9988c9..40fd0463 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -249,7 +249,7 @@ func recvmmsg(fd uintptr, msgs []rawMessage) (int, bool, error) { return int(n), true, nil } -func (u *StdConn) listenOutSingle(r EncReader) error { +func (u *StdConn) listenOutSingle(r EncReader, flush func()) error { var err error var n int var from netip.AddrPort @@ -262,10 +262,11 @@ func (u *StdConn) listenOutSingle(r EncReader) error { } from = netip.AddrPortFrom(from.Addr().Unmap(), from.Port()) r(from, buffer[:n]) + flush() } } -func (u *StdConn) listenOutBatch(r EncReader) error { +func (u *StdConn) listenOutBatch(r EncReader, flush func()) error { var ip netip.Addr var n int var operr error @@ -297,14 +298,17 @@ func (u *StdConn) listenOutBatch(r EncReader) error { } r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), buffers[i][:msgs[i].Len]) } + // End-of-batch: let callers (e.g. TUN write coalescer) flush any + // state they accumulated across this batch. + flush() } } -func (u *StdConn) ListenOut(r EncReader) error { +func (u *StdConn) ListenOut(r EncReader, flush func()) error { if u.batch == 1 { - return u.listenOutSingle(r) + return u.listenOutSingle(r, flush) } else { - return u.listenOutBatch(r) + return u.listenOutBatch(r, flush) } } diff --git a/udp/udp_rio_windows.go b/udp/udp_rio_windows.go index fc15acbb..1ee85165 100644 --- a/udp/udp_rio_windows.go +++ b/udp/udp_rio_windows.go @@ -140,7 +140,7 @@ func (u *RIOConn) bind(l *logrus.Logger, sa windows.Sockaddr) error { return nil } -func (u *RIOConn) ListenOut(r EncReader) error { +func (u *RIOConn) ListenOut(r EncReader, flush func()) error { buffer := make([]byte, MTU) var lastRecvErr time.Time @@ -162,6 +162,7 @@ func (u *RIOConn) ListenOut(r EncReader) error { } r(netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), buffer[:n]) + flush() } } diff --git a/udp/udp_tester.go b/udp/udp_tester.go index 522f95f7..e7ef5c05 100644 --- a/udp/udp_tester.go +++ b/udp/udp_tester.go @@ -127,13 +127,14 @@ func (u *TesterConn) WriteSegmented(bufs [][]byte, addr netip.AddrPort, _ int) e func (u *TesterConn) SupportsGSO() bool { return false } -func (u *TesterConn) ListenOut(r EncReader) error { +func (u *TesterConn) ListenOut(r EncReader, flush func()) error { for { p, ok := <-u.RxPackets if !ok { return os.ErrClosed } r(p.From, p.Data) + flush() } }