GSO/GRO offloads, with TCP+ECN and UDP support

This commit is contained in:
JackDoan
2026-04-17 10:25:05 -05:00
parent f95857b4c3
commit 5d35351437
60 changed files with 6915 additions and 283 deletions

View File

@@ -14,20 +14,15 @@ type RxBatcher interface {
}
type TxBatcher interface {
// Next returns a zero-length slice with slotCap capacity over the next unused
// slot's backing bytes. The caller writes into the returned slice and then
// calls Commit with the final length and destination. Next returns nil when
// the batch is full.
Next() []byte
// Commit records the slot just returned by Next as a packet of length n
// destined for dst.
Commit(n int, dst netip.AddrPort)
// Reset clears committed slots; backing storage is retained for reuse.
Reset()
// Len returns the number of committed packets.
Len() int
// Cap returns the maximum number of slots in the batch.
Cap() int
// Get returns the buffers needed to send the batch
Get() ([][]byte, []netip.AddrPort)
// Reserve creates a pkt to borrow
Reserve(sz int) []byte
// Commit borrows pkt and records its destination plus the 2-bit
// IP-level ECN codepoint to set on the outer (carrier) header. The
// caller must keep pkt valid until the next Flush. Pass 0 (Not-ECT)
// to leave the outer ECN field unset.
Commit(pkt []byte, dst netip.AddrPort, outerECN byte)
// Flush emits every queued packet via the underlying batch writer in
// arrival order. Returns the first error observed. After Flush returns,
// borrowed payload slices may be recycled.
Flush() error
}

View File

@@ -0,0 +1,163 @@
package batch
import (
"bytes"
"encoding/binary"
)
// flowKey identifies a transport flow by {src, dst, sport, dport, family}.
// Comparable, so map lookups and linear scans over the slot list stay tight.
// Shared by the TCP and UDP coalescers; each coalescer keeps its own
// openSlots map, so a TCP and UDP flow on the same 5-tuple-without-proto
// never alias.
type flowKey struct {
src, dst [16]byte
sport, dport uint16
isV6 bool
}
// initialSlots is the starting capacity of the slot pool. One flow per
// packet is the worst case so this matches a typical carrier-side
// recvmmsg batch on the encrypted UDP socket.
const initialSlots = 64
// parsedIP is the IP-level result of parseIPPrologue. The caller layers
// L4-specific parsing (TCP / UDP) on top.
type parsedIP struct {
fk flowKey
ipHdrLen int
// pkt is the original buffer trimmed to the IP-declared total length.
// Anything below the IP layer (transport parsers) should slice into
// pkt rather than the unbounded original.
pkt []byte
}
// parseIPPrologue extracts the IP-level fields the coalescers care about:
// IHL/payload length, version, src/dst addresses, and the L4 protocol byte.
// Returns ok=false for malformed input, IPv4 with options or fragmentation,
// or IPv6 with extension headers (all rejected by both coalescers in
// identical ways before this refactor).
//
// On success, p.pkt is len-trimmed to the IP-declared length so callers
// don't have to repeat the trim. wantProto is the IANA protocol number to
// require (6 for TCP, 17 for UDP); ok=false for any other value.
func parseIPPrologue(pkt []byte, wantProto byte) (parsedIP, bool) {
var p parsedIP
if len(pkt) < 20 {
return p, false
}
v := pkt[0] >> 4
switch v {
case 4:
ihl := int(pkt[0]&0x0f) * 4
if ihl != 20 {
return p, false
}
if pkt[9] != wantProto {
return p, false
}
// Reject actual fragmentation (MF or non-zero frag offset).
if binary.BigEndian.Uint16(pkt[6:8])&0x3fff != 0 {
return p, false
}
totalLen := int(binary.BigEndian.Uint16(pkt[2:4]))
if totalLen > len(pkt) || totalLen < ihl {
return p, false
}
p.ipHdrLen = 20
p.fk.isV6 = false
copy(p.fk.src[:4], pkt[12:16])
copy(p.fk.dst[:4], pkt[16:20])
p.pkt = pkt[:totalLen]
case 6:
if len(pkt) < 40 {
return p, false
}
if pkt[6] != wantProto {
return p, false
}
payloadLen := int(binary.BigEndian.Uint16(pkt[4:6]))
if 40+payloadLen > len(pkt) {
return p, false
}
p.ipHdrLen = 40
p.fk.isV6 = true
copy(p.fk.src[:], pkt[8:24])
copy(p.fk.dst[:], pkt[24:40])
p.pkt = pkt[:40+payloadLen]
default:
return p, false
}
return p, true
}
// ipHeadersMatch compares the IP portion of two packet header prefixes for
// byte-for-byte equality on every field that must be identical across
// coalesced segments. Size/IPID/IPCsum and the 2-bit IP-level ECN field are
// masked out — the appendPayload step merges CE into the seed.
//
// The transport (L4) portion of the header is checked separately by the
// per-protocol matcher.
func ipHeadersMatch(a, b []byte, isV6 bool) bool {
if isV6 {
// IPv6: byte 0 = version/TC[7:4], byte 1 = TC[3:0]/flow[19:16],
// bytes [2:4] = flow[15:0], [6:8] = next_hdr/hop, [8:40] = src+dst.
// ECN lives in TC[1:0] = byte 1 mask 0x30. Skip [4:6] payload_len.
if a[0] != b[0] {
return false
}
if a[1]&^0x30 != b[1]&^0x30 {
return false
}
if !bytes.Equal(a[2:4], b[2:4]) {
return false
}
if !bytes.Equal(a[6:40], b[6:40]) {
return false
}
return true
}
// IPv4: byte 0 = version/IHL, byte 1 = DSCP(6)|ECN(2),
// [6:10] flags/fragoff/TTL/proto, [12:20] src+dst.
// Skip [2:4] total len, [4:6] id, [10:12] csum.
if a[0] != b[0] {
return false
}
if a[1]&^0x03 != b[1]&^0x03 {
return false
}
if !bytes.Equal(a[6:10], b[6:10]) {
return false
}
if !bytes.Equal(a[12:20], b[12:20]) {
return false
}
return true
}
// mergeECNIntoSeed ORs the 2-bit IP-level ECN field of pkt's IP header
// onto the seed's IP header, so a CE mark on any coalesced segment
// propagates to the final superpacket. (CE is 0b11; ORing yields CE if
// any segment carried it.) Used by both TCP and UDP coalescers, so the
// invariant lives in one place.
func mergeECNIntoSeed(seedHdr, pktHdr []byte, isV6 bool) {
if isV6 {
seedHdr[1] |= pktHdr[1] & 0x30
} else {
seedHdr[1] |= pktHdr[1] & 0x03
}
}
// reserveFromBacking implements the Reserve half of the RxBatcher contract
// shared by TCP and UDP coalescers. The backing slice grows on demand;
// already-committed slices reference the old array and remain valid until
// Flush resets backing.
func reserveFromBacking(backing *[]byte, sz int) []byte {
if len(*backing)+sz > cap(*backing) {
newCap := max(cap(*backing)*2, sz)
*backing = make([]byte, 0, newCap)
}
start := len(*backing)
*backing = (*backing)[:start+sz]
return (*backing)[start : start+sz : start+sz]
}

View File

@@ -0,0 +1,133 @@
package batch
import (
"io"
)
// MultiCoalescer fans plaintext packets out to lane-specific batchers based
// on the IP/L4 protocol of the packet, sharing a single Reserve arena
// across lanes so the caller's allocation pattern is unchanged.
//
// Lanes are processed independently: the TCP coalescer only sees TCP, the
// UDP coalescer only sees UDP, and the passthrough lane handles everything
// else. Per-flow arrival order is preserved because a single 5-tuple only
// ever lands in one lane and each lane preserves its own slot order.
//
// Cross-lane order is NOT preserved across the TCP/UDP/passthrough split.
// This is acceptable because the carrier-side recvmmsg path already
// stable-sorts by (peer, message counter) before delivering plaintext
// here, so replay-window invariants are unaffected, and apps observe
// correct per-flow ordering — which is all the IP layer guarantees anyway.
// Do not "fix" this by interleaving lane outputs at flush time; that
// negates the entire point of coalescing (each lane needs to see runs of
// adjacent same-flow packets to coalesce them).
type MultiCoalescer struct {
tcp *TCPCoalescer
udp *UDPCoalescer
pt *Passthrough
// arena shared across all lanes so a single Reserve grows one backing
// slice; lane Commit calls borrow into this same arena.
backing []byte
}
// NewMultiCoalescer builds a multi-lane batcher. tcpEnabled lets the caller
// opt out of TCP coalescing (e.g. when the queue can't do TSO); udpEnabled
// likewise gates UDP coalescing (only enable when USO was negotiated).
// Either lane disabled redirects its traffic into the passthrough lane.
func NewMultiCoalescer(w io.Writer, tcpEnabled, udpEnabled bool) *MultiCoalescer {
m := &MultiCoalescer{
pt: NewPassthrough(w),
backing: make([]byte, 0, initialSlots*65535),
}
if tcpEnabled {
m.tcp = NewTCPCoalescer(w)
}
if udpEnabled {
m.udp = NewUDPCoalescer(w)
}
return m
}
func (m *MultiCoalescer) Reserve(sz int) []byte {
if len(m.backing)+sz > cap(m.backing) {
newCap := max(cap(m.backing)*2, sz)
m.backing = make([]byte, 0, newCap)
}
start := len(m.backing)
m.backing = m.backing[:start+sz]
return m.backing[start : start+sz : start+sz]
}
// Commit dispatches pkt to the appropriate lane based on IP version + L4
// proto. Borrowed slice contract is identical to the single-lane batchers
// — pkt must remain valid until the next Flush.
//
// On the success path the IP/TCP-or-UDP parse happens here once and the
// parsed struct is handed to the lane via commitParsed so the lane doesn't
// re-walk the header. On a parse failure we fall through to the lane's
// public Commit, which re-runs the parse before passthrough — that path
// only fires for malformed/unsupported packets so the duplicated parse is
// not on the hot path. The lane's public Commit still works for direct
// callers.
func (m *MultiCoalescer) Commit(pkt []byte) error {
if len(pkt) < 20 {
return m.pt.Commit(pkt)
}
v := pkt[0] >> 4
var proto byte
switch v {
case 4:
proto = pkt[9]
case 6:
if len(pkt) < 40 {
return m.pt.Commit(pkt)
}
proto = pkt[6]
default:
return m.pt.Commit(pkt)
}
switch proto {
case ipProtoTCP:
if m.tcp != nil {
info, ok := parseTCPBase(pkt)
if !ok {
// Malformed/unsupported TCP shape (IP options, fragments, ...)
// — the TCP lane handles this as passthrough.
return m.tcp.Commit(pkt)
}
return m.tcp.commitParsed(pkt, info)
}
case ipProtoUDP:
if m.udp != nil {
info, ok := parseUDP(pkt)
if !ok {
return m.udp.Commit(pkt)
}
return m.udp.commitParsed(pkt, info)
}
}
return m.pt.Commit(pkt)
}
// Flush drains every lane in a fixed order: TCP, UDP, passthrough. Errors
// from a lane do not stop subsequent lanes from flushing — we keep
// draining and return the first observed error so a single bad packet
// doesn't strand the others.
func (m *MultiCoalescer) Flush() error {
var first error
keep := func(err error) {
if err != nil && first == nil {
first = err
}
}
if m.tcp != nil {
keep(m.tcp.Flush())
}
if m.udp != nil {
keep(m.udp.Flush())
}
keep(m.pt.Flush())
m.backing = m.backing[:0]
return first
}

View File

@@ -0,0 +1,94 @@
package batch
import (
"testing"
)
// TestMultiCoalescerRoutesByProto confirms TCP/UDP/other land in the right
// lane: TCP and UDP get coalesced when their lanes are enabled, anything
// else (ICMP here) falls through to plain Write.
func TestMultiCoalescerRoutesByProto(t *testing.T) {
w := &fakeTunWriter{gsoEnabled: true}
m := NewMultiCoalescer(w, true, true)
tcpPay := make([]byte, 1200)
udpPay := make([]byte, 1200)
icmp := make([]byte, 28)
icmp[0] = 0x45
icmp[2] = 0
icmp[3] = 28
icmp[9] = 1
if err := m.Commit(buildTCPv4(1000, tcpAck, tcpPay)); err != nil {
t.Fatal(err)
}
if err := m.Commit(buildTCPv4(2200, tcpAck, tcpPay)); err != nil {
t.Fatal(err)
}
if err := m.Commit(buildUDPv4(2000, 53, udpPay)); err != nil {
t.Fatal(err)
}
if err := m.Commit(buildUDPv4(2000, 53, udpPay)); err != nil {
t.Fatal(err)
}
if err := m.Commit(icmp); err != nil {
t.Fatal(err)
}
if err := m.Flush(); err != nil {
t.Fatal(err)
}
// 1 TCP super (2 segments) + 1 UDP super (2 segments) = 2 gso writes.
if len(w.gsoWrites) != 2 {
t.Fatalf("want 2 gso writes (one TCP + one UDP), got %d", len(w.gsoWrites))
}
if len(w.writes) != 1 {
t.Fatalf("want 1 plain write (ICMP), got %d", len(w.writes))
}
}
// TestMultiCoalescerDisabledUDPFallsThrough verifies that when the UDP lane
// is disabled (e.g. kernel doesn't support USO), UDP packets still reach
// the kernel via the passthrough lane rather than being lost.
func TestMultiCoalescerDisabledUDPFallsThrough(t *testing.T) {
w := &fakeTunWriter{gsoEnabled: true}
m := NewMultiCoalescer(w, true, false) // TSO on, USO off
if err := m.Commit(buildUDPv4(1000, 53, make([]byte, 800))); err != nil {
t.Fatal(err)
}
if err := m.Commit(buildUDPv4(1000, 53, make([]byte, 800))); err != nil {
t.Fatal(err)
}
if err := m.Flush(); err != nil {
t.Fatal(err)
}
if len(w.gsoWrites) != 0 {
t.Errorf("UDP must NOT be coalesced when USO disabled, got %d gso writes", len(w.gsoWrites))
}
if len(w.writes) != 2 {
t.Errorf("UDP must pass through as 2 plain writes, got %d", len(w.writes))
}
}
// TestMultiCoalescerDisabledTCPFallsThrough mirrors the TSO=off case.
func TestMultiCoalescerDisabledTCPFallsThrough(t *testing.T) {
w := &fakeTunWriter{gsoEnabled: true}
m := NewMultiCoalescer(w, false, true) // TSO off, USO on
pay := make([]byte, 1200)
if err := m.Commit(buildTCPv4(1000, tcpAck, pay)); err != nil {
t.Fatal(err)
}
if err := m.Commit(buildTCPv4(2200, tcpAck, pay)); err != nil {
t.Fatal(err)
}
if err := m.Flush(); err != nil {
t.Fatal(err)
}
if len(w.gsoWrites) != 0 {
t.Errorf("TCP must NOT be coalesced when TSO disabled, got %d gso writes", len(w.gsoWrites))
}
if len(w.writes) != 2 {
t.Errorf("TCP must pass through as 2 plain writes, got %d", len(w.writes))
}
}

View File

@@ -0,0 +1,722 @@
package batch
import (
"bytes"
"encoding/binary"
"io"
"log/slog"
"net/netip"
"sort"
"github.com/slackhq/nebula/overlay/tio"
)
// ipProtoTCP is the IANA protocol number for TCP. Hardcoded instead of
// reaching for golang.org/x/sys/unix — that package doesn't define the
// constant on Windows, which would break cross-compiles even though this
// file runs unchanged on every platform.
const ipProtoTCP = 6
// tcpCoalesceBufSize caps total bytes per superpacket. Mirrors the kernel's
// sk_gso_max_size of ~64KiB; anything beyond this would be rejected anyway.
const tcpCoalesceBufSize = 65535
// tcpCoalesceMaxSegs caps how many segments we'll coalesce into a single
// superpacket. Keeping this well below the kernel's TSO ceiling bounds
// latency.
const tcpCoalesceMaxSegs = 64
// tcpCoalesceHdrCap is the scratch space we copy a seed's IP+TCP header
// into. IPv6 (40) + TCP with full options (60) = 100 bytes.
const tcpCoalesceHdrCap = 100
// coalesceSlot is one entry in the coalescer's ordered event queue. When
// passthrough is true the slot holds a single borrowed packet that must be
// emitted verbatim (non-TCP, non-admissible TCP, or oversize seed). When
// passthrough is false the slot is an in-progress coalesced superpacket:
// hdrBuf is a mutable copy of the seed's IP+TCP header (we patch total
// length and pseudo-header partial at flush), and payIovs are *borrowed*
// slices from the caller's plaintext buffers — no payload is ever copied.
// The caller (listenOut) must keep those buffers alive until Flush.
type coalesceSlot struct {
passthrough bool
rawPkt []byte // borrowed when passthrough
fk flowKey
hdrBuf [tcpCoalesceHdrCap]byte
hdrLen int
ipHdrLen int
isV6 bool
gsoSize int
numSeg int
totalPay int
nextSeq uint32
// psh closes the chain: set when the last-accepted segment had PSH or
// was sub-gsoSize. No further appends after that.
psh bool
payIovs [][]byte
}
// TCPCoalescer accumulates adjacent in-flow TCP data segments across
// multiple concurrent flows and emits each flow's run as a single TSO
// superpacket via tio.GSOWriter. All output — coalesced or not — is
// deferred until Flush so arrival order is preserved on the wire. Owns
// no locks; one coalescer per TUN write queue.
type TCPCoalescer struct {
plainW io.Writer
gsoW tio.GSOWriter // nil when the queue doesn't support TSO
// slots is the ordered event queue. Flush walks it once and emits each
// entry as either a WriteGSO (coalesced) or a plainW.Write (passthrough).
slots []*coalesceSlot
// openSlots maps a flow key to its most recent non-sealed slot, so new
// segments can extend an in-progress superpacket in O(1). Slots are
// removed from this map when they close (PSH or short-last-segment),
// when a non-admissible packet for that flow arrives, or in Flush.
openSlots map[flowKey]*coalesceSlot
// lastSlot caches the most recently touched open slot. Steady-state
// bulk traffic is dominated by a single flow, so comparing the
// incoming key against the cached slot's own fk lets the hot path
// skip the map lookup (and the aeshash of a 38-byte key) entirely.
// Kept in lockstep with openSlots: nil whenever the slot it pointed
// at is removed/sealed.
lastSlot *coalesceSlot
pool []*coalesceSlot // free list for reuse
backing []byte
}
func NewTCPCoalescer(w io.Writer) *TCPCoalescer {
c := &TCPCoalescer{
plainW: w,
slots: make([]*coalesceSlot, 0, initialSlots),
openSlots: make(map[flowKey]*coalesceSlot, initialSlots),
pool: make([]*coalesceSlot, 0, initialSlots),
backing: make([]byte, 0, initialSlots*65535),
}
if gw, ok := tio.SupportsGSO(w, tio.GSOProtoTCP); ok {
c.gsoW = gw
}
return c
}
// parsedTCP holds the fields extracted from a single parse so later steps
// (admission, slot lookup, canAppend) don't re-walk the header.
type parsedTCP struct {
fk flowKey
ipHdrLen int
tcpHdrLen int
hdrLen int
payLen int
seq uint32
flags byte
}
// parseTCPBase extracts the flow key and IP/TCP offsets for any TCP packet,
// regardless of whether it's admissible for coalescing. Returns ok=false
// for non-TCP or malformed input. Accepts IPv4 (no options, no fragmentation)
// and IPv6 (no extension headers).
func parseTCPBase(pkt []byte) (parsedTCP, bool) {
var p parsedTCP
ip, ok := parseIPPrologue(pkt, ipProtoTCP)
if !ok {
return p, false
}
pkt = ip.pkt
p.fk = ip.fk
p.ipHdrLen = ip.ipHdrLen
if len(pkt) < p.ipHdrLen+20 {
return p, false
}
tcpOff := int(pkt[p.ipHdrLen+12]>>4) * 4
if tcpOff < 20 || tcpOff > 60 {
return p, false
}
if len(pkt) < p.ipHdrLen+tcpOff {
return p, false
}
p.tcpHdrLen = tcpOff
p.hdrLen = p.ipHdrLen + tcpOff
p.payLen = len(pkt) - p.hdrLen
p.seq = binary.BigEndian.Uint32(pkt[p.ipHdrLen+4 : p.ipHdrLen+8])
p.flags = pkt[p.ipHdrLen+13]
p.fk.sport = binary.BigEndian.Uint16(pkt[p.ipHdrLen : p.ipHdrLen+2])
p.fk.dport = binary.BigEndian.Uint16(pkt[p.ipHdrLen+2 : p.ipHdrLen+4])
return p, true
}
// TCP flag bits (byte 13 of the TCP header). Only the bits actually consulted
// by the coalescer are named; FIN/SYN/RST/URG/CWR are rejected via the
// negative mask in coalesceable, not by name.
const (
tcpFlagPsh = 0x08
tcpFlagAck = 0x10
tcpFlagEce = 0x40
)
// coalesceable reports whether a parsed TCP segment is eligible for
// coalescing. Accepts ACK, ACK|PSH, ACK|ECE, ACK|PSH|ECE with a
// non-empty payload. CWR is excluded because it marks a one-shot
// congestion-window-reduced transition the receiver must observe at a
// segment boundary.
func (p parsedTCP) coalesceable() bool {
if p.flags&tcpFlagAck == 0 {
return false
}
if p.flags&^(tcpFlagAck|tcpFlagPsh|tcpFlagEce) != 0 {
return false
}
return p.payLen > 0
}
func (c *TCPCoalescer) Reserve(sz int) []byte {
return reserveFromBacking(&c.backing, sz)
}
// Commit borrows pkt. The caller must keep pkt valid until the next Flush,
// whether or not the packet was coalesced — passthrough (non-admissible)
// packets are queued and written at Flush time, not synchronously.
func (c *TCPCoalescer) Commit(pkt []byte) error {
if c.gsoW == nil {
c.addPassthrough(pkt)
return nil
}
info, ok := parseTCPBase(pkt)
if !ok {
c.addPassthrough(pkt)
return nil
}
return c.commitParsed(pkt, info)
}
// commitParsed is the post-parse half of Commit. The caller must have
// already verified parseTCPBase succeeded (info is a valid TCP parse).
// Used by MultiCoalescer.Commit to avoid re-walking the IP/TCP header
// after the dispatcher has already done so.
func (c *TCPCoalescer) commitParsed(pkt []byte, info parsedTCP) error {
if c.gsoW == nil {
c.addPassthrough(pkt)
return nil
}
if !info.coalesceable() {
// TCP but not admissible (SYN/FIN/RST/URG/CWR or zero-payload).
// Seal this flow's open slot so later in-flow packets don't extend
// it and accidentally reorder past this passthrough.
if last := c.lastSlot; last != nil && last.fk == info.fk {
c.lastSlot = nil
}
delete(c.openSlots, info.fk)
c.addPassthrough(pkt)
return nil
}
// Single-flow fast path: with only one open flow the cache hits every
// packet, and len(openSlots)==1 lets us skip the 38-byte fk compare
// when there are multiple flows in flight (where the hit rate would
// be ~0 and the compare is pure overhead).
var open *coalesceSlot
if last := c.lastSlot; last != nil && len(c.openSlots) == 1 && last.fk == info.fk {
open = last
} else {
open = c.openSlots[info.fk]
}
if open != nil {
if c.canAppend(open, pkt, info) {
c.appendPayload(open, pkt, info)
if open.psh {
delete(c.openSlots, info.fk)
c.lastSlot = nil
} else {
c.lastSlot = open
}
return nil
}
// Can't extend — seal it and fall through to seed a fresh slot.
delete(c.openSlots, info.fk)
if c.lastSlot == open {
c.lastSlot = nil
}
}
c.seed(pkt, info)
return nil
}
// Flush emits every queued event in (per-flow) seq order. Coalesced slots
// go out via WriteGSO; passthrough slots go out via plainW.Write.
// reorderForFlush first sorts each flow's slots into TCP-seq order within
// passthrough-bounded segments and merges contiguous adjacent slots, so
// any wire-side reorder that crossed an rxOrder batch boundary doesn't
// get amplified into kernel-visible reorder by the slot machinery.
// Returns the first error observed; keeps draining so one bad packet
// doesn't hold up the rest. After Flush returns, borrowed payload slices
// may be recycled.
func (c *TCPCoalescer) Flush() error {
c.reorderForFlush()
var first error
for _, s := range c.slots {
var err error
if s.passthrough {
_, err = c.plainW.Write(s.rawPkt)
} else {
err = c.flushSlot(s)
}
if err != nil && first == nil {
first = err
}
c.release(s)
}
for i := range c.slots {
c.slots[i] = nil
}
c.slots = c.slots[:0]
for k := range c.openSlots {
delete(c.openSlots, k)
}
c.lastSlot = nil
c.backing = c.backing[:0]
return first
}
func (c *TCPCoalescer) addPassthrough(pkt []byte) {
s := c.take()
s.passthrough = true
s.rawPkt = pkt
c.slots = append(c.slots, s)
}
func (c *TCPCoalescer) seed(pkt []byte, info parsedTCP) {
if info.hdrLen > tcpCoalesceHdrCap || info.hdrLen+info.payLen > tcpCoalesceBufSize {
// Pathological shape — can't fit our scratch, emit as-is.
c.addPassthrough(pkt)
return
}
s := c.take()
s.passthrough = false
s.rawPkt = nil
copy(s.hdrBuf[:], pkt[:info.hdrLen])
s.hdrLen = info.hdrLen
s.ipHdrLen = info.ipHdrLen
s.isV6 = info.fk.isV6
s.fk = info.fk
s.gsoSize = info.payLen
s.numSeg = 1
s.totalPay = info.payLen
s.nextSeq = info.seq + uint32(info.payLen)
s.psh = info.flags&tcpFlagPsh != 0
s.payIovs = append(s.payIovs[:0], pkt[info.hdrLen:info.hdrLen+info.payLen])
c.slots = append(c.slots, s)
if !s.psh {
c.openSlots[info.fk] = s
c.lastSlot = s
} else if last := c.lastSlot; last != nil && last.fk == info.fk {
// PSH-on-seed seals the slot immediately. Any prior cached open
// slot for this flow has just been sealed-and-replaced by this
// passthrough-shaped seed, so drop the cache too.
c.lastSlot = nil
}
}
// canAppend reports whether info's packet extends the slot's seed: same
// header shape and stable contents, adjacent seq, not oversized, chain not
// closed.
func (c *TCPCoalescer) canAppend(s *coalesceSlot, pkt []byte, info parsedTCP) bool {
if s.psh {
return false
}
if info.hdrLen != s.hdrLen {
return false
}
if info.seq != s.nextSeq {
return false
}
if s.numSeg >= tcpCoalesceMaxSegs {
return false
}
if info.payLen > s.gsoSize {
return false
}
if s.hdrLen+s.totalPay+info.payLen > tcpCoalesceBufSize {
return false
}
// ECE state must be stable across a burst — receivers expect the
// flag set on every segment of a CE-echoing window or none.
seedFlags := s.hdrBuf[s.ipHdrLen+13]
if (seedFlags^info.flags)&tcpFlagEce != 0 {
return false
}
if !headersMatch(s.hdrBuf[:s.hdrLen], pkt[:info.hdrLen], s.isV6, s.ipHdrLen) {
return false
}
return true
}
func (c *TCPCoalescer) appendPayload(s *coalesceSlot, pkt []byte, info parsedTCP) {
s.payIovs = append(s.payIovs, pkt[info.hdrLen:info.hdrLen+info.payLen])
s.numSeg++
s.totalPay += info.payLen
s.nextSeq = info.seq + uint32(info.payLen)
if info.flags&tcpFlagPsh != 0 {
// Propagate PSH into the seed header so kernel TSO sets it on the
// last segment. Without this the sender's push signal is dropped.
s.hdrBuf[s.ipHdrLen+13] |= tcpFlagPsh
}
// Merge IP-level CE marks into the seed: headersMatch ignores ECN, so
// this is the one place the signal is preserved.
mergeECNIntoSeed(s.hdrBuf[:s.ipHdrLen], pkt[:s.ipHdrLen], s.isV6)
if info.payLen < s.gsoSize || info.flags&tcpFlagPsh != 0 {
s.psh = true
}
}
func (c *TCPCoalescer) take() *coalesceSlot {
if n := len(c.pool); n > 0 {
s := c.pool[n-1]
c.pool[n-1] = nil
c.pool = c.pool[:n-1]
return s
}
return &coalesceSlot{}
}
func (c *TCPCoalescer) release(s *coalesceSlot) {
s.passthrough = false
s.rawPkt = nil
for i := range s.payIovs {
s.payIovs[i] = nil
}
s.payIovs = s.payIovs[:0]
s.numSeg = 0
s.totalPay = 0
s.psh = false
c.pool = append(c.pool, s)
}
// flushSlot patches the header and calls WriteGSO. Does not remove the
// slot from c.slots.
func (c *TCPCoalescer) flushSlot(s *coalesceSlot) error {
total := s.hdrLen + s.totalPay
l4Len := total - s.ipHdrLen
hdr := s.hdrBuf[:s.hdrLen]
if s.isV6 {
binary.BigEndian.PutUint16(hdr[4:6], uint16(l4Len))
} else {
binary.BigEndian.PutUint16(hdr[2:4], uint16(total))
hdr[10] = 0
hdr[11] = 0
binary.BigEndian.PutUint16(hdr[10:12], ipv4HdrChecksum(hdr[:s.ipHdrLen]))
}
var psum uint32
if s.isV6 {
psum = pseudoSumIPv6(hdr[8:24], hdr[24:40], ipProtoTCP, l4Len)
} else {
psum = pseudoSumIPv4(hdr[12:16], hdr[16:20], ipProtoTCP, l4Len)
}
tcsum := s.ipHdrLen + 16
binary.BigEndian.PutUint16(hdr[tcsum:tcsum+2], foldOnceNoInvert(psum))
return c.gsoW.WriteGSO(hdr[:s.ipHdrLen], hdr[s.ipHdrLen:], s.payIovs, tio.GSOProtoTCP)
}
// headersMatch compares two IP+TCP header prefixes for byte-for-byte
// equality on every field that must be identical across coalesced
// segments. Size/IPID/IPCsum/seq/flags/tcpCsum are masked out, as is the
// 2-bit IP-level ECN field — appendPayload merges CE into the seed.
func headersMatch(a, b []byte, isV6 bool, ipHdrLen int) bool {
if len(a) != len(b) {
return false
}
if !ipHeadersMatch(a, b, isV6) {
return false
}
// TCP: compare [0:4] ports, [8:13] ack+dataoff, [14:16] window,
// [18:tcpHdrLen] options (incl. urgent).
tcp := ipHdrLen
if !bytes.Equal(a[tcp:tcp+4], b[tcp:tcp+4]) {
return false
}
if !bytes.Equal(a[tcp+8:tcp+13], b[tcp+8:tcp+13]) {
return false
}
if !bytes.Equal(a[tcp+14:tcp+16], b[tcp+14:tcp+16]) {
return false
}
if !bytes.Equal(a[tcp+18:], b[tcp+18:]) {
return false
}
return true
}
// reorderForFlush neutralizes wire-side reorder that the rxOrder buffer
// couldn't catch (anything crossing a recvmmsg batch boundary). Without
// this pass a small wire reorder — counter 250 arriving in batch K when
// 200..249 are coming in batch K+1 — would seed an out-of-seq slot first
// and emit it ahead of the lower-seq slot, manifesting at the inner TCP
// receiver as a much larger reorder than the wire actually had.
//
// Two phases:
// 1. Sort each passthrough-bounded segment of c.slots by (flow, seq).
// Cross-flow ordering inside a segment isn't preserved (it never was
// and doesn't matter for any single flow's TCP correctness).
// 2. Sweep once and merge adjacent same-flow slots whose ranges are now
// contiguous AND whose tail is gsoSize-aligned. The tail constraint
// matters because the kernel TSO splitter chops at gsoSize from the
// start of the merged payload — a short segment in the middle would
// desynchronize every later segment.
//
// Passthrough slots act as barriers: the merge check skips them on either
// side, so a SYN/FIN/RST/CWR is never reordered relative to its flow's
// data.
func (c *TCPCoalescer) reorderForFlush() {
if len(c.slots) <= 1 {
return
}
runStart := 0
for i := 0; i <= len(c.slots); i++ {
if i < len(c.slots) && !c.slots[i].passthrough {
continue
}
c.sortRun(c.slots[runStart:i])
runStart = i + 1
}
out := c.slots[:0]
logged := false
for _, s := range c.slots {
if n := len(out); n > 0 {
prev := out[n-1]
if !prev.passthrough && !s.passthrough && prev.fk == s.fk {
// Same-flow neighbors after sort. If they aren't seq-
// contiguous it's a real gap — packets the wire reordered
// across batches, or actual loss before nebula. Log it so
// the operator can quantify how often it happens; the data
// itself still emits in seq order, kernel TCP handles the
// gap via its OOO queue.
if prev.nextSeq != slotSeedSeq(s) {
logged = true
gap := int64(slotSeedSeq(s)) - int64(prev.nextSeq)
slog.Default().Warn("tcp coalesce: cross-slot seq gap",
"src", flowKeyAddr(s.fk, false),
"dst", flowKeyAddr(s.fk, true),
"sport", s.fk.sport,
"dport", s.fk.dport,
"prev_seed_seq", slotSeedSeq(prev),
"prev_next_seq", prev.nextSeq,
"this_seed_seq", slotSeedSeq(s),
"gap_bytes", gap,
"prev_seg_count", prev.numSeg,
"prev_total_pay", prev.totalPay,
)
}
if canMergeSlots(prev, s) {
mergeSlots(prev, s)
c.release(s)
continue
}
}
}
out = append(out, s)
}
if logged {
slog.Default().Warn("==== end of batch ====")
}
c.slots = out
}
// flowKeyAddr returns the src or dst address from fk as a netip.Addr for
// logging. Only used on the cold gap-log path so the netip allocation
// doesn't matter.
func flowKeyAddr(fk flowKey, dst bool) netip.Addr {
src := fk.src
if dst {
src = fk.dst
}
if fk.isV6 {
return netip.AddrFrom16(src)
}
var v4 [4]byte
copy(v4[:], src[:4])
return netip.AddrFrom4(v4)
}
// sortRun stable-sorts run by (flowKey, seedSeq) so each flow's slots
// cluster together in seq order, ready for the merge sweep. Stable so
// equal-key slots keep their original relative position (defensive — a
// duplicate seedSeq would already mean something's wrong upstream).
func (c *TCPCoalescer) sortRun(run []*coalesceSlot) {
if len(run) <= 1 {
return
}
sort.SliceStable(run, func(i, j int) bool {
a, b := run[i], run[j]
if cmp := flowKeyCompare(a.fk, b.fk); cmp != 0 {
return cmp < 0
}
return tcpSeqLess(slotSeedSeq(a), slotSeedSeq(b))
})
}
// slotSeedSeq returns the TCP seq of the slot's seed (first segment).
// nextSeq tracks the seq just past the last appended byte; subtracting
// totalPay walks back to the seed. uint32 wraparound is the right TCP
// arithmetic so no special-casing is needed.
func slotSeedSeq(s *coalesceSlot) uint32 {
return s.nextSeq - uint32(s.totalPay)
}
// tcpSeqLess reports whether a precedes b in TCP serial-number arithmetic
// (RFC 1323 §2.3). The signed int32 cast turns the modular subtraction
// into the right comparison even across the 2^32 wrap.
func tcpSeqLess(a, b uint32) bool {
return int32(a-b) < 0
}
// flowKeyCompare orders flowKeys deterministically. The exact ordering
// is irrelevant — only that same-flow slots cluster together so the
// post-sort sweep can merge contiguous pairs.
func flowKeyCompare(a, b flowKey) int {
if c := bytes.Compare(a.src[:], b.src[:]); c != 0 {
return c
}
if c := bytes.Compare(a.dst[:], b.dst[:]); c != 0 {
return c
}
if a.sport != b.sport {
if a.sport < b.sport {
return -1
}
return 1
}
if a.dport != b.dport {
if a.dport < b.dport {
return -1
}
return 1
}
if a.isV6 != b.isV6 {
if !a.isV6 {
return -1
}
return 1
}
return 0
}
// canMergeSlots reports whether s can fold into prev as one merged TSO
// superpacket. Same flow, contiguous TCP byte range, equal gsoSize, and
// fits within the kernel TSO limits. The tail-of-prev check rejects any
// merge whose first slot ended on a sub-gsoSize segment — kernel TSO
// would split the merged skb at gsoSize boundaries from the start, so a
// short segment in the middle would corrupt every later segment. PSH and
// ECE state must agree across both slots: PSH is a semantic delimiter
// (preserving the sender's push boundary) and ECE state must be uniform
// across a window (the same rule canAppend enforces for in-flow appends).
//
// Note: a slot sealed by reorder (canAppend returned false on seq
// mismatch) keeps psh=false, so this restriction does not block the
// reorder-fix merge — only legitimate PSH-set seals.
func canMergeSlots(prev, s *coalesceSlot) bool {
if prev.psh {
return false
}
if prev.fk != s.fk {
return false
}
if prev.gsoSize != s.gsoSize {
return false
}
if prev.nextSeq != slotSeedSeq(s) {
return false
}
if prev.numSeg+s.numSeg > tcpCoalesceMaxSegs {
return false
}
if prev.hdrLen+prev.totalPay+s.totalPay > tcpCoalesceBufSize {
return false
}
if len(prev.payIovs[len(prev.payIovs)-1]) != prev.gsoSize {
return false
}
prevFlags := prev.hdrBuf[prev.ipHdrLen+13]
sFlags := s.hdrBuf[s.ipHdrLen+13]
if (prevFlags^sFlags)&tcpFlagEce != 0 {
return false
}
if !headersMatch(prev.hdrBuf[:prev.hdrLen], s.hdrBuf[:s.hdrLen], prev.isV6, prev.ipHdrLen) {
return false
}
return true
}
// mergeSlots folds src into dst in place: payIovs concatenated, counters
// and totals updated, PSH and IP-level CE bits OR'd into the seed header
// so neither the push signal nor a CE mark is lost. The seed header's
// seq, gsoSize, and fk are unchanged. Caller is responsible for releasing
// src (it's no longer in c.slots after this call).
func mergeSlots(dst, src *coalesceSlot) {
dst.payIovs = append(dst.payIovs, src.payIovs...)
dst.numSeg += src.numSeg
dst.totalPay += src.totalPay
dst.nextSeq = src.nextSeq
if src.psh {
dst.psh = true
dst.hdrBuf[dst.ipHdrLen+13] |= tcpFlagPsh
}
mergeECNIntoSeed(dst.hdrBuf[:dst.ipHdrLen], src.hdrBuf[:src.ipHdrLen], dst.isV6)
}
// ipv4HdrChecksum computes the IPv4 header checksum over hdr (which must
// already have its checksum field zeroed) and returns the folded/inverted
// 16-bit value to store.
func ipv4HdrChecksum(hdr []byte) uint16 {
var sum uint32
for i := 0; i+1 < len(hdr); i += 2 {
sum += uint32(binary.BigEndian.Uint16(hdr[i : i+2]))
}
if len(hdr)%2 == 1 {
sum += uint32(hdr[len(hdr)-1]) << 8
}
for sum>>16 != 0 {
sum = (sum & 0xffff) + (sum >> 16)
}
return ^uint16(sum)
}
// pseudoSumIPv4 / pseudoSumIPv6 build the L4 pseudo-header partial sum
// expected by the virtio NEEDS_CSUM kernel path: the 32-bit accumulator
// before folding. proto selects the L4 (TCP or UDP); the UDP coalescer
// reuses these helpers.
func pseudoSumIPv4(src, dst []byte, proto byte, l4Len int) uint32 {
var sum uint32
sum += uint32(binary.BigEndian.Uint16(src[0:2]))
sum += uint32(binary.BigEndian.Uint16(src[2:4]))
sum += uint32(binary.BigEndian.Uint16(dst[0:2]))
sum += uint32(binary.BigEndian.Uint16(dst[2:4]))
sum += uint32(proto)
sum += uint32(l4Len)
return sum
}
func pseudoSumIPv6(src, dst []byte, proto byte, l4Len int) uint32 {
var sum uint32
for i := 0; i < 16; i += 2 {
sum += uint32(binary.BigEndian.Uint16(src[i : i+2]))
sum += uint32(binary.BigEndian.Uint16(dst[i : i+2]))
}
sum += uint32(l4Len >> 16)
sum += uint32(l4Len & 0xffff)
sum += uint32(proto)
return sum
}
// foldOnceNoInvert folds the 32-bit accumulator to 16 bits and returns it
// unchanged (no one's complement). This is what virtio NEEDS_CSUM wants in
// the L4 checksum field — the kernel will add the payload sum and invert.
func foldOnceNoInvert(sum uint32) uint16 {
for sum>>16 != 0 {
sum = (sum & 0xffff) + (sum >> 16)
}
return uint16(sum)
}

View File

@@ -0,0 +1,173 @@
package batch
import (
"encoding/binary"
"testing"
"github.com/slackhq/nebula/overlay/tio"
)
// nopTunWriter is a zero-alloc tio.GSOWriter for benchmarks. Discards
// everything but satisfies the interface the coalescer detects.
type nopTunWriter struct{}
func (nopTunWriter) Write(p []byte) (int, error) { return len(p), nil }
func (nopTunWriter) WriteGSO(hdr []byte, transportHdr []byte, pays [][]byte, _ tio.GSOProto) error {
return nil
}
func (nopTunWriter) Capabilities() tio.Capabilities {
return tio.Capabilities{TSO: true, USO: true}
}
// buildTCPv4BulkFlow returns a slice of N adjacent ACK-only TCP segments
// on a single 5-tuple, each carrying payloadLen bytes. Seq numbers are
// contiguous so every packet is coalesceable onto the previous one.
func buildTCPv4BulkFlow(n, payloadLen int) [][]byte {
pkts := make([][]byte, n)
pay := make([]byte, payloadLen)
seq := uint32(1000)
for i := range n {
pkts[i] = buildTCPv4(seq, tcpAck, pay)
seq += uint32(payloadLen)
}
return pkts
}
// buildTCPv4Interleaved returns nFlows * perFlow packets with per-flow
// seq continuity but round-robin across flows — worst case for any
// "last-slot" cache.
func buildTCPv4Interleaved(nFlows, perFlow, payloadLen int) [][]byte {
pay := make([]byte, payloadLen)
seqs := make([]uint32, nFlows)
for i := range seqs {
seqs[i] = uint32(1000 + i*1000000)
}
pkts := make([][]byte, 0, nFlows*perFlow)
for range perFlow {
for f := range nFlows {
sport := uint16(10000 + f)
pkts = append(pkts, buildTCPv4Ports(sport, 2000, seqs[f], tcpAck, pay))
seqs[f] += uint32(payloadLen)
}
}
return pkts
}
// buildICMPv4 returns a minimal non-TCP packet that takes the passthrough
// branch in Commit.
func buildICMPv4() []byte {
pkt := make([]byte, 28)
pkt[0] = 0x45
binary.BigEndian.PutUint16(pkt[2:4], 28)
pkt[9] = 1 // ICMP
copy(pkt[12:16], []byte{10, 0, 0, 1})
copy(pkt[16:20], []byte{10, 0, 0, 2})
return pkt
}
// runCommitBench drives Commit over pkts batchSize at a time, flushing
// between batches, and reports per-packet cost.
func runCommitBench(b *testing.B, pkts [][]byte, batchSize int) {
b.Helper()
c := NewTCPCoalescer(nopTunWriter{})
b.ReportAllocs()
b.SetBytes(int64(len(pkts[0])))
b.ResetTimer()
for i := 0; i < b.N; i++ {
pkt := pkts[i%len(pkts)]
if err := c.Commit(pkt); err != nil {
b.Fatal(err)
}
if (i+1)%batchSize == 0 {
if err := c.Flush(); err != nil {
b.Fatal(err)
}
}
}
// Drain any trailing partial batch so slot state doesn't leak across runs.
_ = c.Flush()
}
// BenchmarkCommitSingleFlow is the bulk-TCP steady state: one flow,
// contiguous seq, 1200-byte payloads. Every packet past the seed should
// append onto the open slot. This is the case we most care about.
func BenchmarkCommitSingleFlow(b *testing.B) {
pkts := buildTCPv4BulkFlow(tcpCoalesceMaxSegs, 1200)
runCommitBench(b, pkts, tcpCoalesceMaxSegs)
}
// BenchmarkCommitInterleaved4 has 4 concurrent bulk flows round-robined.
// A single-entry fast-path cache will miss on every packet; an N-way
// cache or map lookup carries the weight.
func BenchmarkCommitInterleaved4(b *testing.B) {
pkts := buildTCPv4Interleaved(4, tcpCoalesceMaxSegs, 1200)
runCommitBench(b, pkts, len(pkts))
}
// BenchmarkCommitInterleaved16 stresses the map at higher flow counts.
func BenchmarkCommitInterleaved16(b *testing.B) {
pkts := buildTCPv4Interleaved(16, tcpCoalesceMaxSegs, 1200)
runCommitBench(b, pkts, len(pkts))
}
// BenchmarkCommitPassthrough exercises the non-TCP branch: parseTCPBase
// bails early and addPassthrough is the only work.
func BenchmarkCommitPassthrough(b *testing.B) {
pkt := buildICMPv4()
pkts := make([][]byte, 64)
for i := range pkts {
pkts[i] = pkt
}
runCommitBench(b, pkts, 64)
}
// BenchmarkCommitNonCoalesceableTCP sends SYN|ACK packets on one flow.
// Each packet takes the "TCP but not admissible" branch which does a
// map delete + passthrough. Measures the seal-without-slot cost.
func BenchmarkCommitNonCoalesceableTCP(b *testing.B) {
pay := make([]byte, 0)
pkts := make([][]byte, 64)
for i := range pkts {
pkts[i] = buildTCPv4(uint32(1000+i), tcpSyn|tcpAck, pay)
}
runCommitBench(b, pkts, 64)
}
// runMultiCommitBench drives MultiCoalescer.Commit. The dispatcher does
// the IP/L4 parse once and passes the parsed struct to the lane, so this
// is the bench that shows the savings of skipping the lane's re-parse.
func runMultiCommitBench(b *testing.B, pkts [][]byte, batchSize int) {
b.Helper()
m := NewMultiCoalescer(nopTunWriter{}, true, true)
b.ReportAllocs()
b.SetBytes(int64(len(pkts[0])))
b.ResetTimer()
for i := 0; i < b.N; i++ {
pkt := pkts[i%len(pkts)]
if err := m.Commit(pkt); err != nil {
b.Fatal(err)
}
if (i+1)%batchSize == 0 {
if err := m.Flush(); err != nil {
b.Fatal(err)
}
}
}
_ = m.Flush()
}
// BenchmarkMultiCommitSingleFlow is the multi-lane analogue of
// BenchmarkCommitSingleFlow — same workload but routed through the
// dispatcher. The delta vs the single-lane bench measures dispatcher
// overhead.
func BenchmarkMultiCommitSingleFlow(b *testing.B) {
pkts := buildTCPv4BulkFlow(tcpCoalesceMaxSegs, 1200)
runMultiCommitBench(b, pkts, tcpCoalesceMaxSegs)
}
// BenchmarkMultiCommitInterleaved4 mirrors BenchmarkCommitInterleaved4
// through the dispatcher.
func BenchmarkMultiCommitInterleaved4(b *testing.B) {
pkts := buildTCPv4Interleaved(4, tcpCoalesceMaxSegs, 1200)
runMultiCommitBench(b, pkts, len(pkts))
}

File diff suppressed because it is too large Load Diff

View File

@@ -4,58 +4,63 @@ import "net/netip"
const SendBatchCap = 128
// SendBatch accumulates encrypted UDP packets for potential TX offloading.
// batchWriter is the minimal subset of udp.Conn needed by SendBatch to flush.
type batchWriter interface {
WriteBatch(bufs [][]byte, addrs []netip.AddrPort, outerECNs []byte) error
}
// SendBatch accumulates encrypted UDP packets and flushes them via WriteBatch.
// One SendBatch is owned by each listenIn goroutine; no locking is needed.
// The backing storage holds up to batchCap packets of slotCap bytes each;
// bufs and dsts are parallel slices of committed slots.
// The backing arena grows on demand: when there isn't room for the next slot
// we allocate a fresh backing array. Already-committed slices keep referencing
// the old array and remain valid until Flush drops them.
type SendBatch struct {
bufs [][]byte
dsts []netip.AddrPort
backing []byte
slotCap int
batchCap int
nextSlot int
out batchWriter
bufs [][]byte
dsts []netip.AddrPort
ecns []byte
backing []byte
}
func NewSendBatch(batchCap, slotCap int) *SendBatch {
func NewSendBatch(out batchWriter, batchCap, slotCap int) *SendBatch {
return &SendBatch{
bufs: make([][]byte, 0, batchCap),
dsts: make([]netip.AddrPort, 0, batchCap),
backing: make([]byte, batchCap*slotCap),
slotCap: slotCap,
batchCap: batchCap,
out: out,
bufs: make([][]byte, 0, batchCap),
dsts: make([]netip.AddrPort, 0, batchCap),
ecns: make([]byte, 0, batchCap),
backing: make([]byte, 0, batchCap*slotCap),
}
}
func (b *SendBatch) Next() []byte {
if b.nextSlot >= b.batchCap {
return nil
func (b *SendBatch) Reserve(sz int) []byte {
if len(b.backing)+sz > cap(b.backing) {
// Grow: allocate a fresh backing. Already-committed slices still
// reference the old array and remain valid until Flush drops them.
newCap := max(cap(b.backing)*2, sz)
b.backing = make([]byte, 0, newCap)
}
start := b.nextSlot * b.slotCap
return b.backing[start : start : start+b.slotCap] //set len to 0 but cap to slotCap
start := len(b.backing)
b.backing = b.backing[:start+sz]
return b.backing[start : start+sz : start+sz]
}
func (b *SendBatch) Commit(n int, dst netip.AddrPort) {
start := b.nextSlot * b.slotCap
b.bufs = append(b.bufs, b.backing[start:start+n])
func (b *SendBatch) Commit(pkt []byte, dst netip.AddrPort, outerECN byte) {
b.bufs = append(b.bufs, pkt)
b.dsts = append(b.dsts, dst)
b.nextSlot++
b.ecns = append(b.ecns, outerECN)
}
func (b *SendBatch) Reset() {
func (b *SendBatch) Flush() error {
var err error
if len(b.bufs) > 0 {
err = b.out.WriteBatch(b.bufs, b.dsts, b.ecns)
}
for i := range b.bufs {
b.bufs[i] = nil
}
b.bufs = b.bufs[:0]
b.dsts = b.dsts[:0]
b.nextSlot = 0
}
func (b *SendBatch) Len() int {
return len(b.bufs)
}
func (b *SendBatch) Cap() int {
return b.batchCap
}
func (b *SendBatch) Get() ([][]byte, []netip.AddrPort) {
return b.bufs, b.dsts
b.ecns = b.ecns[:0]
b.backing = b.backing[:0]
return err
}

View File

@@ -5,65 +5,120 @@ import (
"testing"
)
func TestSendBatchBookkeeping(t *testing.T) {
b := NewSendBatch(4, 32)
if b.Len() != 0 || b.Cap() != 4 {
t.Fatalf("fresh batch: len=%d cap=%d", b.Len(), b.Cap())
type fakeBatchWriter struct {
bufs [][]byte
addrs []netip.AddrPort
ecns []byte
}
func (w *fakeBatchWriter) WriteBatch(bufs [][]byte, addrs []netip.AddrPort, ecns []byte) error {
// Snapshot — SendBatch.Flush nils its slot pointers right after WriteBatch
// returns, so tests must capture data before that happens.
w.bufs = make([][]byte, len(bufs))
for i, b := range bufs {
cp := make([]byte, len(b))
copy(cp, b)
w.bufs[i] = cp
}
w.addrs = append(w.addrs[:0], addrs...)
w.ecns = append(w.ecns[:0], ecns...)
return nil
}
func TestSendBatchReserveCommitFlush(t *testing.T) {
fw := &fakeBatchWriter{}
b := NewSendBatch(fw, 4, 32)
ap := netip.MustParseAddrPort("10.0.0.1:4242")
for i := 0; i < 4; i++ {
slot := b.Next()
if slot == nil {
t.Fatalf("slot %d: Next returned nil before cap", i)
slot := b.Reserve(32)
if cap(slot) != 32 {
t.Fatalf("slot %d: cap=%d want 32", i, cap(slot))
}
if cap(slot) != 32 || len(slot) != 0 {
t.Fatalf("slot %d: got len=%d cap=%d want len=0 cap=32", i, len(slot), cap(slot))
}
// Write a marker byte.
slot = append(slot, byte(i), byte(i+1), byte(i+2))
b.Commit(len(slot), ap)
pkt := append(slot[:0], byte(i), byte(i+1), byte(i+2))
b.Commit(pkt, ap, 0)
}
if b.Next() != nil {
t.Fatalf("Next should return nil when full")
if err := b.Flush(); err != nil {
t.Fatalf("Flush: %v", err)
}
if b.Len() != 4 {
t.Fatalf("Len=%d want 4", b.Len())
if len(fw.bufs) != 4 {
t.Fatalf("WriteBatch got %d bufs want 4", len(fw.bufs))
}
for i, buf := range b.bufs {
for i, buf := range fw.bufs {
if len(buf) != 3 || buf[0] != byte(i) {
t.Errorf("buf %d: %x", i, buf)
}
if b.dsts[i] != ap {
t.Errorf("dst %d: got %v want %v", i, b.dsts[i], ap)
if fw.addrs[i] != ap {
t.Errorf("addr %d: got %v want %v", i, fw.addrs[i], ap)
}
}
// Reset returns empty and Next works again.
b.Reset()
if b.Len() != 0 {
t.Fatalf("after Reset Len=%d want 0", b.Len())
// Flush again with nothing committed — should be a no-op.
fw.bufs = nil
if err := b.Flush(); err != nil {
t.Fatalf("empty Flush: %v", err)
}
slot := b.Next()
if slot == nil || cap(slot) != 32 {
t.Fatalf("after Reset Next nil or wrong cap: %v cap=%d", slot == nil, cap(slot))
if fw.bufs != nil {
t.Fatalf("empty Flush triggered WriteBatch")
}
// Reuse after Flush.
slot := b.Reserve(32)
if cap(slot) != 32 {
t.Fatalf("after Flush Reserve wrong cap: %d", cap(slot))
}
}
func TestSendBatchSlotsDoNotOverlap(t *testing.T) {
b := NewSendBatch(3, 8)
fw := &fakeBatchWriter{}
b := NewSendBatch(fw, 3, 8)
ap := netip.MustParseAddrPort("10.0.0.1:80")
// Fill three slots, each with its own sentinel byte.
for i := 0; i < 3; i++ {
s := b.Next()
s = append(s, byte(0xA0+i), byte(0xB0+i))
b.Commit(len(s), ap)
s := b.Reserve(8)
pkt := append(s[:0], byte(0xA0+i), byte(0xB0+i))
b.Commit(pkt, ap, 0)
}
if err := b.Flush(); err != nil {
t.Fatalf("Flush: %v", err)
}
for i, buf := range b.bufs {
for i, buf := range fw.bufs {
if buf[0] != byte(0xA0+i) || buf[1] != byte(0xB0+i) {
t.Errorf("slot %d corrupted: %x", i, buf)
}
}
}
func TestSendBatchGrowPreservesCommitted(t *testing.T) {
fw := &fakeBatchWriter{}
// Tiny initial backing forces a grow on the second Reserve.
b := NewSendBatch(fw, 1, 4)
ap := netip.MustParseAddrPort("10.0.0.1:80")
s1 := b.Reserve(4)
pkt1 := append(s1[:0], 0x11, 0x22, 0x33, 0x44)
b.Commit(pkt1, ap, 0)
s2 := b.Reserve(8) // exceeds remaining cap, triggers grow
pkt2 := append(s2[:0], 0xA, 0xB, 0xC, 0xD, 0xE)
b.Commit(pkt2, ap, 0)
// pkt1 must still be intact even though backing reallocated.
if pkt1[0] != 0x11 || pkt1[3] != 0x44 {
t.Fatalf("first packet corrupted by grow: %x", pkt1)
}
if err := b.Flush(); err != nil {
t.Fatalf("Flush: %v", err)
}
if len(fw.bufs) != 2 {
t.Fatalf("got %d bufs want 2", len(fw.bufs))
}
if fw.bufs[0][0] != 0x11 || fw.bufs[0][3] != 0x44 {
t.Errorf("first packet on the wire: %x", fw.bufs[0])
}
if fw.bufs[1][0] != 0xA || fw.bufs[1][4] != 0xE {
t.Errorf("second packet on the wire: %x", fw.bufs[1])
}
}

View File

@@ -0,0 +1,342 @@
package batch
import (
"encoding/binary"
"io"
"github.com/slackhq/nebula/overlay/tio"
)
// ipProtoUDP is the IANA protocol number for UDP.
const ipProtoUDP = 17
// udpCoalesceBufSize caps total bytes per UDP superpacket. Mirrors the
// kernel's gso_max_size; payloads beyond this are emitted as-is.
const udpCoalesceBufSize = 65535
// udpCoalesceMaxSegs caps how many segments we'll coalesce. Kernel UDP-GSO
// accepts up to 64 segments per skb (UDP_MAX_SEGMENTS); stay under that.
const udpCoalesceMaxSegs = 64
// udpCoalesceHdrCap is the scratch space we copy a seed's IP+UDP header
// into. IPv6 (40) + UDP (8) = 48; round up for safety.
const udpCoalesceHdrCap = 64
// udpSlot is one entry in the UDPCoalescer's ordered event queue. Same
// passthrough-vs-coalesced shape as the TCP coalescer's slot, but no
// seq/PSH/CWR bookkeeping — UDP segments only need 5-tuple + length
// matching to coalesce.
type udpSlot struct {
passthrough bool
rawPkt []byte // borrowed when passthrough
fk flowKey
hdrBuf [udpCoalesceHdrCap]byte
hdrLen int
ipHdrLen int
isV6 bool
gsoSize int // per-segment UDP payload length
numSeg int
totalPay int
// sealed closes the chain: set when a sub-gsoSize segment is appended
// (kernel UDP-GSO requires every segment but the last to be exactly
// gsoSize) or when limits are hit. No further appends after.
sealed bool
payIovs [][]byte
}
// UDPCoalescer accumulates adjacent in-flow UDP datagrams across multiple
// concurrent flows and emits each flow's run as a single GSO_UDP_L4
// superpacket via tio.GSOWriter. Falls back to per-packet writes when the
// underlying writer doesn't support USO.
//
// All output — coalesced or not — is deferred until Flush so per-flow
// arrival order is preserved on the wire. Cross-flow order is NOT preserved
// across the TCP/UDP/passthrough split when this coalescer runs alongside
// others — see multi_coalesce.go. Per-flow order is preserved because a
// single 5-tuple only ever lands in one lane and each lane preserves its
// own slot order.
//
// Owns no locks; one coalescer per TUN write queue.
type UDPCoalescer struct {
plainW io.Writer
gsoW tio.GSOWriter // nil when the queue can't accept GSO_UDP_L4
slots []*udpSlot
openSlots map[flowKey]*udpSlot
pool []*udpSlot
backing []byte
}
// NewUDPCoalescer wraps w. The caller is responsible for only constructing
// this when the underlying Queue's Capabilities advertise USO; otherwise
// the kernel may reject GSO_UDP_L4 writes. If w does not implement
// tio.GSOWriter at all (single-packet Queue), the coalescer degrades to
// plain Writes — same defensive shape as the TCP coalescer.
func NewUDPCoalescer(w io.Writer) *UDPCoalescer {
c := &UDPCoalescer{
plainW: w,
slots: make([]*udpSlot, 0, initialSlots),
openSlots: make(map[flowKey]*udpSlot, initialSlots),
pool: make([]*udpSlot, 0, initialSlots),
backing: make([]byte, 0, initialSlots*udpCoalesceBufSize),
}
if gw, ok := tio.SupportsGSO(w, tio.GSOProtoUDP); ok {
c.gsoW = gw
}
return c
}
// parsedUDP holds the fields extracted from a single parse so later steps
// (admission, slot lookup, canAppend) don't re-walk the header.
type parsedUDP struct {
fk flowKey
ipHdrLen int
hdrLen int // ipHdrLen + 8
payLen int
}
// parseUDP extracts the flow key and IP/UDP offsets for a UDP packet.
// Returns ok=false for non-UDP, malformed, or unsupported header shapes
// (IPv4 with options/fragmentation, IPv6 with extension headers).
func parseUDP(pkt []byte) (parsedUDP, bool) {
var p parsedUDP
ip, ok := parseIPPrologue(pkt, ipProtoUDP)
if !ok {
return p, false
}
pkt = ip.pkt
p.fk = ip.fk
p.ipHdrLen = ip.ipHdrLen
if len(pkt) < p.ipHdrLen+8 {
return p, false
}
p.hdrLen = p.ipHdrLen + 8
// UDP `length` field: must equal IP-derived length-of-UDP-header-plus-payload.
udpLen := int(binary.BigEndian.Uint16(pkt[p.ipHdrLen+4 : p.ipHdrLen+6]))
if udpLen < 8 || udpLen > len(pkt)-p.ipHdrLen {
return p, false
}
p.payLen = udpLen - 8
p.fk.sport = binary.BigEndian.Uint16(pkt[p.ipHdrLen : p.ipHdrLen+2])
p.fk.dport = binary.BigEndian.Uint16(pkt[p.ipHdrLen+2 : p.ipHdrLen+4])
return p, true
}
func (c *UDPCoalescer) Reserve(sz int) []byte {
return reserveFromBacking(&c.backing, sz)
}
// Commit borrows pkt. The caller must keep pkt valid until the next Flush.
func (c *UDPCoalescer) Commit(pkt []byte) error {
if c.gsoW == nil {
c.addPassthrough(pkt)
return nil
}
info, ok := parseUDP(pkt)
if !ok {
c.addPassthrough(pkt)
return nil
}
return c.commitParsed(pkt, info)
}
// commitParsed is the post-parse half of Commit. The caller must have
// already verified parseUDP succeeded. Used by MultiCoalescer.Commit to
// avoid re-walking the IP/UDP header.
func (c *UDPCoalescer) commitParsed(pkt []byte, info parsedUDP) error {
if c.gsoW == nil {
c.addPassthrough(pkt)
return nil
}
if open := c.openSlots[info.fk]; open != nil {
if c.canAppend(open, pkt, info) {
c.appendPayload(open, pkt, info)
if open.sealed {
delete(c.openSlots, info.fk)
}
return nil
}
// Can't extend — seal it and fall through to seed a fresh slot.
delete(c.openSlots, info.fk)
}
c.seed(pkt, info)
return nil
}
func (c *UDPCoalescer) Flush() error {
var first error
for _, s := range c.slots {
var err error
if s.passthrough {
_, err = c.plainW.Write(s.rawPkt)
} else {
err = c.flushSlot(s)
}
if err != nil && first == nil {
first = err
}
c.release(s)
}
for i := range c.slots {
c.slots[i] = nil
}
c.slots = c.slots[:0]
for k := range c.openSlots {
delete(c.openSlots, k)
}
c.backing = c.backing[:0]
return first
}
func (c *UDPCoalescer) addPassthrough(pkt []byte) {
s := c.take()
s.passthrough = true
s.rawPkt = pkt
c.slots = append(c.slots, s)
}
func (c *UDPCoalescer) seed(pkt []byte, info parsedUDP) {
if info.hdrLen > udpCoalesceHdrCap || info.hdrLen+info.payLen > udpCoalesceBufSize {
c.addPassthrough(pkt)
return
}
s := c.take()
s.passthrough = false
s.rawPkt = nil
copy(s.hdrBuf[:], pkt[:info.hdrLen])
s.hdrLen = info.hdrLen
s.ipHdrLen = info.ipHdrLen
s.isV6 = info.fk.isV6
s.fk = info.fk
s.gsoSize = info.payLen
s.numSeg = 1
s.totalPay = info.payLen
s.sealed = false
s.payIovs = append(s.payIovs[:0], pkt[info.hdrLen:info.hdrLen+info.payLen])
c.slots = append(c.slots, s)
c.openSlots[info.fk] = s
}
// canAppend reports whether info's packet extends the slot's seed.
// Kernel UDP-GSO requires every segment except possibly the last to be
// exactly gsoSize, and the last may be shorter (≤ gsoSize).
func (c *UDPCoalescer) canAppend(s *udpSlot, pkt []byte, info parsedUDP) bool {
if s.sealed {
return false
}
if info.hdrLen != s.hdrLen {
return false
}
if s.numSeg >= udpCoalesceMaxSegs {
return false
}
if info.payLen > s.gsoSize {
return false
}
if s.hdrLen+s.totalPay+info.payLen > udpCoalesceBufSize {
return false
}
if !udpHeadersMatch(s.hdrBuf[:s.hdrLen], pkt[:info.hdrLen], s.isV6, s.ipHdrLen) {
return false
}
return true
}
func (c *UDPCoalescer) appendPayload(s *udpSlot, pkt []byte, info parsedUDP) {
s.payIovs = append(s.payIovs, pkt[info.hdrLen:info.hdrLen+info.payLen])
s.numSeg++
s.totalPay += info.payLen
// Merge IP-level CE marks into the seed (same trick TCP coalescer uses).
mergeECNIntoSeed(s.hdrBuf[:s.ipHdrLen], pkt[:s.ipHdrLen], s.isV6)
if info.payLen < s.gsoSize {
// Last-segment-can-be-shorter: this seals the chain.
s.sealed = true
}
}
func (c *UDPCoalescer) take() *udpSlot {
if n := len(c.pool); n > 0 {
s := c.pool[n-1]
c.pool[n-1] = nil
c.pool = c.pool[:n-1]
return s
}
return &udpSlot{}
}
func (c *UDPCoalescer) release(s *udpSlot) {
s.passthrough = false
s.rawPkt = nil
for i := range s.payIovs {
s.payIovs[i] = nil
}
s.payIovs = s.payIovs[:0]
s.numSeg = 0
s.totalPay = 0
s.sealed = false
c.pool = append(c.pool, s)
}
// flushSlot patches the IP header total length / IPv6 payload length and
// the UDP length to the *total* across all coalesced segments, then seeds
// the UDP checksum field with the pseudo-header partial (single-fold, not
// inverted) per virtio NEEDS_CSUM. The kernel's ip_rcv_core (v4) and
// ip6_rcv_core (v6) trim the skb to those length fields, so per-segment
// values would silently drop everything but the first segment. The kernel
// then walks each segment in __udp_gso_segment, recomputing per-segment
// uh->len / iph->tot_len / IPv6 plen and adjusting the checksum via
// `check = csum16_add(csum16_sub(uh->check, uh->len), newlen)` — meaning
// our seed's uh->check must be consistent with the seed's uh->len, which
// is what passing the total to both pseudoSum and the UDP length field
// guarantees.
func (c *UDPCoalescer) flushSlot(s *udpSlot) error {
hdr := s.hdrBuf[:s.hdrLen]
total := s.hdrLen + s.totalPay // full IP+UDP+all_payloads bytes
l4Len := total - s.ipHdrLen // total UDP (8 + sum of payloads)
if s.isV6 {
binary.BigEndian.PutUint16(hdr[4:6], uint16(l4Len))
} else {
binary.BigEndian.PutUint16(hdr[2:4], uint16(total))
hdr[10] = 0
hdr[11] = 0
binary.BigEndian.PutUint16(hdr[10:12], ipv4HdrChecksum(hdr[:s.ipHdrLen]))
}
// UDP length field (offset 4 inside the UDP header) = total UDP size.
binary.BigEndian.PutUint16(hdr[s.ipHdrLen+4:s.ipHdrLen+6], uint16(l4Len))
var psum uint32
if s.isV6 {
psum = pseudoSumIPv6(hdr[8:24], hdr[24:40], ipProtoUDP, l4Len)
} else {
psum = pseudoSumIPv4(hdr[12:16], hdr[16:20], ipProtoUDP, l4Len)
}
udpCsumOff := s.ipHdrLen + 6
binary.BigEndian.PutUint16(hdr[udpCsumOff:udpCsumOff+2], foldOnceNoInvert(psum))
return c.gsoW.WriteGSO(hdr[:s.ipHdrLen], hdr[s.ipHdrLen:], s.payIovs, tio.GSOProtoUDP)
}
// udpHeadersMatch compares two IP+UDP header prefixes for byte-equality on
// every field that must be identical across coalesced segments. Length
// fields and the ECN bits in IP TOS/TC are masked out — appendPayload
// merges CE into the seed; flushSlot rewrites lengths.
func udpHeadersMatch(a, b []byte, isV6 bool, ipHdrLen int) bool {
if len(a) != len(b) {
return false
}
if !ipHeadersMatch(a, b, isV6) {
return false
}
// UDP: compare sport+dport ([0:4]). Skip length [4:6] and checksum [6:8] —
// length varies (we rewrite at flush) and the checksum will be redone.
udp := ipHdrLen
if a[udp] != b[udp] || a[udp+1] != b[udp+1] || a[udp+2] != b[udp+2] || a[udp+3] != b[udp+3] {
return false
}
return true
}

View File

@@ -0,0 +1,383 @@
package batch
import (
"encoding/binary"
"testing"
)
// buildUDPv4 builds a minimal IPv4+UDP packet with the given payload and ports.
func buildUDPv4(sport, dport uint16, payload []byte) []byte {
const ipHdrLen = 20
const udpHdrLen = 8
total := ipHdrLen + udpHdrLen + len(payload)
pkt := make([]byte, total)
pkt[0] = 0x45
pkt[1] = 0x00
binary.BigEndian.PutUint16(pkt[2:4], uint16(total))
binary.BigEndian.PutUint16(pkt[4:6], 0)
binary.BigEndian.PutUint16(pkt[6:8], 0x4000)
pkt[8] = 64
pkt[9] = ipProtoUDP
copy(pkt[12:16], []byte{10, 0, 0, 1})
copy(pkt[16:20], []byte{10, 0, 0, 2})
binary.BigEndian.PutUint16(pkt[20:22], sport)
binary.BigEndian.PutUint16(pkt[22:24], dport)
binary.BigEndian.PutUint16(pkt[24:26], uint16(udpHdrLen+len(payload)))
binary.BigEndian.PutUint16(pkt[26:28], 0)
copy(pkt[28:], payload)
return pkt
}
// buildUDPv6 builds a minimal IPv6+UDP packet.
func buildUDPv6(sport, dport uint16, payload []byte) []byte {
const ipHdrLen = 40
const udpHdrLen = 8
total := ipHdrLen + udpHdrLen + len(payload)
pkt := make([]byte, total)
pkt[0] = 0x60
binary.BigEndian.PutUint16(pkt[4:6], uint16(udpHdrLen+len(payload)))
pkt[6] = ipProtoUDP
pkt[7] = 64
pkt[8] = 0xfe
pkt[9] = 0x80
pkt[23] = 1
pkt[24] = 0xfe
pkt[25] = 0x80
pkt[39] = 2
binary.BigEndian.PutUint16(pkt[40:42], sport)
binary.BigEndian.PutUint16(pkt[42:44], dport)
binary.BigEndian.PutUint16(pkt[44:46], uint16(udpHdrLen+len(payload)))
binary.BigEndian.PutUint16(pkt[46:48], 0)
copy(pkt[48:], payload)
return pkt
}
func TestUDPCoalescerPassthroughWhenGSOUnavailable(t *testing.T) {
w := &fakeTunWriter{gsoEnabled: false}
c := NewUDPCoalescer(w)
pkt := buildUDPv4(1000, 53, make([]byte, 100))
if err := c.Commit(pkt); err != nil {
t.Fatal(err)
}
if len(w.writes) != 0 || len(w.gsoWrites) != 0 {
t.Fatalf("no Add-time writes: writes=%d gso=%d", len(w.writes), len(w.gsoWrites))
}
if err := c.Flush(); err != nil {
t.Fatal(err)
}
if len(w.writes) != 1 || len(w.gsoWrites) != 0 {
t.Fatalf("want single plain write, got writes=%d gso=%d", len(w.writes), len(w.gsoWrites))
}
}
func TestUDPCoalescerNonUDPPassthrough(t *testing.T) {
w := &fakeTunWriter{gsoEnabled: true}
c := NewUDPCoalescer(w)
// ICMP packet
pkt := make([]byte, 28)
pkt[0] = 0x45
binary.BigEndian.PutUint16(pkt[2:4], 28)
pkt[9] = 1
copy(pkt[12:16], []byte{10, 0, 0, 1})
copy(pkt[16:20], []byte{10, 0, 0, 2})
if err := c.Commit(pkt); err != nil {
t.Fatal(err)
}
if err := c.Flush(); err != nil {
t.Fatal(err)
}
if len(w.writes) != 1 || len(w.gsoWrites) != 0 {
t.Fatalf("ICMP must pass through unchanged: writes=%d gso=%d", len(w.writes), len(w.gsoWrites))
}
}
func TestUDPCoalescerSeedThenFlushAlone(t *testing.T) {
w := &fakeTunWriter{gsoEnabled: true}
c := NewUDPCoalescer(w)
pkt := buildUDPv4(1000, 53, make([]byte, 800))
if err := c.Commit(pkt); err != nil {
t.Fatal(err)
}
if err := c.Flush(); err != nil {
t.Fatal(err)
}
// Single-segment flush goes through WriteGSO; the writer infers GSO_NONE
// from len(pays)==1 and the kernel fills in the UDP csum (NEEDS_CSUM).
if len(w.gsoWrites) != 1 || len(w.writes) != 0 {
t.Fatalf("single-seg flush: writes=%d gso=%d", len(w.writes), len(w.gsoWrites))
}
}
func TestUDPCoalescerCoalescesEqualSized(t *testing.T) {
w := &fakeTunWriter{gsoEnabled: true}
c := NewUDPCoalescer(w)
pay := make([]byte, 1200)
for i := 0; i < 3; i++ {
if err := c.Commit(buildUDPv4(1000, 53, pay)); err != nil {
t.Fatal(err)
}
}
if err := c.Flush(); err != nil {
t.Fatal(err)
}
if len(w.gsoWrites) != 1 {
t.Fatalf("want 1 gso write, got %d (plain=%d)", len(w.gsoWrites), len(w.writes))
}
g := w.gsoWrites[0]
if g.gsoSize != 1200 {
t.Errorf("gsoSize=%d want 1200", g.gsoSize)
}
if len(g.pays) != 3 {
t.Errorf("pay count=%d want 3", len(g.pays))
}
if g.csumStart != 20 {
t.Errorf("csumStart=%d want 20", g.csumStart)
}
// IP totalLen and UDP length must be the TOTAL across all segments —
// the kernel's ip_rcv_core trims skbs to iph->tot_len, so a per-segment
// value would silently drop everything but the first segment. Total =
// IP(20) + UDP(8) + 3*1200 = 3628.
gotTotalLen := binary.BigEndian.Uint16(g.hdr[2:4])
if gotTotalLen != 3628 {
t.Errorf("ipv4 total_len=%d want 3628 (must be total across segments)", gotTotalLen)
}
gotUDPLen := binary.BigEndian.Uint16(g.hdr[20+4 : 20+6])
if gotUDPLen != 8+3*1200 {
t.Errorf("udp len=%d want %d", gotUDPLen, 8+3*1200)
}
}
// Last segment may be shorter, sealing the chain.
func TestUDPCoalescerShortLastSegmentSeals(t *testing.T) {
w := &fakeTunWriter{gsoEnabled: true}
c := NewUDPCoalescer(w)
full := make([]byte, 1200)
tail := make([]byte, 600)
if err := c.Commit(buildUDPv4(1000, 53, full)); err != nil {
t.Fatal(err)
}
if err := c.Commit(buildUDPv4(1000, 53, full)); err != nil {
t.Fatal(err)
}
if err := c.Commit(buildUDPv4(1000, 53, tail)); err != nil {
t.Fatal(err)
}
// A 4th packet, even same-sized, must NOT join — chain is sealed.
if err := c.Commit(buildUDPv4(1000, 53, full)); err != nil {
t.Fatal(err)
}
if err := c.Flush(); err != nil {
t.Fatal(err)
}
if len(w.gsoWrites) != 2 {
t.Fatalf("want 2 gso writes (sealed + new seed), got %d", len(w.gsoWrites))
}
if len(w.gsoWrites[0].pays) != 3 {
t.Errorf("first super: want 3 pays, got %d", len(w.gsoWrites[0].pays))
}
if len(w.gsoWrites[1].pays) != 1 {
t.Errorf("second super: want 1 pay (re-seed), got %d", len(w.gsoWrites[1].pays))
}
}
// A larger-than-gsoSize packet cannot extend the slot — it reseeds.
func TestUDPCoalescerLargerThanSeedReseeds(t *testing.T) {
w := &fakeTunWriter{gsoEnabled: true}
c := NewUDPCoalescer(w)
if err := c.Commit(buildUDPv4(1000, 53, make([]byte, 800))); err != nil {
t.Fatal(err)
}
if err := c.Commit(buildUDPv4(1000, 53, make([]byte, 1200))); err != nil {
t.Fatal(err)
}
if err := c.Flush(); err != nil {
t.Fatal(err)
}
if len(w.gsoWrites) != 2 {
t.Fatalf("want 2 separate seeds, got %d", len(w.gsoWrites))
}
}
// Different 5-tuples must not coalesce.
func TestUDPCoalescerDifferentFlowsKeepSeparate(t *testing.T) {
w := &fakeTunWriter{gsoEnabled: true}
c := NewUDPCoalescer(w)
pay := make([]byte, 800)
if err := c.Commit(buildUDPv4(1000, 53, pay)); err != nil {
t.Fatal(err)
}
if err := c.Commit(buildUDPv4(2000, 53, pay)); err != nil {
t.Fatal(err)
}
if err := c.Commit(buildUDPv4(1000, 53, pay)); err != nil {
t.Fatal(err)
}
if err := c.Commit(buildUDPv4(2000, 53, pay)); err != nil {
t.Fatal(err)
}
if err := c.Flush(); err != nil {
t.Fatal(err)
}
// Two flows × 2 datagrams each = 2 superpackets of 2 segments.
if len(w.gsoWrites) != 2 {
t.Fatalf("want 2 gso writes (one per flow), got %d", len(w.gsoWrites))
}
for i, g := range w.gsoWrites {
if len(g.pays) != 2 {
t.Errorf("super %d: want 2 pays, got %d", i, len(g.pays))
}
}
}
// Caps at udpCoalesceMaxSegs.
func TestUDPCoalescerCapsAtMaxSegs(t *testing.T) {
w := &fakeTunWriter{gsoEnabled: true}
c := NewUDPCoalescer(w)
pay := make([]byte, 100)
for i := 0; i < udpCoalesceMaxSegs+5; i++ {
if err := c.Commit(buildUDPv4(1000, 53, pay)); err != nil {
t.Fatal(err)
}
}
if err := c.Flush(); err != nil {
t.Fatal(err)
}
// First superpacket holds udpCoalesceMaxSegs segments; the spillover
// reseeds a new one.
if len(w.gsoWrites) != 2 {
t.Fatalf("want 2 gso writes (cap then reseed), got %d", len(w.gsoWrites))
}
if len(w.gsoWrites[0].pays) != udpCoalesceMaxSegs {
t.Errorf("first super: pays=%d want %d", len(w.gsoWrites[0].pays), udpCoalesceMaxSegs)
}
if len(w.gsoWrites[1].pays) != 5 {
t.Errorf("second super: pays=%d want 5", len(w.gsoWrites[1].pays))
}
}
// CE marks on appended segments must be merged into the seed's IP TOS.
func TestUDPCoalescerMergesCEMark(t *testing.T) {
w := &fakeTunWriter{gsoEnabled: true}
c := NewUDPCoalescer(w)
pay := make([]byte, 800)
pkt0 := buildUDPv4(1000, 53, pay) // ECN=00
pkt1 := buildUDPv4(1000, 53, pay)
pkt1[1] = 0x03 // CE
pkt2 := buildUDPv4(1000, 53, pay)
if err := c.Commit(pkt0); err != nil {
t.Fatal(err)
}
if err := c.Commit(pkt1); err != nil {
t.Fatal(err)
}
if err := c.Commit(pkt2); err != nil {
t.Fatal(err)
}
if err := c.Flush(); err != nil {
t.Fatal(err)
}
if len(w.gsoWrites) != 1 {
t.Fatalf("want 1 merged gso write, got %d (plain=%d)", len(w.gsoWrites), len(w.writes))
}
if w.gsoWrites[0].hdr[1]&0x03 != 0x03 {
t.Errorf("CE not merged into seed (tos=%#x)", w.gsoWrites[0].hdr[1])
}
}
// IPv6 path: same flow, equal-sized → coalesced.
func TestUDPCoalescerIPv6Coalesces(t *testing.T) {
w := &fakeTunWriter{gsoEnabled: true}
c := NewUDPCoalescer(w)
pay := make([]byte, 1200)
for i := 0; i < 3; i++ {
if err := c.Commit(buildUDPv6(1000, 53, pay)); err != nil {
t.Fatal(err)
}
}
if err := c.Flush(); err != nil {
t.Fatal(err)
}
if len(w.gsoWrites) != 1 {
t.Fatalf("want 1 gso write, got %d", len(w.gsoWrites))
}
g := w.gsoWrites[0]
if !g.isV6 {
t.Errorf("expected v6 write")
}
if g.csumStart != 40 {
t.Errorf("csumStart=%d want 40", g.csumStart)
}
// IPv6 payload_len and UDP length must be TOTAL — kernel's
// ip6_rcv_core trims to payload_len + ipv6 hdr size. Total UDP = 8 +
// 3*1200 = 3608.
gotPlen := binary.BigEndian.Uint16(g.hdr[4:6])
if gotPlen != 8+3*1200 {
t.Errorf("ipv6 payload_len=%d want %d (must be total)", gotPlen, 8+3*1200)
}
gotUDPLen := binary.BigEndian.Uint16(g.hdr[40+4 : 40+6])
if gotUDPLen != 8+3*1200 {
t.Errorf("udp len=%d want %d", gotUDPLen, 8+3*1200)
}
}
// DSCP differences must reseed (headers don't match outside ECN).
func TestUDPCoalescerDSCPMismatchReseeds(t *testing.T) {
w := &fakeTunWriter{gsoEnabled: true}
c := NewUDPCoalescer(w)
pay := make([]byte, 800)
pkt0 := buildUDPv4(1000, 53, pay)
pkt1 := buildUDPv4(1000, 53, pay)
pkt1[1] = 0xb8 // EF DSCP, ECN=0
if err := c.Commit(pkt0); err != nil {
t.Fatal(err)
}
if err := c.Commit(pkt1); err != nil {
t.Fatal(err)
}
if err := c.Flush(); err != nil {
t.Fatal(err)
}
if len(w.gsoWrites) != 2 {
t.Fatalf("want 2 separate seeds (different DSCP), got %d", len(w.gsoWrites))
}
}
// Fragmented IPv4 must not be coalesced.
func TestUDPCoalescerFragmentedIPv4PassesThrough(t *testing.T) {
w := &fakeTunWriter{gsoEnabled: true}
c := NewUDPCoalescer(w)
pkt := buildUDPv4(1000, 53, make([]byte, 200))
binary.BigEndian.PutUint16(pkt[6:8], 0x2000) // MF=1
if err := c.Commit(pkt); err != nil {
t.Fatal(err)
}
if err := c.Flush(); err != nil {
t.Fatal(err)
}
if len(w.writes) != 1 || len(w.gsoWrites) != 0 {
t.Fatalf("frag must pass through plain, got writes=%d gso=%d", len(w.writes), len(w.gsoWrites))
}
}
// IPv4 with options is not admissible (we require IHL=5).
func TestUDPCoalescerIPv4WithOptionsPassesThrough(t *testing.T) {
w := &fakeTunWriter{gsoEnabled: true}
c := NewUDPCoalescer(w)
pkt := buildUDPv4(1000, 53, make([]byte, 200))
pkt[0] = 0x46 // IHL = 6 (24-byte IPv4 header — has options)
if err := c.Commit(pkt); err != nil {
t.Fatal(err)
}
if err := c.Flush(); err != nil {
t.Fatal(err)
}
if len(w.writes) != 1 || len(w.gsoWrites) != 0 {
t.Fatalf("ipv4-with-options must pass through plain, got writes=%d gso=%d", len(w.writes), len(w.gsoWrites))
}
}

View File

@@ -18,7 +18,7 @@ type Device interface {
Networks() []netip.Prefix
Name() string
RoutesFor(netip.Addr) routing.Gateways
SupportsMultiqueue() bool //todo remove?
SupportsMultiqueue() bool
NewMultiQueueReader() error
Readers() []tio.Queue
}

View File

@@ -31,7 +31,7 @@ func (NoopTun) Name() string {
return "noop"
}
func (NoopTun) Read() ([][]byte, error) {
func (NoopTun) Read() ([]tio.Packet, error) {
return nil, nil
}

View File

@@ -0,0 +1,79 @@
package tio
import (
"encoding/binary"
"errors"
"fmt"
"golang.org/x/sys/unix"
)
type offloadQueueSet struct {
pq []*Offload
// pqi is exactly the same as pq, but stored as the interface type
pqi []Queue
shutdownFd int
// usoEnabled is true when newTun successfully negotiated TUN_F_USO4|6
// with the kernel. Queues created by Add inherit this and surface it
// via Offload.USOSupported so coalescers can gate USO emission.
usoEnabled bool
}
// NewOffloadQueueSet creates a QueueSet that uses virtio_net_hdr to do
// TSO segmentation in userspace. usoEnabled tells downstream queues whether
// the kernel agreed to deliver/accept GSO_UDP_L4 superpackets — coalescers
// should fall back to per-packet writes when this is false.
func NewOffloadQueueSet(usoEnabled bool) (QueueSet, error) {
shutdownFd, err := unix.Eventfd(0, unix.EFD_NONBLOCK|unix.EFD_CLOEXEC)
if err != nil {
return nil, fmt.Errorf("failed to create eventfd: %w", err)
}
out := &offloadQueueSet{
pq: []*Offload{},
pqi: []Queue{},
shutdownFd: shutdownFd,
usoEnabled: usoEnabled,
}
return out, nil
}
func (c *offloadQueueSet) Queues() []Queue {
return c.pqi
}
func (c *offloadQueueSet) Add(fd int) error {
x, err := newOffload(fd, c.shutdownFd, c.usoEnabled)
if err != nil {
return err
}
c.pq = append(c.pq, x)
c.pqi = append(c.pqi, x)
return nil
}
func (c *offloadQueueSet) wakeForShutdown() error {
var buf [8]byte
binary.NativeEndian.PutUint64(buf[:], 1)
_, err := unix.Write(c.shutdownFd, buf[:])
return err
}
func (c *offloadQueueSet) Close() error {
errs := []error{}
// Signal all readers blocked in poll to wake up and exit
if err := c.wakeForShutdown(); err != nil {
errs = append(errs, err)
}
for _, x := range c.pq {
if err := x.Close(); err != nil {
errs = append(errs, err)
}
}
return errors.Join(errs...)
}

View File

@@ -8,20 +8,20 @@ import (
"golang.org/x/sys/unix"
)
type pollContainer struct {
type pollQueueSet struct {
pq []*Poll
// pqi is exactly the same as pq, but stored as the interface type
pqi []Queue
shutdownFd int
}
func NewPollContainer() (Container, error) {
func NewPollQueueSet() (QueueSet, error) {
shutdownFd, err := unix.Eventfd(0, unix.EFD_NONBLOCK|unix.EFD_CLOEXEC)
if err != nil {
return nil, fmt.Errorf("failed to create eventfd: %w", err)
}
out := &pollContainer{
out := &pollQueueSet{
pq: []*Poll{},
pqi: []Queue{},
shutdownFd: shutdownFd,
@@ -30,11 +30,11 @@ func NewPollContainer() (Container, error) {
return out, nil
}
func (c *pollContainer) Queues() []Queue {
func (c *pollQueueSet) Queues() []Queue {
return c.pqi
}
func (c *pollContainer) Add(fd int) error {
func (c *pollQueueSet) Add(fd int) error {
x, err := newPoll(fd, c.shutdownFd)
if err != nil {
return err
@@ -45,14 +45,14 @@ func (c *pollContainer) Add(fd int) error {
return nil
}
func (c *pollContainer) wakeForShutdown() error {
func (c *pollQueueSet) wakeForShutdown() error {
var buf [8]byte
binary.NativeEndian.PutUint64(buf[:], 1)
_, err := unix.Write(int(c.shutdownFd), buf[:])
return err
}
func (c *pollContainer) Close() error {
func (c *pollQueueSet) Close() error {
errs := []error{}
if err := c.wakeForShutdown(); err != nil {

View File

@@ -0,0 +1,65 @@
//go:build linux && !android && !e2e_testing
package tio
import "testing"
// fakeBatch stands in for batch.TxBatcher inside the bench — same shape
// of pointer-capturing closure that sendInsideMessage builds.
type fakeBatch struct{ buf [65536]byte }
func (b *fakeBatch) Reserve(sz int) []byte { return b.buf[:sz] }
func (b *fakeBatch) Commit([]byte) {}
type fakeHostInfo struct {
remoteIndexId uint32
counter uint64
}
type fakeIface struct {
rebindCount uint8
hi *fakeHostInfo
}
// BenchmarkSegmentSuperpacketAllocsTSO measures allocation per
// SegmentSuperpacket call when a closure captures pointer-bearing
// receivers — the realistic shape of sendInsideMessage's closure.
func BenchmarkSegmentSuperpacketAllocsTSO(b *testing.B) {
const mss = 1400
const numSeg = 32
pkt := buildTSOv6(mss*numSeg, mss)
gso := GSOInfo{
Size: mss,
HdrLen: 60, // 40 (IPv6) + 20 (TCP)
CsumStart: 40,
Proto: GSOProtoTCP,
}
p := Packet{Bytes: pkt, GSO: gso}
hi := &fakeHostInfo{remoteIndexId: 0xdeadbeef}
f := &fakeIface{rebindCount: 7, hi: hi}
fb := &fakeBatch{}
// SegmentSuperpacket consumes pkt destructively; refresh from a master
// copy each iter (matches the production pattern where every TUN read
// hands the segmenter a fresh kernel-supplied buffer).
master := append([]byte(nil), pkt...)
work := make([]byte, len(pkt))
p.Bytes = work
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
copy(work, master)
err := SegmentSuperpacket(p, func(seg []byte) error {
out := fb.Reserve(16 + len(seg) + 16)
out[0] = byte(f.rebindCount)
out[1] = byte(hi.counter)
hi.counter++
fb.Commit(out)
return nil
})
if err != nil {
b.Fatalf("SegmentSuperpacket: %v", err)
}
}
}

View File

@@ -0,0 +1,18 @@
//go:build !linux || android || e2e_testing
package tio
import "fmt"
// SegmentSuperpacket invokes fn once per segment of pkt. On non-Linux
// builds (and Android/e2e_testing) this package does not provide a Queue
// implementation, so any caller that does construct a Packet here can only
// be operating on non-superpacket bytes and the stub forwards them
// directly. A non-zero GSO field is a programming error from the caller
// and returns an explicit error rather than silently misbehaving.
func SegmentSuperpacket(pkt Packet, scratch []byte, fn func(seg []byte) error) error {
if pkt.GSO.IsSuperpacket() {
return fmt.Errorf("tio: GSO superpacket on platform without segmentation support")
}
return fn(pkt.Bytes)
}

View File

@@ -1,56 +1,170 @@
package tio
import "io"
import (
"io"
)
// defaultBatchBufSize is the per-Queue scratch size for Read on backends
// that don't do TSO segmentation. 65535 covers any single IP packet.
const defaultBatchBufSize = 65535
// Container holds one or many Queue objects and helps close them in an orderly way
type Container interface {
// QueueSet holds one or many Queue objects and helps close them in an orderly way.
type QueueSet interface {
io.Closer
Queues() []Queue
// Add takes a tun fd, adds it to the container, and prepares it for use as a Queue
// Add takes a tun fd, adds it to the set, and prepares it for use as a Queue.
Add(fd int) error
}
// Capabilities advertises which kernel offload features a Queue
// successfully negotiated. Callers consult this to decide which coalescers
// to wire onto the write path — a Queue without TSO can't usefully accept a
// TCPCoalescer, and a Queue without USO can't accept a UDPCoalescer.
type Capabilities struct {
// TSO means the FD was opened with IFF_VNET_HDR and the kernel agreed
// to TUN_F_TSO4|TSO6 — i.e. WriteGSO with GSOProtoTCP is safe.
TSO bool
// USO means the kernel additionally agreed to TUN_F_USO4|USO6, so
// WriteGSO with GSOProtoUDP is safe. Linux ≥ 6.2.
USO bool
}
// Queue is a readable/writable Poll queue. One Queue is driven by a single
// read goroutine plus concurrent writers (see Write / WriteReject below).
// read goroutine plus a single writer (see Write below).
type Queue interface {
io.Closer
// Read returns one or more packets. The returned slices are borrowed
// from the Queue's internal buffer and are only valid until the next
// Read or Close on this Queue - callers must encrypt or copy each
// slice before the next call. Not safe for concurrent Reads.
Read() ([][]byte, error)
// Read returns one or more packets. The returned Packet.Bytes slices
// are borrowed from the Queue's internal buffer and are only valid
// until the next Read or Close on this Queue - callers must encrypt
// or copy each slice before the next call. A Packet may carry a
// GSO/USO superpacket (see GSOInfo); when GSO.IsSuperpacket() is
// true the caller must segment Bytes before treating it as a single
// IP datagram. Not safe for concurrent Reads.
Read() ([]Packet, error)
// Write emits a single packet on the plaintext (outside→inside)
// delivery path. Not safe for concurrent Writes.
Write(p []byte) (int, error)
}
// GSOWriter is implemented by Queues that can emit a TCP TSO superpacket
// Packet is the unit Queue.Read returns. Bytes points into the queue's
// internal buffer and is only valid until the next Read or Close on the
// queue that produced it. GSO is the zero value for an already-segmented
// IP datagram; when non-zero it describes a kernel-supplied TSO/USO
// superpacket the caller must segment before consuming.
type Packet struct {
Bytes []byte
GSO GSOInfo
}
// GSOInfo describes a kernel-supplied superpacket sitting in Packet.Bytes.
// The zero value means "not a superpacket" — Bytes is one regular IP
// datagram and no segmentation is required.
type GSOInfo struct {
// Size is the GSO segment size: max payload bytes per segment
// (== TCP MSS for TSO, == UDP payload chunk for USO). Zero means
// not a superpacket.
Size uint16
// HdrLen is the total L3+L4 header length within Bytes (already
// corrected via correctHdrLen, so safe to slice on).
HdrLen uint16
// CsumStart is the L4 header offset inside Bytes (== L3 header
// length).
CsumStart uint16
// Proto picks the L4 protocol (TCP or UDP) so the segmenter knows
// which checksum/header layout to apply.
Proto GSOProto
}
// IsSuperpacket reports whether g describes a multi-segment GSO/USO
// superpacket that needs segmentation before its bytes can be encrypted
// and sent on the wire.
func (g GSOInfo) IsSuperpacket() bool { return g.Size > 0 }
// Clone returns a Packet whose Bytes is a freshly allocated copy of p.Bytes,
// safe to retain past the next Read or Close on the originating Queue.
// GSO metadata is copied verbatim. Use this only when a caller genuinely
// needs to outlive the borrowed-slice contract — the hot path reads should
// continue to consume the borrow synchronously to avoid the allocation.
func (p Packet) Clone() Packet {
if p.Bytes == nil {
return p
}
cp := make([]byte, len(p.Bytes))
copy(cp, p.Bytes)
return Packet{Bytes: cp, GSO: p.GSO}
}
// CapsProvider is an optional interface implemented by Queues that
// successfully negotiated kernel offload features at open time. Callers
// pick a write-path coalescer based on the result. Queues that don't
// implement it are treated as having no offload capability — callers must
// fall back to plain per-packet writes.
type CapsProvider interface {
Capabilities() Capabilities
}
// QueueCapabilities returns q's negotiated offload capabilities, or the
// zero value when q does not advertise any.
func QueueCapabilities(q Queue) Capabilities {
if cp, ok := q.(CapsProvider); ok {
return cp.Capabilities()
}
return Capabilities{}
}
// GSOProto selects the L4 protocol for a GSO superpacket. Determines which
// VIRTIO_NET_HDR_GSO_* type the writer stamps and which checksum offset
// inside the transport header virtio NEEDS_CSUM expects.
type GSOProto uint8
const (
GSOProtoTCP GSOProto = iota
GSOProtoUDP
)
// GSOWriter is implemented by Queues that can emit a TCP or UDP superpacket
// assembled from a header prefix plus one or more borrowed payload
// fragments, in a single vectored write (writev with a leading
// virtio_net_hdr). This lets the coalescer avoid copying payload bytes
// between the caller's decrypt buffer and the TUN. Backends without GSO
// support return false from GSOSupported and coalescing is skipped.
// support do not implement this interface and coalescing is skipped.
//
// hdr contains the IPv4/IPv6 + TCP header prefix (mutable - callers will
// have filled in total length and pseudo-header partial). pays are
// non-overlapping payload fragments whose concatenation is the full
// superpacket payload; they are read-only from the writer's perspective
// and must remain valid until the call returns. gsoSize is the MSS:
// every segment except possibly the last is exactly that many bytes.
// csumStart is the byte offset where the TCP header begins within hdr.
// hdr contains the IPv4/IPv6 header prefix (mutable - callers will have
// filled in total length and IP csum). transportHdr is the TCP or UDP
// header (mutable - the L4 checksum field must hold the pseudo-header
// partial, single-fold not inverted, per virtio NEEDS_CSUM semantics).
// pays are non-overlapping payload fragments whose concatenation is the
// full superpacket payload; they are read-only from the writer's
// perspective and must remain valid until the call returns. Every segment
// in pays except possibly the last is exactly the same size. proto picks
// the L4 protocol so the writer knows which GSOType / CsumOffset to set.
//
// # TODO fold into Queue
//
// hdr's TCP checksum field must already hold the pseudo-header partial
// sum (single-fold, not inverted), per virtio NEEDS_CSUM semantics.
// Callers should also consult CapsProvider (via SupportsGSO or
// QueueCapabilities) for the per-protocol negotiated capability; an
// implementation of GSOWriter is necessary but not sufficient since USO
// may not have been negotiated even when TSO was.
type GSOWriter interface {
WriteGSO(hdr []byte, pays [][]byte, gsoSize uint16, isV6 bool, csumStart uint16) error
GSOSupported() bool
WriteGSO(hdr []byte, transportHdr []byte, pays [][]byte, proto GSOProto) error
}
// SupportsGSO reports whether w implements GSOWriter and the underlying
// queue advertises the negotiated capability for `want`. A writer that
// implements GSOWriter but not CapsProvider is treated as permissive
// (used by tests and fakes that don't negotiate).
func SupportsGSO(w any, want GSOProto) (GSOWriter, bool) {
gw, ok := w.(GSOWriter)
if !ok {
return nil, false
}
cp, ok := w.(CapsProvider)
if !ok {
return gw, true
}
caps := cp.Capabilities()
switch want {
case GSOProtoTCP:
return gw, caps.TSO
case GSOProtoUDP:
return gw, caps.USO
}
return gw, false
}

View File

@@ -0,0 +1,461 @@
package tio
import (
"fmt"
"io"
"log/slog"
"os"
"sync"
"sync/atomic"
"syscall"
"unsafe"
"golang.org/x/sys/unix"
"github.com/slackhq/nebula/overlay/tio/virtio"
)
// tunRxBufSize is the per-Read worst-case footprint inside rxBuf: one
// kernel-supplied packet body, which is at most ~64 KiB (tunReadBufSize).
// Segmentation happens at encrypt time on a per-routine MTU-sized scratch
// (see SegmentSuperpacket), so rxBuf only holds raw kernel-supplied bytes.
// We round up to give comfortable margin for the drain headroom check
// below.
const tunRxBufSize = 64 * 1024
// tunRxBufCap is the total size we allocate for the per-reader rx
// buffer. With reads landing directly in rxBuf, each drain iteration
// consumes up to tunRxBufSize of headroom for the kernel-supplied bytes.
// Sized to two such iterations so the initial blocking read plus one
// drain read both fit without partial-drop.
const tunRxBufCap = tunRxBufSize * 2
// tunDrainCap caps how many packets a single Read will accumulate via
// the post-wake drain loop. Sized to soak up a burst of small ACKs while
// bounding how much work a single caller holds before handing off.
const tunDrainCap = 64
// gsoMaxIovs caps the iovec budget WriteGSO assembles per call: 3 fixed
// entries (virtio_net_hdr, IP hdr, transport hdr) plus up to gsoMaxIovs-3
// payload fragments. Sized comfortably above the typical kernel GSO
// segment cap (Linux UDP_GRO is 64) so realistic coalesced bursts never
// touch the limit. iovecs are tiny (16 bytes), so the entire scratch is
// 4 KiB — fine to keep resident on every queue. WriteGSO returns an error
// rather than reallocating when a caller exceeds this budget.
const gsoMaxIovs = 256
// validVnetHdr is the 10-byte virtio_net_hdr we prepend to every non-GSO TUN
// write. Only flag set is VIRTIO_NET_HDR_F_DATA_VALID, which marks the skb
// CHECKSUM_UNNECESSARY so the receiving network stack skips L4 checksum
// verification. All packets that reach the plain Write paths already carry
// a valid L4 checksum (either supplied by a remote peer whose ciphertext we
// AEAD-authenticated, produced by segmentTCPYield/segmentUDPYield during
// superpacket segmentation, or built locally by CreateRejectPacket), so
// trusting them is safe.
var validVnetHdr = [virtio.Size]byte{unix.VIRTIO_NET_HDR_F_DATA_VALID}
// Offload wraps a TUN file descriptor with poll-based reads. The FD provided will be changed to non-blocking.
// A shared eventfd allows Close to wake all readers blocked in poll.
type Offload struct {
fd int
shutdownFd int
readPoll [2]unix.PollFd
writePoll [2]unix.PollFd
// writeLock serializes blockOnWrite's read+clear of writePoll[*].Revents.
// Any goroutine that calls Write may end up parked in poll(2); without
// the lock concurrent waiters could race the Revents reset and lose
// events.
writeLock sync.Mutex
closed atomic.Bool
rxBuf []byte // backing store for kernel-handed packets read this drain
rxOff int // cursor into rxBuf for the current Read drain
pending []Packet // packets returned from the most recent Read
// readVnetScratch holds the 10-byte virtio_net_hdr split off the front of
// every TUN read via readv(2). Decoupling the header from the packet body
// lets us read the body directly into rxBuf at the current rxOff with
// no userspace copy on the GSO_NONE fast path.
readVnetScratch [virtio.Size]byte
// readIovs is the readv(2) iovec scratch wired once at construction —
// iovec[0] points at readVnetScratch; iovec[1].Base/Len is updated per
// read to address the current rxBuf slot.
readIovs [2]unix.Iovec
// usoEnabled records whether the kernel agreed to TUN_F_USO* on this FD,
// so writers can decide whether emitting GSO_UDP_L4 superpackets is safe.
usoEnabled bool
// gsoHdrBuf is a per-queue 10-byte scratch for the virtio_net_hdr emitted
// by WriteGSO. Kept separate from the read-only package-level validVnetHdr
// so non-GSO Writes can ship that constant directly while WriteGSO
// rewrites this scratch on every call.
gsoHdrBuf [virtio.Size]byte
// gsoIovs is the writev iovec scratch for WriteGSO. Pre-sized to
// gsoMaxIovs at construction; never grown. WriteGSO returns an error
// (and drops the call) if a caller hands it more fragments than fit.
gsoIovs []unix.Iovec
}
func newOffload(fd int, shutdownFd int, usoEnabled bool) (*Offload, error) {
if err := unix.SetNonblock(fd, true); err != nil {
return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err)
}
out := &Offload{
fd: fd,
shutdownFd: shutdownFd,
usoEnabled: usoEnabled,
closed: atomic.Bool{},
readPoll: [2]unix.PollFd{
{Fd: int32(fd), Events: unix.POLLIN},
{Fd: int32(shutdownFd), Events: unix.POLLIN},
},
writePoll: [2]unix.PollFd{
{Fd: int32(fd), Events: unix.POLLOUT},
{Fd: int32(shutdownFd), Events: unix.POLLIN},
},
writeLock: sync.Mutex{},
rxBuf: make([]byte, tunRxBufCap),
gsoIovs: make([]unix.Iovec, 2, gsoMaxIovs),
}
out.gsoIovs[0].Base = &out.gsoHdrBuf[0]
out.gsoIovs[0].SetLen(virtio.Size)
// readIovs[0] is wired once to the virtio_net_hdr scratch; per-read we
// only repoint readIovs[1] at the next rxBuf slot (see readPacket).
out.readIovs[0].Base = &out.readVnetScratch[0]
out.readIovs[0].SetLen(virtio.Size)
return out, nil
}
func (r *Offload) blockOnRead() error {
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
var err error
for {
_, err = unix.Poll(r.readPoll[:], -1)
if err != unix.EINTR {
break
}
}
//always reset these!
tunEvents := r.readPoll[0].Revents
shutdownEvents := r.readPoll[1].Revents
r.readPoll[0].Revents = 0
r.readPoll[1].Revents = 0
//do the err check before trusting the potentially bogus bits we just got
if err != nil {
return err
}
if shutdownEvents&(unix.POLLIN|problemFlags) != 0 {
return os.ErrClosed
} else if tunEvents&problemFlags != 0 {
return os.ErrClosed
}
return nil
}
func (r *Offload) blockOnWrite() error {
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
var err error
for {
_, err = unix.Poll(r.writePoll[:], -1)
if err != unix.EINTR {
break
}
}
//always reset these!
r.writeLock.Lock()
tunEvents := r.writePoll[0].Revents
shutdownEvents := r.writePoll[1].Revents
r.writePoll[0].Revents = 0
r.writePoll[1].Revents = 0
r.writeLock.Unlock()
//do the err check before trusting the potentially bogus bits we just got
if err != nil {
return err
}
if shutdownEvents&(unix.POLLIN|problemFlags) != 0 {
return os.ErrClosed
} else if tunEvents&problemFlags != 0 {
return os.ErrClosed
}
return nil
}
// readPacket issues a single readv(2) splitting the virtio_net_hdr off
// into readVnetScratch and reading the packet body directly into rxBuf at
// the current rxOff. Returns the body length (zero virtio header bytes,
// just the IP packet/superpacket). block controls whether EAGAIN is
// retried via poll: the initial read of a drain blocks; subsequent drain
// reads do not.
//
// The body iovec capacity is always tunReadBufSize; callers (the Read
// drain loop) gate entry on tunRxBufCap-rxOff >= tunRxBufSize, sized to
// hold one worst-case kernel-supplied packet body. Without that gate the
// body iovec could be smaller than the next inbound packet and the
// kernel would truncate.
func (r *Offload) readPacket(block bool) (int, error) {
for {
r.readIovs[1].Base = &r.rxBuf[r.rxOff]
r.readIovs[1].SetLen(tunReadBufSize)
n, _, errno := syscall.Syscall(unix.SYS_READV, uintptr(r.fd), uintptr(unsafe.Pointer(&r.readIovs[0])), uintptr(len(r.readIovs)))
if errno == 0 {
if int(n) < virtio.Size {
return 0, io.ErrShortWrite
}
return int(n) - virtio.Size, nil
}
if errno == unix.EAGAIN {
if !block {
return 0, errno
}
if err := r.blockOnRead(); err != nil {
return 0, err
}
continue
}
if errno == unix.EINTR {
continue
}
if errno == unix.EBADF {
return 0, os.ErrClosed
}
return 0, errno
}
}
// Read returns one or more packets from the tun. Each Packet either
// carries a single ready-to-use IP datagram (GSO zero) or a TSO/USO
// superpacket plus the GSOInfo a caller needs to segment it (see
// SegmentSuperpacket). The first read blocks via poll; once the fd is
// known readable we drain additional packets non-blocking until the
// kernel queue is empty (EAGAIN), we've collected tunDrainCap packets,
// or we're out of rxBuf headroom. This amortizes the poll wake over
// bursts of small packets (e.g. TCP ACKs). Packet.Bytes slices point
// into the Offload's internal buffer and are only valid until the next
// Read or Close on this Queue.
func (r *Offload) Read() ([]Packet, error) {
r.pending = r.pending[:0]
r.rxOff = 0
// Initial (blocking) read. Retry on decode errors so a single bad
// packet does not stall the reader.
for {
n, err := r.readPacket(true)
if err != nil {
return nil, err
}
if err := r.decodeRead(n); err != nil {
// Drop and read again — a bad packet should not kill the reader.
continue
}
break
}
// Drain: non-blocking reads until the kernel queue is empty, the drain
// cap is reached, or rxBuf no longer has room for another worst-case
// kernel-supplied packet (tunRxBufSize).
for len(r.pending) < tunDrainCap && tunRxBufCap-r.rxOff >= tunRxBufSize {
n, err := r.readPacket(false)
if err != nil {
// EAGAIN / EINTR / anything else: stop draining. We already
// have a valid batch from the first read.
break
}
if n <= 0 {
break
}
if err := r.decodeRead(n); err != nil {
// Drop this packet and stop the drain; we'd rather hand off
// what we have than keep spinning here.
break
}
}
return r.pending, nil
}
// decodeRead processes the packet sitting in rxBuf at rxOff (length
// pktLen). The bytes stay in rxBuf — for GSO_NONE we slice them as a
// regular IP datagram (running finishChecksum if NEEDS_CSUM is set);
// for TSO/USO superpackets we attach the corrected GSO metadata so the
// caller can segment lazily at encrypt time. rxOff advances past the
// kernel-supplied body and nothing else, since segmentation no longer
// writes back into rxBuf.
func (r *Offload) decodeRead(pktLen int) error {
if pktLen <= 0 {
return fmt.Errorf("short tun read: %d", pktLen)
}
var hdr virtio.Hdr
hdr.Decode(r.readVnetScratch[:])
body := r.rxBuf[r.rxOff : r.rxOff+pktLen]
if hdr.GSOType == unix.VIRTIO_NET_HDR_GSO_NONE {
if hdr.Flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 {
if err := virtio.FinishChecksum(body, hdr); err != nil {
return err
}
}
r.pending = append(r.pending, Packet{Bytes: body})
r.rxOff += pktLen
return nil
}
// GSO superpacket: validate, fix the kernel-supplied HdrLen on the
// FORWARD path (CorrectHdrLen), pick the L4 protocol, and attach
// the metadata. The bytes stay in rxBuf untouched, segmentation
// happens in SegmentSuperpacket at encrypt time.
if err := virtio.CheckValid(body, hdr); err != nil {
return err
}
if err := virtio.CorrectHdrLen(body, &hdr); err != nil {
return err
}
proto, err := protoFromGSOType(hdr.GSOType)
if err != nil {
return err
}
r.pending = append(r.pending, Packet{
Bytes: body,
GSO: GSOInfo{
Size: hdr.GSOSize,
HdrLen: hdr.HdrLen,
CsumStart: hdr.CsumStart,
Proto: proto,
},
})
r.rxOff += pktLen
return nil
}
func (r *Offload) Write(buf []byte) (int, error) {
iovs := [2]unix.Iovec{
{Base: &validVnetHdr[0]},
{Base: &buf[0]},
}
iovs[0].SetLen(virtio.Size)
iovs[1].SetLen(len(buf))
return r.writeWithScratch(buf, &iovs)
}
func (r *Offload) writeWithScratch(buf []byte, iovs *[2]unix.Iovec) (int, error) {
if len(buf) == 0 {
return 0, nil
}
iovs[1].Base = &buf[0]
iovs[1].SetLen(len(buf))
return r.rawWrite(unsafe.Slice(&iovs[0], len(iovs)))
}
func (r *Offload) rawWrite(iovs []unix.Iovec) (int, error) {
for {
n, _, errno := syscall.Syscall(unix.SYS_WRITEV, uintptr(r.fd), uintptr(unsafe.Pointer(&iovs[0])), uintptr(len(iovs)))
if errno == 0 {
if int(n) < virtio.Size {
return 0, io.ErrShortWrite
}
return int(n) - virtio.Size, nil
}
if errno == unix.EAGAIN {
if err := r.blockOnWrite(); err != nil {
return 0, err
}
continue
}
if errno == unix.EINTR {
continue
}
if errno == unix.EBADF {
return 0, os.ErrClosed
}
return 0, errno
}
}
// Capabilities reports the offload features negotiated for this Queue. TSO
// is always true for Offload (we only construct it on IFF_VNET_HDR FDs);
// USO is true only when the kernel agreed to TUN_F_USO4|6 at open time
// (Linux ≥ 6.2).
func (r *Offload) Capabilities() Capabilities {
return Capabilities{TSO: true, USO: r.usoEnabled}
}
func (r *Offload) WriteGSO(hdr []byte, transportHdr []byte, pays [][]byte, proto GSOProto) error {
if len(hdr) == 0 || len(pays) == 0 || len(transportHdr) == 0 {
return nil
}
// L4 checksum offset inside transportHdr: TCP=16 (the `check` field after
// seq/ack/dataoff/flags/window), UDP=6 (after sport/dport/length).
var csumOff uint16
switch proto {
case GSOProtoUDP:
csumOff = 6
default:
csumOff = 16
}
vhdr := virtio.Hdr{
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
HdrLen: uint16(len(hdr) + len(transportHdr)),
GSOSize: uint16(len(pays[0])),
CsumStart: uint16(len(hdr)),
CsumOffset: csumOff,
}
if len(pays) > 1 {
ipVer := hdr[0] >> 4
switch {
case proto == GSOProtoUDP && (ipVer == 4 || ipVer == 6):
vhdr.GSOType = unix.VIRTIO_NET_HDR_GSO_UDP_L4
case ipVer == 6:
vhdr.GSOType = unix.VIRTIO_NET_HDR_GSO_TCPV6
case ipVer == 4:
vhdr.GSOType = unix.VIRTIO_NET_HDR_GSO_TCPV4
default:
vhdr.GSOType = unix.VIRTIO_NET_HDR_GSO_NONE
vhdr.GSOSize = 0
}
} else {
vhdr.GSOType = unix.VIRTIO_NET_HDR_GSO_NONE
vhdr.GSOSize = 0
}
vhdr.Encode(r.gsoHdrBuf[:])
// Build the iovec array: [virtio_hdr, hdr, transportHdr, pays...]. r.gsoIovs[0] is
// wired to gsoHdrBuf at construction and never changes.
need := 3 + len(pays)
if need > cap(r.gsoIovs) {
slog.Default().Warn("tio: WriteGSO iovec budget exceeded; dropping superpacket",
"need", need, "cap", cap(r.gsoIovs), "segments", len(pays))
return fmt.Errorf("tio: WriteGSO needs %d iovecs but cap is %d", need, cap(r.gsoIovs))
}
r.gsoIovs = r.gsoIovs[:need]
r.gsoIovs[1].Base = &hdr[0]
r.gsoIovs[1].SetLen(len(hdr))
r.gsoIovs[2].Base = &transportHdr[0]
r.gsoIovs[2].SetLen(len(transportHdr))
for i, p := range pays {
r.gsoIovs[3+i].Base = &p[0]
r.gsoIovs[3+i].SetLen(len(p))
}
_, err := r.rawWrite(r.gsoIovs)
return err
}
func (r *Offload) Close() error {
if r.closed.Swap(true) {
return nil
}
//shutdownFd is owned by the container, so we should not close it
var err error
if r.fd >= 0 {
err = unix.Close(r.fd)
r.fd = -1
}
return err
}

View File

@@ -21,7 +21,7 @@ type Poll struct {
closed atomic.Bool
readBuf []byte
batchRet [1][]byte
batchRet [1]Packet
}
func newPoll(fd int, shutdownFd int) (*Poll, error) {
@@ -97,12 +97,12 @@ func (t *Poll) blockOnWrite() error {
return nil
}
func (t *Poll) Read() ([][]byte, error) {
func (t *Poll) Read() ([]Packet, error) {
n, err := t.readOne(t.readBuf)
if err != nil {
return nil, err
}
t.batchRet[0] = t.readBuf[:n]
t.batchRet[0] = Packet{Bytes: t.readBuf[:n]}
return t.batchRet[:], nil
}

View File

@@ -15,7 +15,7 @@ import (
)
// newReadPipe returns a read fd. The matching write fd is registered for cleanup.
// The caller takes ownership of the read fd (pass it to newOffload / newFriend).
// The caller takes ownership of the read fd (pass it into a QueueSet).
func newReadPipe(t *testing.T) int {
t.Helper()
var fds [2]int
@@ -29,7 +29,7 @@ func newReadPipe(t *testing.T) int {
func TestPoll_WakeForShutdown_WakesFriends(t *testing.T) {
pipe1 := newReadPipe(t)
pipe2 := newReadPipe(t)
parent, err := NewPollContainer()
parent, err := NewPollQueueSet()
require.NoError(t, err)
require.NoError(t, parent.Add(pipe1))
require.NoError(t, parent.Add(pipe2))

View File

@@ -0,0 +1,51 @@
//go:build linux && !android && !e2e_testing
// +build linux,!android,!e2e_testing
package tio
import (
"fmt"
"golang.org/x/sys/unix"
"github.com/slackhq/nebula/overlay/tio/virtio"
)
// protoFromGSOType maps a virtio_net_hdr GSOType to the GSOProto value the
// segment-time helpers use. Returns an error for GSO_NONE or any unknown
// value — the caller should only invoke this on a confirmed superpacket.
func protoFromGSOType(t uint8) (GSOProto, error) {
switch t {
case unix.VIRTIO_NET_HDR_GSO_TCPV4, unix.VIRTIO_NET_HDR_GSO_TCPV6:
return GSOProtoTCP, nil
case unix.VIRTIO_NET_HDR_GSO_UDP_L4:
return GSOProtoUDP, nil
default:
return 0, fmt.Errorf("unsupported virtio gso type: %d", t)
}
}
// SegmentSuperpacket invokes fn once per segment of pkt. For non-GSO pkts
// fn is called once with pkt.Bytes (no segmentation, no copy). For GSO/USO
// superpackets fn is called once per segment with a slice of pkt.Bytes
// holding that segment's plaintext (a freshly-patched L3+L4 header sliced
// in front of the original payload chunk). The slide is destructive: pkt is
// consumed by this call and its bytes are in an undefined state when
// SegmentSuperpacket returns. Callers must not retain pkt or any earlier
// seg slice past fn's return for that segment. The scratch parameter is
// unused on the destructive path and kept only for cross-platform
// signature compatibility. Aborts and returns the first error from fn or
// from per-segment construction.
func SegmentSuperpacket(pkt Packet, fn func(seg []byte) error) error {
if !pkt.GSO.IsSuperpacket() {
return fn(pkt.Bytes)
}
switch pkt.GSO.Proto {
case GSOProtoTCP:
return virtio.SegmentTCP(pkt.Bytes, pkt.GSO.HdrLen, pkt.GSO.CsumStart, pkt.GSO.Size, fn)
case GSOProtoUDP:
return virtio.SegmentUDP(pkt.Bytes, pkt.GSO.HdrLen, pkt.GSO.CsumStart, pkt.GSO.Size, fn)
default:
return fmt.Errorf("unsupported gso proto: %d", pkt.GSO.Proto)
}
}

View File

@@ -0,0 +1,794 @@
//go:build linux && !android && !e2e_testing
// +build linux,!android,!e2e_testing
package tio
import (
"encoding/binary"
"os"
"testing"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/tcpip/checksum"
"github.com/slackhq/nebula/overlay/tio/virtio"
)
// testSegScratchSize is a generous segmentation scratch sized to fit any
// of the synthetic TSO/USO superpackets these tests generate (one
// worst-case 64 KiB superpacket plus replicated per-segment headers).
const testSegScratchSize = 192 * 1024
// verifyChecksum confirms that the one's-complement sum across `b`, seeded
// with a folded pseudo-header sum, equals all-ones (valid).
func verifyChecksum(b []byte, pseudo uint16) bool {
return checksum.Checksum(b, pseudo) == 0xffff
}
// segmentForTest is the test-only counterpart to the production
// SegmentSuperpacket path. It handles GSO_NONE (with optional
// finishChecksum) inline and dispatches GSO superpackets through
// SegmentSuperpacket, draining each yielded segment into a
// freshly-copied [][]byte slot so callers can iterate after the call
// returns. Tests pre-set hdr.HdrLen correctly, so correctHdrLen is not
// invoked here.
func segmentForTest(pkt []byte, hdr virtio.Hdr, out *[][]byte, scratch []byte) error {
if hdr.GSOType == unix.VIRTIO_NET_HDR_GSO_NONE {
cp := append([]byte(nil), pkt...)
if hdr.Flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 {
if err := virtio.FinishChecksum(cp, hdr); err != nil {
return err
}
}
*out = append(*out, cp)
return nil
}
proto, err := protoFromGSOType(hdr.GSOType)
if err != nil {
return err
}
gso := GSOInfo{
Size: hdr.GSOSize,
HdrLen: hdr.HdrLen,
CsumStart: hdr.CsumStart,
Proto: proto,
}
return SegmentSuperpacket(Packet{Bytes: pkt, GSO: gso}, func(seg []byte) error {
*out = append(*out, append([]byte(nil), seg...))
return nil
})
}
// pseudoHeaderIPv4 returns the folded pseudo-header sum used to verify a
// TCP/UDP segment's checksum in tests. src/dst are 4 bytes each.
func pseudoHeaderIPv4(src, dst []byte, proto byte, l4Len int) uint16 {
s := uint32(checksum.Checksum(src, 0)) + uint32(checksum.Checksum(dst, 0))
s += uint32(proto) + uint32(l4Len)
s = (s & 0xffff) + (s >> 16)
s = (s & 0xffff) + (s >> 16)
return uint16(s)
}
// pseudoHeaderIPv6 returns the folded pseudo-header sum used to verify a
// TCP/UDP segment's checksum in tests. src/dst are 16 bytes each.
func pseudoHeaderIPv6(src, dst []byte, proto byte, l4Len int) uint16 {
s := uint32(checksum.Checksum(src, 0)) + uint32(checksum.Checksum(dst, 0))
s += uint32(l4Len>>16) + uint32(l4Len&0xffff) + uint32(proto)
s = (s & 0xffff) + (s >> 16)
s = (s & 0xffff) + (s >> 16)
return uint16(s)
}
// buildTSOv4 builds a synthetic IPv4/TCP TSO superpacket with a payload of
// `payLen` bytes split at `mss`.
func buildTSOv4(t *testing.T, payLen, mss int) ([]byte, virtio.Hdr) {
t.Helper()
const ipLen = 20
const tcpLen = 20
pkt := make([]byte, ipLen+tcpLen+payLen)
// IPv4 header
pkt[0] = 0x45 // version 4, IHL 5
// total length is meaningless for TSO but set it anyway
binary.BigEndian.PutUint16(pkt[2:4], uint16(ipLen+tcpLen+payLen))
binary.BigEndian.PutUint16(pkt[4:6], 0x4242) // original ID
pkt[8] = 64 // TTL
pkt[9] = unix.IPPROTO_TCP
copy(pkt[12:16], []byte{10, 0, 0, 1}) // src
copy(pkt[16:20], []byte{10, 0, 0, 2}) // dst
// TCP header
binary.BigEndian.PutUint16(pkt[20:22], 12345) // sport
binary.BigEndian.PutUint16(pkt[22:24], 80) // dport
binary.BigEndian.PutUint32(pkt[24:28], 10000) // seq
binary.BigEndian.PutUint32(pkt[28:32], 20000) // ack
pkt[32] = 0x50 // data offset 5 words
pkt[33] = 0x18 // ACK | PSH
binary.BigEndian.PutUint16(pkt[34:36], 65535) // window
// payload
for i := 0; i < payLen; i++ {
pkt[ipLen+tcpLen+i] = byte(i & 0xff)
}
return pkt, virtio.Hdr{
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
GSOType: unix.VIRTIO_NET_HDR_GSO_TCPV4,
HdrLen: uint16(ipLen + tcpLen),
GSOSize: uint16(mss),
CsumStart: uint16(ipLen),
CsumOffset: 16,
}
}
func TestSegmentTCPv4(t *testing.T) {
const mss = 100
const numSeg = 3
pkt, hdr := buildTSOv4(t, mss*numSeg, mss)
scratch := make([]byte, testSegScratchSize)
var out [][]byte
if err := segmentForTest(pkt, hdr, &out, scratch); err != nil {
t.Fatalf("segmentForTest: %v", err)
}
if len(out) != numSeg {
t.Fatalf("expected %d segments, got %d", numSeg, len(out))
}
for i, seg := range out {
if len(seg) != 40+mss {
t.Errorf("seg %d: unexpected len %d", i, len(seg))
}
totalLen := binary.BigEndian.Uint16(seg[2:4])
if totalLen != uint16(40+mss) {
t.Errorf("seg %d: total_len=%d want %d", i, totalLen, 40+mss)
}
id := binary.BigEndian.Uint16(seg[4:6])
if id != 0x4242+uint16(i) {
t.Errorf("seg %d: ip id=%#x want %#x", i, id, 0x4242+uint16(i))
}
seq := binary.BigEndian.Uint32(seg[24:28])
wantSeq := uint32(10000 + i*mss)
if seq != wantSeq {
t.Errorf("seg %d: seq=%d want %d", i, seq, wantSeq)
}
flags := seg[33]
wantFlags := byte(0x10) // ACK only, PSH cleared
if i == numSeg-1 {
wantFlags = 0x18 // ACK | PSH preserved on last
}
if flags != wantFlags {
t.Errorf("seg %d: flags=%#x want %#x", i, flags, wantFlags)
}
// IPv4 header checksum must verify against itself.
if !verifyChecksum(seg[:20], 0) {
t.Errorf("seg %d: bad IPv4 header checksum", i)
}
// TCP checksum must verify against the pseudo-header.
psum := pseudoHeaderIPv4(seg[12:16], seg[16:20], unix.IPPROTO_TCP, 20+mss)
if !verifyChecksum(seg[20:], psum) {
t.Errorf("seg %d: bad TCP checksum", i)
}
}
}
func TestSegmentTCPv4OddTail(t *testing.T) {
// Payload of 250 bytes with MSS 100 → segments of 100, 100, 50.
pkt, hdr := buildTSOv4(t, 250, 100)
scratch := make([]byte, testSegScratchSize)
var out [][]byte
if err := segmentForTest(pkt, hdr, &out, scratch); err != nil {
t.Fatalf("segmentForTest: %v", err)
}
if len(out) != 3 {
t.Fatalf("want 3 segments, got %d", len(out))
}
wantPayLens := []int{100, 100, 50}
for i, seg := range out {
if len(seg)-40 != wantPayLens[i] {
t.Errorf("seg %d: pay len %d want %d", i, len(seg)-40, wantPayLens[i])
}
if !verifyChecksum(seg[:20], 0) {
t.Errorf("seg %d: bad IPv4 header checksum", i)
}
psum := pseudoHeaderIPv4(seg[12:16], seg[16:20], unix.IPPROTO_TCP, 20+wantPayLens[i])
if !verifyChecksum(seg[20:], psum) {
t.Errorf("seg %d: bad TCP checksum", i)
}
}
}
func TestSegmentTCPv6(t *testing.T) {
const ipLen = 40
const tcpLen = 20
const mss = 120
const numSeg = 2
payLen := mss * numSeg
pkt := make([]byte, ipLen+tcpLen+payLen)
// IPv6 header
pkt[0] = 0x60 // version 6
binary.BigEndian.PutUint16(pkt[4:6], uint16(tcpLen+payLen))
pkt[6] = unix.IPPROTO_TCP
pkt[7] = 64
// src/dst fe80::1 / fe80::2
pkt[8] = 0xfe
pkt[9] = 0x80
pkt[23] = 1
pkt[24] = 0xfe
pkt[25] = 0x80
pkt[39] = 2
// TCP header
binary.BigEndian.PutUint16(pkt[40:42], 12345)
binary.BigEndian.PutUint16(pkt[42:44], 80)
binary.BigEndian.PutUint32(pkt[44:48], 7)
binary.BigEndian.PutUint32(pkt[48:52], 99)
pkt[52] = 0x50
pkt[53] = 0x19 // FIN | ACK | PSH — exercise FIN clearing too
binary.BigEndian.PutUint16(pkt[54:56], 65535)
for i := 0; i < payLen; i++ {
pkt[ipLen+tcpLen+i] = byte(i)
}
hdr := virtio.Hdr{
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
GSOType: unix.VIRTIO_NET_HDR_GSO_TCPV6,
HdrLen: uint16(ipLen + tcpLen),
GSOSize: uint16(mss),
CsumStart: uint16(ipLen),
CsumOffset: 16,
}
scratch := make([]byte, testSegScratchSize)
var out [][]byte
if err := segmentForTest(pkt, hdr, &out, scratch); err != nil {
t.Fatalf("segmentForTest: %v", err)
}
if len(out) != numSeg {
t.Fatalf("want %d segments, got %d", numSeg, len(out))
}
for i, seg := range out {
if len(seg) != ipLen+tcpLen+mss {
t.Errorf("seg %d: len %d want %d", i, len(seg), ipLen+tcpLen+mss)
}
pl := binary.BigEndian.Uint16(seg[4:6])
if pl != uint16(tcpLen+mss) {
t.Errorf("seg %d: payload_length=%d want %d", i, pl, tcpLen+mss)
}
seq := binary.BigEndian.Uint32(seg[44:48])
if seq != uint32(7+i*mss) {
t.Errorf("seg %d: seq=%d want %d", i, seq, 7+i*mss)
}
flags := seg[53]
// Original flags = 0x19 (FIN|ACK|PSH). FIN(0x01)+PSH(0x08) should be
// cleared on all but the last; ACK(0x10) always preserved.
wantFlags := byte(0x10)
if i == numSeg-1 {
wantFlags = 0x19
}
if flags != wantFlags {
t.Errorf("seg %d: flags=%#x want %#x", i, flags, wantFlags)
}
psum := pseudoHeaderIPv6(seg[8:24], seg[24:40], unix.IPPROTO_TCP, tcpLen+mss)
if !verifyChecksum(seg[ipLen:], psum) {
t.Errorf("seg %d: bad TCP checksum", i)
}
}
}
func TestSegmentGSONonePassesThrough(t *testing.T) {
pkt, hdr := buildTSOv4(t, 100, 100)
hdr.GSOType = unix.VIRTIO_NET_HDR_GSO_NONE
hdr.Flags = 0 // no NEEDS_CSUM, leave packet untouched
scratch := make([]byte, testSegScratchSize)
var out [][]byte
if err := segmentForTest(pkt, hdr, &out, scratch); err != nil {
t.Fatalf("segmentForTest: %v", err)
}
if len(out) != 1 {
t.Fatalf("want 1 segment, got %d", len(out))
}
if len(out[0]) != len(pkt) {
t.Fatalf("unexpected length: %d vs %d", len(out[0]), len(pkt))
}
}
// TestSegmentRejectsLegacyUDPGSO ensures the legacy GSO_UDP (UFO) marker is
// still rejected; only modern GSO_UDP_L4 (USO) is supported.
func TestSegmentRejectsLegacyUDPGSO(t *testing.T) {
hdr := virtio.Hdr{GSOType: unix.VIRTIO_NET_HDR_GSO_UDP}
var out [][]byte
if err := segmentForTest(nil, hdr, &out, nil); err == nil {
t.Fatalf("expected rejection for legacy UDP GSO")
}
}
// buildUSOv4 builds a synthetic IPv4/UDP USO superpacket with payload of
// payLen bytes, segmented at gsoSize.
func buildUSOv4(t *testing.T, payLen, gsoSize int) ([]byte, virtio.Hdr) {
t.Helper()
const ipLen = 20
const udpLen = 8
pkt := make([]byte, ipLen+udpLen+payLen)
// IPv4 header
pkt[0] = 0x45 // version 4, IHL 5
binary.BigEndian.PutUint16(pkt[2:4], uint16(ipLen+udpLen+payLen))
binary.BigEndian.PutUint16(pkt[4:6], 0x4242)
pkt[8] = 64
pkt[9] = unix.IPPROTO_UDP
copy(pkt[12:16], []byte{10, 0, 0, 1})
copy(pkt[16:20], []byte{10, 0, 0, 2})
// UDP header (length + checksum filled in per segment by segmentUDPYield)
binary.BigEndian.PutUint16(pkt[20:22], 12345) // sport
binary.BigEndian.PutUint16(pkt[22:24], 53) // dport
for i := 0; i < payLen; i++ {
pkt[ipLen+udpLen+i] = byte(i & 0xff)
}
return pkt, virtio.Hdr{
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
GSOType: unix.VIRTIO_NET_HDR_GSO_UDP_L4,
HdrLen: uint16(ipLen + udpLen),
GSOSize: uint16(gsoSize),
CsumStart: uint16(ipLen),
CsumOffset: 6,
}
}
func TestSegmentUDPv4(t *testing.T) {
const gso = 100
const numSeg = 3
pkt, hdr := buildUSOv4(t, gso*numSeg, gso)
scratch := make([]byte, testSegScratchSize)
var out [][]byte
if err := segmentForTest(pkt, hdr, &out, scratch); err != nil {
t.Fatalf("segmentForTest: %v", err)
}
if len(out) != numSeg {
t.Fatalf("expected %d segments, got %d", numSeg, len(out))
}
for i, seg := range out {
if len(seg) != 28+gso {
t.Errorf("seg %d: len %d want %d", i, len(seg), 28+gso)
}
totalLen := binary.BigEndian.Uint16(seg[2:4])
if totalLen != uint16(28+gso) {
t.Errorf("seg %d: total_len=%d want %d", i, totalLen, 28+gso)
}
// kernel UDP-GSO does NOT bump the IPv4 ID across segments; every
// segment carries the same ID as the seed.
id := binary.BigEndian.Uint16(seg[4:6])
if id != 0x4242 {
t.Errorf("seg %d: ip id=%#x want %#x", i, id, 0x4242)
}
udpLen := binary.BigEndian.Uint16(seg[24:26])
if udpLen != uint16(8+gso) {
t.Errorf("seg %d: udp len=%d want %d", i, udpLen, 8+gso)
}
if !verifyChecksum(seg[:20], 0) {
t.Errorf("seg %d: bad IPv4 header checksum", i)
}
psum := pseudoHeaderIPv4(seg[12:16], seg[16:20], unix.IPPROTO_UDP, 8+gso)
if !verifyChecksum(seg[20:], psum) {
t.Errorf("seg %d: bad UDP checksum", i)
}
}
}
func TestSegmentUDPv4OddTail(t *testing.T) {
// 250 bytes payload, gsoSize=100 → segments of 100, 100, 50.
pkt, hdr := buildUSOv4(t, 250, 100)
scratch := make([]byte, testSegScratchSize)
var out [][]byte
if err := segmentForTest(pkt, hdr, &out, scratch); err != nil {
t.Fatalf("segmentForTest: %v", err)
}
if len(out) != 3 {
t.Fatalf("want 3 segments, got %d", len(out))
}
wantPay := []int{100, 100, 50}
for i, seg := range out {
if len(seg)-28 != wantPay[i] {
t.Errorf("seg %d: pay len %d want %d", i, len(seg)-28, wantPay[i])
}
udpLen := binary.BigEndian.Uint16(seg[24:26])
if udpLen != uint16(8+wantPay[i]) {
t.Errorf("seg %d: udp len=%d want %d", i, udpLen, 8+wantPay[i])
}
if !verifyChecksum(seg[:20], 0) {
t.Errorf("seg %d: bad IPv4 header checksum", i)
}
psum := pseudoHeaderIPv4(seg[12:16], seg[16:20], unix.IPPROTO_UDP, 8+wantPay[i])
if !verifyChecksum(seg[20:], psum) {
t.Errorf("seg %d: bad UDP checksum", i)
}
}
}
func TestSegmentUDPv6(t *testing.T) {
const ipLen = 40
const udpLen = 8
const gso = 120
const numSeg = 2
payLen := gso * numSeg
pkt := make([]byte, ipLen+udpLen+payLen)
// IPv6 header
pkt[0] = 0x60
binary.BigEndian.PutUint16(pkt[4:6], uint16(udpLen+payLen))
pkt[6] = unix.IPPROTO_UDP
pkt[7] = 64
pkt[8] = 0xfe
pkt[9] = 0x80
pkt[23] = 1
pkt[24] = 0xfe
pkt[25] = 0x80
pkt[39] = 2
binary.BigEndian.PutUint16(pkt[40:42], 12345)
binary.BigEndian.PutUint16(pkt[42:44], 53)
for i := 0; i < payLen; i++ {
pkt[ipLen+udpLen+i] = byte(i)
}
hdr := virtio.Hdr{
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
GSOType: unix.VIRTIO_NET_HDR_GSO_UDP_L4,
HdrLen: uint16(ipLen + udpLen),
GSOSize: uint16(gso),
CsumStart: uint16(ipLen),
CsumOffset: 6,
}
scratch := make([]byte, testSegScratchSize)
var out [][]byte
if err := segmentForTest(pkt, hdr, &out, scratch); err != nil {
t.Fatalf("segmentForTest: %v", err)
}
if len(out) != numSeg {
t.Fatalf("want %d segments, got %d", numSeg, len(out))
}
for i, seg := range out {
if len(seg) != ipLen+udpLen+gso {
t.Errorf("seg %d: len %d want %d", i, len(seg), ipLen+udpLen+gso)
}
pl := binary.BigEndian.Uint16(seg[4:6])
if pl != uint16(udpLen+gso) {
t.Errorf("seg %d: payload_length=%d want %d", i, pl, udpLen+gso)
}
ul := binary.BigEndian.Uint16(seg[ipLen+4 : ipLen+6])
if ul != uint16(udpLen+gso) {
t.Errorf("seg %d: udp len=%d want %d", i, ul, udpLen+gso)
}
psum := pseudoHeaderIPv6(seg[8:24], seg[24:40], unix.IPPROTO_UDP, udpLen+gso)
if !verifyChecksum(seg[ipLen:], psum) {
t.Errorf("seg %d: bad UDP checksum", i)
}
}
}
// TestSegmentUDPCEPropagates confirms IP-level CE marks on the seed appear on
// every segment. UDP has no transport-level CWR/ECE: the IP TOS/TC byte is
// copied verbatim into every segment by the segment-prefix copy.
func TestSegmentUDPCEPropagates(t *testing.T) {
pkt, hdr := buildUSOv4(t, 200, 100)
pkt[1] = 0x03 // CE codepoint in IP-ECN
scratch := make([]byte, testSegScratchSize)
var out [][]byte
if err := segmentForTest(pkt, hdr, &out, scratch); err != nil {
t.Fatalf("segmentForTest: %v", err)
}
if len(out) != 2 {
t.Fatalf("want 2 segments, got %d", len(out))
}
for i, seg := range out {
if seg[1]&0x03 != 0x03 {
t.Errorf("seg %d: CE missing (tos=%#x)", i, seg[1])
}
if !verifyChecksum(seg[:20], 0) {
t.Errorf("seg %d: bad IPv4 header checksum", i)
}
}
}
// TestSegmentTCPCwrFirstSegmentOnly confirms RFC 3168 §6.1.2: when a TSO
// burst's seed has CWR set, only the first emitted segment carries CWR.
// ECE is preserved on every segment (different signal, persistent state).
func TestSegmentTCPCwrFirstSegmentOnly(t *testing.T) {
const mss = 100
const numSeg = 3
pkt, hdr := buildTSOv4(t, mss*numSeg, mss)
// Seed flags: CWR | ECE | ACK | PSH.
pkt[33] = 0x80 | 0x40 | 0x10 | 0x08
scratch := make([]byte, testSegScratchSize)
var out [][]byte
if err := segmentForTest(pkt, hdr, &out, scratch); err != nil {
t.Fatalf("segmentForTest: %v", err)
}
if len(out) != numSeg {
t.Fatalf("expected %d segments, got %d", numSeg, len(out))
}
for i, seg := range out {
flags := seg[33]
hasCwr := flags&0x80 != 0
hasEce := flags&0x40 != 0
hasPsh := flags&0x08 != 0
wantCwr := i == 0
wantPsh := i == numSeg-1
if hasCwr != wantCwr {
t.Errorf("seg %d: CWR=%v want %v (flags=%#x)", i, hasCwr, wantCwr, flags)
}
if !hasEce {
t.Errorf("seg %d: ECE missing (flags=%#x)", i, flags)
}
if hasPsh != wantPsh {
t.Errorf("seg %d: PSH=%v want %v (flags=%#x)", i, hasPsh, wantPsh, flags)
}
// IP and TCP checksums must still verify after the flag rewrite.
if !verifyChecksum(seg[:20], 0) {
t.Errorf("seg %d: bad IPv4 header checksum", i)
}
psum := pseudoHeaderIPv4(seg[12:16], seg[16:20], unix.IPPROTO_TCP, 20+mss)
if !verifyChecksum(seg[20:], psum) {
t.Errorf("seg %d: bad TCP checksum", i)
}
}
}
func BenchmarkSegmentTCPv4(b *testing.B) {
sizes := []struct {
name string
payLen int
mss int
}{
{"64KiB_MSS1460", 65000, 1460},
{"16KiB_MSS1460", 16384, 1460},
{"4KiB_MSS1460", 4096, 1460},
}
for _, sz := range sizes {
b.Run(sz.name, func(b *testing.B) {
const ipLen = 20
const tcpLen = 20
pkt := make([]byte, ipLen+tcpLen+sz.payLen)
pkt[0] = 0x45
binary.BigEndian.PutUint16(pkt[2:4], uint16(ipLen+tcpLen+sz.payLen))
binary.BigEndian.PutUint16(pkt[4:6], 0x4242)
pkt[8] = 64
pkt[9] = unix.IPPROTO_TCP
copy(pkt[12:16], []byte{10, 0, 0, 1})
copy(pkt[16:20], []byte{10, 0, 0, 2})
binary.BigEndian.PutUint16(pkt[20:22], 12345)
binary.BigEndian.PutUint16(pkt[22:24], 80)
binary.BigEndian.PutUint32(pkt[24:28], 10000)
binary.BigEndian.PutUint32(pkt[28:32], 20000)
pkt[32] = 0x50
pkt[33] = 0x18
binary.BigEndian.PutUint16(pkt[34:36], 65535)
for i := 0; i < sz.payLen; i++ {
pkt[ipLen+tcpLen+i] = byte(i)
}
hdr := virtio.Hdr{
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
GSOType: unix.VIRTIO_NET_HDR_GSO_TCPV4,
HdrLen: uint16(ipLen + tcpLen),
GSOSize: uint16(sz.mss),
CsumStart: uint16(ipLen),
CsumOffset: 16,
}
scratch := make([]byte, testSegScratchSize)
out := make([][]byte, 0, 64)
// SegmentSuperpacket consumes its input destructively; restore
// pkt from a master copy each iteration. The restore mirrors the
// kernel→userspace copy that hands a fresh GSO blob to the
// segmenter in production, so it's representative cost rather
// than bench overhead.
master := append([]byte(nil), pkt...)
work := make([]byte, len(pkt))
b.SetBytes(int64(len(pkt)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
copy(work, master)
out = out[:0]
if err := segmentForTest(work, hdr, &out, scratch); err != nil {
b.Fatal(err)
}
}
})
}
}
// TestTunFileWriteVnetHdrNoAlloc verifies the IFF_VNET_HDR fast-path write is
// allocation-free. We write to /dev/null so every call succeeds synchronously.
func TestTunFileWriteVnetHdrNoAlloc(t *testing.T) {
fd, err := unix.Open("/dev/null", os.O_WRONLY, 0)
if err != nil {
t.Fatalf("open /dev/null: %v", err)
}
t.Cleanup(func() { _ = unix.Close(fd) })
tf := &Offload{fd: fd}
payload := make([]byte, 1400)
// Warm up (first call may trigger one-time internal allocations elsewhere).
if _, err := tf.Write(payload); err != nil {
t.Fatalf("Write: %v", err)
}
allocs := testing.AllocsPerRun(1000, func() {
if _, err := tf.Write(payload); err != nil {
t.Fatalf("Write: %v", err)
}
})
if allocs != 0 {
t.Fatalf("Write allocated %.1f times per call, want 0", allocs)
}
}
// buildTSOv6 builds a synthetic IPv6/TCP TSO superpacket with payLen bytes
// of payload, segmented at gso. Returns the packet bytes only; the
// virtio_net_hdr is the caller's responsibility.
func buildTSOv6(payLen, gso int) []byte {
const ipLen = 40
const tcpLen = 20
pkt := make([]byte, ipLen+tcpLen+payLen)
pkt[0] = 0x60 // version 6
binary.BigEndian.PutUint16(pkt[4:6], uint16(tcpLen+payLen))
pkt[6] = unix.IPPROTO_TCP
pkt[7] = 64
pkt[8] = 0xfe
pkt[9] = 0x80
pkt[23] = 1
pkt[24] = 0xfe
pkt[25] = 0x80
pkt[39] = 2
binary.BigEndian.PutUint16(pkt[40:42], 12345)
binary.BigEndian.PutUint16(pkt[42:44], 80)
binary.BigEndian.PutUint32(pkt[44:48], 7)
binary.BigEndian.PutUint32(pkt[48:52], 99)
pkt[52] = 0x50
pkt[53] = 0x10 // ACK only
binary.BigEndian.PutUint16(pkt[54:56], 65535)
for i := 0; i < payLen; i++ {
pkt[ipLen+tcpLen+i] = byte(i)
}
return pkt
}
// TestDecodeReadFitsMaxTSOAtDrainThreshold proves the rxBuf sizing is
// correct: when rxOff is at the maximum value the drain headroom check
// allows, decodeRead must still be able to absorb a worst-case 64KiB
// TSO superpacket without dropping the burst. With segmentation deferred
// to encrypt time, decodeRead writes only the kernel-supplied bytes into
// rxBuf, so the size requirement is just "fit one worst-case input."
//
// Regression history: in a prior layout the rx buffer doubled as the
// segmentation output, a near-threshold drain read returned "scratch too
// small", the whole 45-segment TSO burst was dropped, and the remote's TCP
// fast-retransmit collapsed cwnd. Keeping this test in the new layout
// guards against re-introducing a drain headroom shortfall.
func TestDecodeReadFitsMaxTSOAtDrainThreshold(t *testing.T) {
const ipv6HdrLen = 40
const tcpHdrLen = 20
const headerLen = ipv6HdrLen + tcpHdrLen
// Maximum TUN read body. The tunReadBufSize cap on readv's body iovec
// is what bounds the kernel's superpacket length.
pktLen := tunReadBufSize
payLen := pktLen - headerLen
const targetSegs = 64
gsoSize := (payLen + targetSegs - 1) / targetSegs
pkt := buildTSOv6(payLen, gsoSize)
if len(pkt) != pktLen {
t.Fatalf("buildTSOv6 produced %d bytes, want %d", len(pkt), pktLen)
}
o := &Offload{
rxBuf: make([]byte, tunRxBufCap),
}
// rxOff at the maximum value the drain headroom check permits before
// it would refuse another read. Any drain-time read up to this
// threshold MUST still process correctly.
o.rxOff = tunRxBufCap - tunRxBufSize
// Stage the body in rxBuf as if readv(2) just placed it there.
copy(o.rxBuf[o.rxOff:], pkt)
// Encode the matching virtio_net_hdr.
hdr := virtio.Hdr{
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
GSOType: unix.VIRTIO_NET_HDR_GSO_TCPV6,
HdrLen: uint16(headerLen),
GSOSize: uint16(gsoSize),
CsumStart: uint16(ipv6HdrLen),
CsumOffset: 16,
}
hdr.Encode(o.readVnetScratch[:])
startRxOff := o.rxOff
if err := o.decodeRead(pktLen); err != nil {
t.Fatalf("decodeRead at drain threshold returned %v — rxBuf sizing regression: "+
"tunRxBufSize=%d must hold one worst-case input (%d)",
err, tunRxBufSize, pktLen)
}
if len(o.pending) != 1 {
t.Fatalf("got %d packets, want 1 superpacket entry", len(o.pending))
}
got := o.pending[0]
if !got.GSO.IsSuperpacket() {
t.Fatalf("expected superpacket GSO metadata, got %+v", got.GSO)
}
if got.GSO.Proto != GSOProtoTCP {
t.Errorf("GSO.Proto=%d want TCP", got.GSO.Proto)
}
if got.GSO.Size != uint16(gsoSize) {
t.Errorf("GSO.Size=%d want %d", got.GSO.Size, gsoSize)
}
if got.GSO.HdrLen != uint16(headerLen) {
t.Errorf("GSO.HdrLen=%d want %d", got.GSO.HdrLen, headerLen)
}
if got.GSO.CsumStart != uint16(ipv6HdrLen) {
t.Errorf("GSO.CsumStart=%d want %d", got.GSO.CsumStart, ipv6HdrLen)
}
if len(got.Bytes) != pktLen {
t.Errorf("len(Bytes)=%d want %d", len(got.Bytes), pktLen)
}
// rxOff advances exactly by the kernel-supplied body length — no
// segmentation output to account for any more.
if o.rxOff != startRxOff+pktLen {
t.Errorf("rxOff=%d want %d", o.rxOff, startRxOff+pktLen)
}
if o.rxOff > tunRxBufCap {
t.Fatalf("rxOff=%d overran rxBuf (cap=%d)", o.rxOff, tunRxBufCap)
}
// Validate that segmenting the returned superpacket reproduces the
// expected per-segment IPv6 payload length and TCP checksum.
wantSegs := (payLen + gsoSize - 1) / gsoSize
gotSegs := 0
if err := SegmentSuperpacket(got, func(seg []byte) error {
defer func() { gotSegs++ }()
if len(seg) < headerLen+1 {
t.Errorf("seg %d too short: %d", gotSegs, len(seg))
return nil
}
if seg[0]>>4 != 6 {
t.Errorf("seg %d: bad IP version %#x", gotSegs, seg[0])
}
segPay := len(seg) - headerLen
gotPL := binary.BigEndian.Uint16(seg[4:6])
if gotPL != uint16(tcpHdrLen+segPay) {
t.Errorf("seg %d: payload_len=%d want %d", gotSegs, gotPL, tcpHdrLen+segPay)
}
psum := pseudoHeaderIPv6(seg[8:24], seg[24:40], unix.IPPROTO_TCP, tcpHdrLen+segPay)
if !verifyChecksum(seg[ipv6HdrLen:], psum) {
t.Errorf("seg %d: bad TCP checksum", gotSegs)
}
return nil
}); err != nil {
t.Fatalf("SegmentSuperpacket: %v", err)
}
if gotSegs != wantSegs {
t.Fatalf("got %d segments, want %d", gotSegs, wantSegs)
}
}

View File

@@ -0,0 +1,43 @@
//go:build linux && !android && !e2e_testing
// +build linux,!android,!e2e_testing
package virtio
import "encoding/binary"
// Size is the on-wire length of struct virtio_net_hdr the kernel
// prepends/expects on a TUN opened with IFF_VNET_HDR (TUNSETVNETHDRSZ
// not set).
const Size = 10
// Hdr is the Go view of the legacy virtio_net_hdr.
type Hdr struct {
Flags uint8
GSOType uint8
HdrLen uint16
GSOSize uint16
CsumStart uint16
CsumOffset uint16
}
// Decode reads a virtio_net_hdr in host byte order (TUN default; we never
// call TUNSETVNETLE so the kernel matches our endianness).
func (h *Hdr) Decode(b []byte) {
h.Flags = b[0]
h.GSOType = b[1]
h.HdrLen = binary.NativeEndian.Uint16(b[2:4])
h.GSOSize = binary.NativeEndian.Uint16(b[4:6])
h.CsumStart = binary.NativeEndian.Uint16(b[6:8])
h.CsumOffset = binary.NativeEndian.Uint16(b[8:10])
}
// Encode is the inverse of Decode: writes the virtio_net_hdr fields into b
// (must be at least Size bytes). Used to emit a TSO superpacket on egress.
func (h *Hdr) Encode(b []byte) {
b[0] = h.Flags
b[1] = h.GSOType
binary.NativeEndian.PutUint16(b[2:4], h.HdrLen)
binary.NativeEndian.PutUint16(b[4:6], h.GSOSize)
binary.NativeEndian.PutUint16(b[6:8], h.CsumStart)
binary.NativeEndian.PutUint16(b[8:10], h.CsumOffset)
}

View File

@@ -0,0 +1,401 @@
//go:build linux && !android && !e2e_testing
// +build linux,!android,!e2e_testing
// Package virtio implements the pure validation, header-correction, and
// per-segment slicing logic for kernel-supplied TSO/USO superpackets on
// IFF_VNET_HDR TUN devices. It is FD-free and depends only on the byte
// layout of the virtio_net_hdr and the IP/TCP/UDP headers it describes,
// so it can be unit-tested in isolation from the tio Queue runtime.
package virtio
import (
"encoding/binary"
"errors"
"fmt"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/tcpip/checksum"
)
// Protocol header size bounds used to validate / cap kernel-supplied offsets.
const (
ipv4HeaderMinLen = 20 // IHL=5, no options
ipv4HeaderMaxLen = 60 // IHL=15, max options
ipv6FixedLen = 40 // IPv6 base header; extensions would extend this
tcpHeaderMinLen = 20 // data-offset=5, no options
tcpHeaderMaxLen = 60 // data-offset=15, max options
)
// Byte offsets inside an IPv4 header.
const (
ipv4TotalLenOff = 2
ipv4IDOff = 4
ipv4ChecksumOff = 10
ipv4SrcOff = 12
ipv4AddrsEnd = 20 // end of dst address (ipv4SrcOff + 2*4)
)
// Byte offsets inside an IPv6 header.
const (
ipv6PayloadLenOff = 4
ipv6SrcOff = 8
ipv6AddrsEnd = 40 // end of dst address (ipv6SrcOff + 2*16)
)
// Byte offsets inside a TCP header (relative to its start, i.e. csumStart).
const (
tcpSeqOff = 4
tcpDataOffOff = 12 // upper nibble is header len in 32-bit words
tcpFlagsOff = 13
tcpChecksumOff = 16
)
// UDP header is fixed at 8 bytes: {sport, dport, length, checksum}.
const (
udpHeaderLen = 8
udpLengthOff = 4
udpChecksumOff = 6
)
// tcpFinPshMask is cleared on every segment except the last of a TSO burst.
const tcpFinPshMask = 0x09 // FIN(0x01) | PSH(0x08)
// tcpCwrFlag is cleared on every segment except the first. Per RFC 3168
// §6.1.2 the CWR bit signals a one-shot transition (the sender just halved
// its window) and must appear on the first segment of a TSO burst only.
const tcpCwrFlag = 0x80
// CheckValid rejects packets whose virtio_net_hdr/IP combination would
// cause a downstream miscompute. The TUN should never emit RSC_INFO and
// the GSO type must agree with the IP version nibble.
func CheckValid(pkt []byte, hdr Hdr) error {
// When RSC_INFO is set the csum_start/csum_offset fields are repurposed to
// carry coalescing info rather than checksum offsets. A TUN writing via
// IFF_VNET_HDR should never emit this, but if it did we would silently
// miscompute the segment checksums — refuse the packet instead.
if hdr.Flags&unix.VIRTIO_NET_HDR_F_RSC_INFO != 0 {
return fmt.Errorf("virtio RSC_INFO flag not supported on TUN reads")
}
if len(pkt) < ipv4HeaderMinLen {
return fmt.Errorf("packet too short")
}
ipVersion := pkt[0] >> 4
switch hdr.GSOType {
case unix.VIRTIO_NET_HDR_GSO_TCPV4:
if ipVersion != 4 {
return fmt.Errorf("invalid IP version %d for GSO type %d", ipVersion, hdr.GSOType)
}
case unix.VIRTIO_NET_HDR_GSO_TCPV6:
if ipVersion != 6 {
return fmt.Errorf("invalid IP version %d for GSO type %d", ipVersion, hdr.GSOType)
}
case unix.VIRTIO_NET_HDR_GSO_UDP_L4:
// USO carries either v4 or v6; the leading nibble disambiguates.
if !(ipVersion == 4 || ipVersion == 6) {
return fmt.Errorf("invalid IP version %d for GSO type %d", ipVersion, hdr.GSOType)
}
default:
if !(ipVersion == 6 || ipVersion == 4) {
return fmt.Errorf("invalid IP version %d for GSO type %d", ipVersion, hdr.GSOType)
}
}
return nil
}
// CorrectHdrLen rewrites hdr.HdrLen based on the actual transport header
// length read out of pkt. The kernel's hdr.HdrLen on the FORWARD path can
// be the length of the entire first packet, so we don't trust it.
func CorrectHdrLen(pkt []byte, hdr *Hdr) error {
// Thank you wireguard-go for documenting these edge-cases
// Don't trust hdr.hdrLen from the kernel as it can be equal to the length
// of the entire first packet when the kernel is handling it as part of a
// FORWARD path. Instead, parse the transport header length and add it onto
// csumStart, which is synonymous for IP header length.
if hdr.GSOType == unix.VIRTIO_NET_HDR_GSO_UDP_L4 {
hdr.HdrLen = hdr.CsumStart + 8
} else {
if len(pkt) <= int(hdr.CsumStart+tcpDataOffOff) {
return errors.New("packet is too short")
}
tcpHLen := uint16(pkt[hdr.CsumStart+tcpDataOffOff] >> 4 * 4)
if tcpHLen < 20 || tcpHLen > 60 {
// A TCP header must be between 20 and 60 bytes in length.
return fmt.Errorf("tcp header len is invalid: %d", tcpHLen)
}
hdr.HdrLen = hdr.CsumStart + tcpHLen
}
if len(pkt) < int(hdr.HdrLen) {
return fmt.Errorf("length of packet (%d) < virtioNetHdr.HdrLen (%d)", len(pkt), hdr.HdrLen)
}
if hdr.HdrLen < hdr.CsumStart {
return fmt.Errorf("virtioNetHdr.HdrLen (%d) < virtioNetHdr.CsumStart (%d)", hdr.HdrLen, hdr.CsumStart)
}
cSumAt := int(hdr.CsumStart + hdr.CsumStart)
if cSumAt+1 >= len(pkt) {
return fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(pkt))
}
return nil
}
// SegmentTCP walks a TSO superpacket pkt, yielding each segment as a
// slice into pkt itself. Per-segment plaintext is laid out by sliding a
// freshly-patched copy of the L3+L4 header into pkt at offset i*gsoSize,
// where it sits immediately before that segment's payload chunk in the
// original buffer. The slide is destructive: iter i's header write overwrites
// the last hdrLen bytes of seg_{i-1}'s payload, which is dead by the time
// the next iteration begins. pkt is consumed by this call and must not be
// inspected by the caller after the final yield.
func SegmentTCP(pkt []byte, hdrLenU, csumStartU, gsoSizeU uint16, yield func(seg []byte) error) error {
if gsoSizeU == 0 {
return fmt.Errorf("gso_size is zero")
}
if csumStartU == 0 {
return fmt.Errorf("csum_start is zero")
}
headerLen := int(hdrLenU)
csumStart := int(csumStartU)
isV4 := pkt[0]>>4 == 4
tcpHdrLen := int(pkt[csumStart+tcpDataOffOff]>>4) * 4
payLen := len(pkt) - headerLen
gsoSize := int(gsoSizeU)
numSeg := (payLen + gsoSize - 1) / gsoSize
if numSeg == 0 {
numSeg = 1
}
origSeq := binary.BigEndian.Uint32(pkt[csumStart+tcpSeqOff : csumStart+tcpSeqOff+4])
origFlags := pkt[csumStart+tcpFlagsOff]
var tmp [tcpHeaderMaxLen]byte
copy(tmp[:tcpHdrLen], pkt[csumStart:headerLen])
tmp[tcpSeqOff], tmp[tcpSeqOff+1], tmp[tcpSeqOff+2], tmp[tcpSeqOff+3] = 0, 0, 0, 0
tmp[tcpFlagsOff] = 0
tmp[tcpChecksumOff], tmp[tcpChecksumOff+1] = 0, 0
baseTcpHdrSum := uint32(checksum.Checksum(tmp[:tcpHdrLen], 0))
var baseProtoSum uint32
if isV4 {
baseProtoSum = uint32(checksum.Checksum(pkt[ipv4SrcOff:ipv4AddrsEnd], 0))
} else {
baseProtoSum = uint32(checksum.Checksum(pkt[ipv6SrcOff:ipv6AddrsEnd], 0))
}
baseProtoSum += uint32(unix.IPPROTO_TCP)
var origIPID uint16
var baseIPHdrSum uint32
if isV4 {
origIPID = binary.BigEndian.Uint16(pkt[ipv4IDOff : ipv4IDOff+2])
ihl := int(pkt[0]&0x0f) * 4
if ihl < ipv4HeaderMinLen || ihl > csumStart {
return fmt.Errorf("bad IPv4 IHL: %d", ihl)
}
var ipTmp [ipv4HeaderMaxLen]byte
copy(ipTmp[:ihl], pkt[:ihl])
ipTmp[ipv4TotalLenOff], ipTmp[ipv4TotalLenOff+1] = 0, 0
ipTmp[ipv4IDOff], ipTmp[ipv4IDOff+1] = 0, 0
ipTmp[ipv4ChecksumOff], ipTmp[ipv4ChecksumOff+1] = 0, 0
baseIPHdrSum = uint32(checksum.Checksum(ipTmp[:ihl], 0))
}
for i := 0; i < numSeg; i++ {
segStart := i * gsoSize
segEnd := segStart + gsoSize
if segEnd > payLen {
segEnd = payLen
}
segPayLen := segEnd - segStart
segLen := headerLen + segPayLen
headerOff := i * gsoSize
// Slide the header into place immediately before this segment's
// payload. Iter 0's header is already at pkt[:headerLen]; for
// i ≥ 1 we copy from there. The constant-byte fields of pkt[:headerLen]
// survive iter 0's in-place patches (only seq/flags/cksum/totalLen/id
// are touched), and iter 0's stale variable-field values are
// overwritten by the per-segment patches below.
if i > 0 {
copy(pkt[headerOff:headerOff+headerLen], pkt[:headerLen])
}
seg := pkt[headerOff : headerOff+segLen]
segSeq := origSeq + uint32(segStart)
segFlags := origFlags
if i != 0 {
segFlags &^= tcpCwrFlag
}
if i != numSeg-1 {
segFlags &^= tcpFinPshMask
}
totalLen := segLen
if isV4 {
segID := origIPID + uint16(i)
binary.BigEndian.PutUint16(seg[ipv4TotalLenOff:ipv4TotalLenOff+2], uint16(totalLen))
binary.BigEndian.PutUint16(seg[ipv4IDOff:ipv4IDOff+2], segID)
ipSum := baseIPHdrSum + uint32(totalLen) + uint32(segID)
binary.BigEndian.PutUint16(seg[ipv4ChecksumOff:ipv4ChecksumOff+2], foldComplement(ipSum))
} else {
binary.BigEndian.PutUint16(seg[ipv6PayloadLenOff:ipv6PayloadLenOff+2], uint16(headerLen-ipv6FixedLen+segPayLen))
}
binary.BigEndian.PutUint32(seg[csumStart+tcpSeqOff:csumStart+tcpSeqOff+4], segSeq)
seg[csumStart+tcpFlagsOff] = segFlags
tcpLen := tcpHdrLen + segPayLen
// Payload bytes still live at their original offset in pkt. The
// header slide above only writes into pkt[i*G : i*G+H], which is
// the tail of seg_{i-1}'s payload (already consumed) and never
// overlaps seg_i's own payload at pkt[H+i*G : H+(i+1)*G].
paySum := uint32(checksum.Checksum(pkt[headerLen+segStart:headerLen+segEnd], 0))
wide := uint64(baseTcpHdrSum) + uint64(paySum) + uint64(baseProtoSum)
wide += uint64(segSeq) + uint64(segFlags) + uint64(tcpLen)
wide = (wide & 0xffffffff) + (wide >> 32)
wide = (wide & 0xffffffff) + (wide >> 32)
binary.BigEndian.PutUint16(seg[csumStart+tcpChecksumOff:csumStart+tcpChecksumOff+2], foldComplement(uint32(wide)))
if err := yield(seg); err != nil {
return err
}
}
return nil
}
// SegmentUDP walks a USO superpacket, sliding a per-segment-patched
// L3+L4 header into pkt at offset i*gsoSize and yielding pkt[i*G:i*G+segLen]
// to the caller. Per-segment patches are total_len + IPv4 csum (or IPv6
// payload_len) plus the UDP length and checksum. pkt is consumed
// destructively; see SegmentTCP for the layout reasoning.
//
// UDP-GSO leaves the IPv4 ID identical across segments (the kernel does not
// bump it), which is why the IP-level per-segment work is limited to
// total_len + IPv4 header checksum (v4) or payload_len (v6).
func SegmentUDP(pkt []byte, hdrLenU, csumStartU, gsoSizeU uint16, yield func(seg []byte) error) error {
if gsoSizeU == 0 {
return fmt.Errorf("gso_size is zero")
}
if csumStartU == 0 {
return fmt.Errorf("csum_start is zero")
}
isV4 := pkt[0]>>4 == 4
headerLen := int(hdrLenU)
csumStart := int(csumStartU)
if headerLen-csumStart != udpHeaderLen {
return fmt.Errorf("udp header len mismatch: %d", headerLen-csumStart)
}
payLen := len(pkt) - headerLen
gsoSize := int(gsoSizeU)
numSeg := (payLen + gsoSize - 1) / gsoSize
if numSeg == 0 {
numSeg = 1
}
var udpTmp [udpHeaderLen]byte
copy(udpTmp[:], pkt[csumStart:headerLen])
udpTmp[udpLengthOff], udpTmp[udpLengthOff+1] = 0, 0
udpTmp[udpChecksumOff], udpTmp[udpChecksumOff+1] = 0, 0
baseUDPHdrSum := uint32(checksum.Checksum(udpTmp[:], 0))
var baseProtoSum uint32
if isV4 {
baseProtoSum = uint32(checksum.Checksum(pkt[ipv4SrcOff:ipv4AddrsEnd], 0))
} else {
baseProtoSum = uint32(checksum.Checksum(pkt[ipv6SrcOff:ipv6AddrsEnd], 0))
}
baseProtoSum += uint32(unix.IPPROTO_UDP)
var baseIPHdrSum uint32
if isV4 {
ihl := int(pkt[0]&0x0f) * 4
if ihl < ipv4HeaderMinLen || ihl > csumStart {
return fmt.Errorf("bad IPv4 IHL: %d", ihl)
}
var ipTmp [ipv4HeaderMaxLen]byte
copy(ipTmp[:ihl], pkt[:ihl])
ipTmp[ipv4TotalLenOff], ipTmp[ipv4TotalLenOff+1] = 0, 0
ipTmp[ipv4ChecksumOff], ipTmp[ipv4ChecksumOff+1] = 0, 0
baseIPHdrSum = uint32(checksum.Checksum(ipTmp[:ihl], 0))
}
for i := 0; i < numSeg; i++ {
segStart := i * gsoSize
segEnd := segStart + gsoSize
if segEnd > payLen {
segEnd = payLen
}
segPayLen := segEnd - segStart
segLen := headerLen + segPayLen
headerOff := i * gsoSize
if i > 0 {
copy(pkt[headerOff:headerOff+headerLen], pkt[:headerLen])
}
seg := pkt[headerOff : headerOff+segLen]
totalLen := segLen
udpLen := udpHeaderLen + segPayLen
if isV4 {
binary.BigEndian.PutUint16(seg[ipv4TotalLenOff:ipv4TotalLenOff+2], uint16(totalLen))
ipSum := baseIPHdrSum + uint32(totalLen)
binary.BigEndian.PutUint16(seg[ipv4ChecksumOff:ipv4ChecksumOff+2], foldComplement(ipSum))
} else {
binary.BigEndian.PutUint16(seg[ipv6PayloadLenOff:ipv6PayloadLenOff+2], uint16(headerLen-ipv6FixedLen+segPayLen))
}
binary.BigEndian.PutUint16(seg[csumStart+udpLengthOff:csumStart+udpLengthOff+2], uint16(udpLen))
paySum := uint32(checksum.Checksum(pkt[headerLen+segStart:headerLen+segEnd], 0))
wide := uint64(baseUDPHdrSum) + uint64(paySum) + uint64(baseProtoSum)
wide += uint64(udpLen) + uint64(udpLen)
wide = (wide & 0xffffffff) + (wide >> 32)
wide = (wide & 0xffffffff) + (wide >> 32)
csum := foldComplement(uint32(wide))
if csum == 0 {
csum = 0xffff
}
binary.BigEndian.PutUint16(seg[csumStart+udpChecksumOff:csumStart+udpChecksumOff+2], csum)
if err := yield(seg); err != nil {
return err
}
}
return nil
}
// FinishChecksum computes the L4 checksum for a non-GSO packet that the kernel
// handed us with NEEDS_CSUM set. csum_start / csum_offset point at the 16-bit
// checksum field; we zero it, fold a full sum (the field was pre-loaded with
// the pseudo-header partial sum by the kernel), and store the result.
func FinishChecksum(seg []byte, hdr Hdr) error {
cs := int(hdr.CsumStart)
co := int(hdr.CsumOffset)
if cs+co+2 > len(seg) {
return fmt.Errorf("csum offsets out of range: start=%d offset=%d len=%d", cs, co, len(seg))
}
// The kernel stores a partial pseudo-header sum at [cs+co:]; sum over the
// L4 region starting at cs, folding the prior partial in as the seed.
partial := binary.BigEndian.Uint16(seg[cs+co : cs+co+2])
seg[cs+co] = 0
seg[cs+co+1] = 0
binary.BigEndian.PutUint16(seg[cs+co:cs+co+2], ^checksum.Checksum(seg[cs:], partial))
return nil
}
// foldComplement folds a 32-bit one's-complement partial sum to 16 bits and
// complements it, yielding the on-wire Internet checksum value.
func foldComplement(sum uint32) uint16 {
sum = (sum & 0xffff) + (sum >> 16)
sum = (sum & 0xffff) + (sum >> 16)
return ^uint16(sum)
}

View File

@@ -27,15 +27,15 @@ type tun struct {
l *slog.Logger
readBuf []byte
batchRet [1][]byte
batchRet [1]tio.Packet
}
func (t *tun) Read() ([][]byte, error) {
func (t *tun) Read() ([]tio.Packet, error) {
n, err := t.rwc.Read(t.readBuf)
if err != nil {
return nil, err
}
t.batchRet[0] = t.readBuf[:n]
t.batchRet[0] = tio.Packet{Bytes: t.readBuf[:n]}
return t.batchRet[:], nil
}

View File

@@ -37,7 +37,7 @@ type tun struct {
out []byte
readBuf []byte
batchRet [1][]byte
batchRet [1]tio.Packet
}
type ifReq struct {
@@ -516,12 +516,12 @@ func (t *tun) readOne(to []byte) (int, error) {
return n - 4, err
}
func (t *tun) Read() ([][]byte, error) {
func (t *tun) Read() ([]tio.Packet, error) {
n, err := t.readOne(t.readBuf)
if err != nil {
return nil, err
}
t.batchRet[0] = t.readBuf[:n]
t.batchRet[0] = tio.Packet{Bytes: t.readBuf[:n]}
return t.batchRet[:], nil
}

View File

@@ -23,23 +23,41 @@ type disabledTun struct {
rx metrics.Counter
l *slog.Logger
numReaders int
batchRet [1][]byte
}
func (t *disabledTun) Read() ([][]byte, error) {
r, ok := <-t.read
// disabledQueue is one tio.Queue view onto a shared disabledTun. Each queue
// owns a private batchRet so concurrent Read calls from different reader
// goroutines do not race on the returned slice.
type disabledQueue struct {
parent *disabledTun
batchRet [1]tio.Packet
}
func (q *disabledQueue) Read() ([]tio.Packet, error) {
r, ok := <-q.parent.read
if !ok {
return nil, io.EOF
}
t.tx.Inc(1)
if t.l.Enabled(context.Background(), slog.LevelDebug) {
t.l.Debug("Write payload", "raw", prettyPacket(r))
q.parent.tx.Inc(1)
if q.parent.l.Enabled(context.Background(), slog.LevelDebug) {
q.parent.l.Debug("Write payload", "raw", prettyPacket(r))
}
t.batchRet[0] = r
return t.batchRet[:], nil
q.batchRet[0] = tio.Packet{Bytes: r}
return q.batchRet[:], nil
}
// Write on a queue forwards to the underlying disabledTun. All queues share
// one ICMP-handling/log path so this is a thin pass-through.
func (q *disabledQueue) Write(b []byte) (int, error) {
return q.parent.Write(b)
}
// Close on a queue is a no-op. The shared channel and metrics are owned by
// the disabledTun; Close on the device tears them down once for everybody.
func (q *disabledQueue) Close() error {
return nil
}
func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *slog.Logger) *disabledTun {
@@ -120,7 +138,7 @@ func (t *disabledTun) NewMultiQueueReader() error {
func (t *disabledTun) Readers() []tio.Queue {
out := make([]tio.Queue, t.numReaders)
for i := range t.numReaders {
out[i] = t
out[i] = &disabledQueue{parent: t}
}
return out
}

View File

@@ -104,7 +104,7 @@ type tun struct {
closed atomic.Bool
readBuf []byte
batchRet [1][]byte
batchRet [1]tio.Packet
}
// blockOnRead waits until the tun fd is readable or shutdown has been signaled.
@@ -159,12 +159,12 @@ func (t *tun) blockOnWrite() error {
return nil
}
func (t *tun) Read() ([][]byte, error) {
func (t *tun) Read() ([]tio.Packet, error) {
n, err := t.readOne(t.readBuf)
if err != nil {
return nil, err
}
t.batchRet[0] = t.readBuf[:n]
t.batchRet[0] = tio.Packet{Bytes: t.readBuf[:n]}
return t.batchRet[:], nil
}

View File

@@ -29,15 +29,15 @@ type tun struct {
l *slog.Logger
readBuf []byte
batchRet [1][]byte
batchRet [1]tio.Packet
}
func (t *tun) Read() ([][]byte, error) {
func (t *tun) Read() ([]tio.Packet, error) {
n, err := t.rwc.Read(t.readBuf)
if err != nil {
return nil, err
}
t.batchRet[0] = t.readBuf[:n]
t.batchRet[0] = tio.Packet{Bytes: t.readBuf[:n]}
return t.batchRet[:], nil
}

View File

@@ -25,7 +25,7 @@ import (
)
type tun struct {
readers tio.Container
readers tio.QueueSet
closeLock sync.Mutex
Device string
vpnNetworks []netip.Prefix
@@ -34,6 +34,14 @@ type tun struct {
TXQueueLen int
deviceIndex int
ioctlFd uintptr
vnetHdr bool
// routeFeatureECN, when true, sets RTAX_FEATURE_ECN on every route we
// install for the tun. The kernel then actively negotiates ECN for
// connections destined to those prefixes (equivalent to `ip route
// change ... features ecn`) regardless of net.ipv4.tcp_ecn, so flows
// across the nebula mesh use ECN even when the host default is the
// passive setting (=2). Disable via tunnels.ecn=false.
routeFeatureECN bool
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
@@ -72,7 +80,9 @@ type ifreqQLEN struct {
}
func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
t, err := newTunGeneric(c, l, deviceFd, vpnNetworks)
// We don't know what flags the caller opened this fd with and can't turn
// on IFF_VNET_HDR after TUNSETIFF, so skip offload on inherited fds.
t, err := newTunGeneric(c, l, deviceFd, false, false, vpnNetworks)
if err != nil {
return nil, err
}
@@ -117,6 +127,18 @@ func tunSetIff(fd int, name string, flags uint16) (string, error) {
return strings.Trim(string(req.Name[:]), "\x00"), nil
}
// tsoOffloadFlags are the TUN_F_* bits we ask the kernel to enable when a
// TSO-capable TUN is available. CSUM is required as a prerequisite for TSO.
// TSO_ECN tells the kernel we propagate ECN correctly through coalesce and
// segmentation, so it can deliver superpackets whose seed has CWR/ECE set
// or whose IP-level codepoint is CE.
const tsoOffloadFlags = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6 | unix.TUN_F_TSO_ECN
// usoOffloadFlags adds UDP Segmentation Offload to tsoOffloadFlags. Requires
// Linux ≥ 6.2; older kernels reject it and we fall back to TCP-only TSO via
// tsoOffloadFlags.
const usoOffloadFlags = tsoOffloadFlags | unix.TUN_F_USO4 | unix.TUN_F_USO6
func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
baseFlags := uint16(unix.IFF_TUN | unix.IFF_NO_PI)
if multiqueue {
@@ -124,17 +146,51 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, multiqueue
}
nameStr := c.GetString("tun.dev", "")
// First try to enable IFF_VNET_HDR via TUNSETIFF and negotiate TUN_F_*
// offloads via TUNSETOFFLOAD so we can receive TSO/USO superpackets.
// We try TSO+USO first, fall back to TSO-only on kernels without USO
// (Linux < 6.2), and finally give up on virtio headers entirely and
// reopen as a plain TUN if neither offload mask is accepted.
fd, err := openTunDev()
if err != nil {
return nil, err
}
name, err := tunSetIff(fd, nameStr, baseFlags)
vnetHdr := true
usoEnabled := false
name, err := tunSetIff(fd, nameStr, baseFlags|unix.IFF_VNET_HDR)
if err != nil {
_ = unix.Close(fd)
return nil, &NameError{Name: nameStr, Underlying: err}
vnetHdr = false
} else {
// Try TSO+USO first. On kernels without USO support (Linux < 6.2)
// the ioctl returns EINVAL; fall back to the TCP-only mask before
// giving up on VNET_HDR entirely.
if err = ioctl(uintptr(fd), unix.TUNSETOFFLOAD, uintptr(usoOffloadFlags)); err == nil {
usoEnabled = true
} else if err = ioctl(uintptr(fd), unix.TUNSETOFFLOAD, uintptr(tsoOffloadFlags)); err != nil {
l.Warn("Failed to enable TUN offload (TSO); proceeding without virtio headers", "error", err)
_ = unix.Close(fd)
vnetHdr = false
}
}
t, err := newTunGeneric(c, l, fd, vpnNetworks)
if !vnetHdr {
fd, err = openTunDev()
if err != nil {
return nil, err
}
name, err = tunSetIff(fd, nameStr, baseFlags)
if err != nil {
_ = unix.Close(fd)
return nil, &NameError{Name: nameStr, Underlying: err}
}
}
if vnetHdr {
l.Info("TUN offload enabled", "tso", true, "uso", usoEnabled)
}
t, err := newTunGeneric(c, l, fd, vnetHdr, usoEnabled, vpnNetworks)
if err != nil {
return nil, err
}
@@ -145,25 +201,34 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, multiqueue
}
// newTunGeneric does all the stuff common to different tun initialization paths. It will close your files on error.
func newTunGeneric(c *config.C, l *slog.Logger, fd int, vpnNetworks []netip.Prefix) (*tun, error) {
container, err := tio.NewPollContainer()
func newTunGeneric(c *config.C, l *slog.Logger, fd int, vnetHdr, usoEnabled bool, vpnNetworks []netip.Prefix) (*tun, error) {
var qs tio.QueueSet
var err error
if vnetHdr {
qs, err = tio.NewOffloadQueueSet(usoEnabled)
} else {
qs, err = tio.NewPollQueueSet()
}
if err != nil {
_ = unix.Close(fd)
return nil, err
}
err = container.Add(fd)
err = qs.Add(fd)
if err != nil {
_ = unix.Close(fd)
return nil, err
}
t := &tun{
readers: container,
readers: qs,
closeLock: sync.Mutex{},
vnetHdr: vnetHdr,
vpnNetworks: vpnNetworks,
TXQueueLen: c.GetInt("tun.tx_queue", 500),
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0),
routeFeatureECN: c.GetBool("tunnels.ecn", true),
routesFromSystem: map[netip.Prefix]routing.Gateways{},
l: l,
}
@@ -271,11 +336,21 @@ func (t *tun) NewMultiQueueReader() error {
}
flags := uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE)
if t.vnetHdr {
flags |= unix.IFF_VNET_HDR
}
if _, err = tunSetIff(fd, t.Device, flags); err != nil {
_ = unix.Close(fd)
return err
}
if t.vnetHdr {
if err = ioctl(uintptr(fd), unix.TUNSETOFFLOAD, uintptr(tsoOffloadFlags)); err != nil {
_ = unix.Close(fd)
return fmt.Errorf("failed to enable offload on multiqueue tun fd: %w", err)
}
}
err = t.readers.Add(fd)
if err != nil {
_ = unix.Close(fd)
@@ -450,6 +525,18 @@ func (t *tun) setDefaultRoute(cidr netip.Prefix) error {
Table: unix.RT_TABLE_MAIN,
Type: unix.RTN_UNICAST,
}
// Match the metric the kernel uses for its auto-installed connected
// route, so RouteReplace overwrites it in place instead of adding a
// second route at a worse metric. IPv6 connected routes are installed
// at metric 256 (IP6_RT_PRIO_KERN); IPv4 uses 0. Without this, the
// kernel route wins lookups and our MTU / AdvMSS / Features never
// apply on v6.
if cidr.Addr().Is6() {
nr.Priority = 256
}
if t.routeFeatureECN {
nr.Features |= unix.RTAX_FEATURE_ECN
}
err := netlink.RouteReplace(&nr)
if err != nil {
t.l.Warn("Failed to set default route MTU, retrying", "error", err, "cidr", cidr)
@@ -499,6 +586,9 @@ func (t *tun) addRoutes(logErrors bool) error {
if r.Metric > 0 {
nr.Priority = r.Metric
}
if t.routeFeatureECN {
nr.Features |= unix.RTAX_FEATURE_ECN
}
err := netlink.RouteReplace(&nr)
if err != nil {

View File

@@ -68,15 +68,15 @@ type tun struct {
fd int
readBuf []byte
batchRet [1][]byte
batchRet [1]tio.Packet
}
func (t *tun) Read() ([][]byte, error) {
func (t *tun) Read() ([]tio.Packet, error) {
n, err := t.readOne(t.readBuf)
if err != nil {
return nil, err
}
t.batchRet[0] = t.readBuf[:n]
t.batchRet[0] = tio.Packet{Bytes: t.readBuf[:n]}
return t.batchRet[:], nil
}

View File

@@ -61,15 +61,15 @@ type tun struct {
out []byte
readBuf []byte
batchRet [1][]byte
batchRet [1]tio.Packet
}
func (t *tun) Read() ([][]byte, error) {
func (t *tun) Read() ([]tio.Packet, error) {
n, err := t.readOne(t.readBuf)
if err != nil {
return nil, err
}
t.batchRet[0] = t.readBuf[:n]
t.batchRet[0] = tio.Packet{Bytes: t.readBuf[:n]}
return t.batchRet[:], nil
}

View File

@@ -30,7 +30,7 @@ type TestTun struct {
rxPackets chan []byte // Packets to receive into nebula
TxPackets chan []byte // Packets transmitted outside by nebula
batchRet [1][]byte
batchRet [1]tio.Packet
}
func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) {
@@ -51,7 +51,9 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*T
l: l,
rxPackets: make(chan []byte, 10),
TxPackets: make(chan []byte, 10),
batchRet: [1][]byte{make([]byte, udp.MTU)},
batchRet: [1]tio.Packet{
tio.Packet{Bytes: make([]byte, udp.MTU)},
},
}, nil
}
@@ -166,13 +168,13 @@ func (t *TestTun) Close() error {
return nil
}
func (t *TestTun) Read() ([][]byte, error) {
t.batchRet[0] = t.batchRet[0][:udp.MTU]
n, err := t.read(t.batchRet[0])
func (t *TestTun) Read() ([]tio.Packet, error) {
t.batchRet[0].Bytes = t.batchRet[0].Bytes[:udp.MTU]
n, err := t.read(t.batchRet[0].Bytes)
if err != nil {
return nil, err
}
t.batchRet[0] = t.batchRet[0][:n]
t.batchRet[0].Bytes = t.batchRet[0].Bytes[:n]
return t.batchRet[:], nil
}

View File

@@ -47,15 +47,15 @@ type winTun struct {
tun *wintun.NativeTun
readBuf []byte
batchRet [1][]byte
batchRet [1]tio.Packet
}
func (t *winTun) Read() ([][]byte, error) {
func (t *winTun) Read() ([]tio.Packet, error) {
n, err := t.tun.Read(t.readBuf, 0)
if err != nil {
return nil, err
}
t.batchRet[0] = t.readBuf[:n]
t.batchRet[0] = tio.Packet{Bytes: t.readBuf[:n]}
return t.batchRet[:], nil
}

View File

@@ -39,10 +39,10 @@ type UserDevice struct {
inboundWriter *io.PipeWriter
readBuf []byte
batchRet [1][]byte
batchRet [1]tio.Packet
}
func (d *UserDevice) Read() ([][]byte, error) {
func (d *UserDevice) Read() ([]tio.Packet, error) {
if d.readBuf == nil {
d.readBuf = make([]byte, defaultBatchBufSize)
}
@@ -50,7 +50,7 @@ func (d *UserDevice) Read() ([][]byte, error) {
if err != nil {
return nil, err
}
d.batchRet[0] = d.readBuf[:n]
d.batchRet[0] = tio.Packet{Bytes: d.readBuf[:n]}
return d.batchRet[:], nil
}