This commit is contained in:
JackDoan
2026-04-21 13:31:16 -05:00
parent bf4e37e99d
commit ad6b918e4d
28 changed files with 1039 additions and 698 deletions

View File

@@ -0,0 +1,484 @@
package coalesce
import (
"bytes"
"encoding/binary"
"io"
"github.com/slackhq/nebula/overlay/tio"
)
// ipProtoTCP is the IANA protocol number for TCP. Hardcoded instead of
// reaching for golang.org/x/sys/unix — that package doesn't define the
// constant on Windows, which would break cross-compiles even though this
// file runs unchanged on every platform.
const ipProtoTCP = 6
// tcpCoalesceBufSize caps total bytes per superpacket. Mirrors the kernel's
// sk_gso_max_size of ~64KiB; anything beyond this would be rejected anyway.
const tcpCoalesceBufSize = 65535
// tcpCoalesceMaxSegs caps how many segments we'll coalesce into a single
// superpacket. Keeping this well below the kernel's TSO ceiling bounds
// latency.
const tcpCoalesceMaxSegs = 64
// tcpCoalesceHdrCap is the scratch space we copy a seed's IP+TCP header
// into. IPv6 (40) + TCP with full options (60) = 100 bytes.
const tcpCoalesceHdrCap = 100
// initialSlots is the starting capacity of the slot pool. One flow per
// packet is the worst case so this matches a typical UDP recvmmsg batch.
const initialSlots = 64
// flowKey identifies a TCP flow by {src, dst, sport, dport, family}.
// Comparable, so linear scans over the slot list stay tight.
type flowKey struct {
src, dst [16]byte
sport, dport uint16
isV6 bool
}
// coalesceSlot is one entry in the coalescer's ordered event queue. When
// passthrough is true the slot holds a single borrowed packet that must be
// emitted verbatim (non-TCP, non-admissible TCP, or oversize seed). When
// passthrough is false the slot is an in-progress coalesced superpacket:
// hdrBuf is a mutable copy of the seed's IP+TCP header (we patch total
// length and pseudo-header partial at flush), and payIovs are *borrowed*
// slices from the caller's plaintext buffers — no payload is ever copied.
// The caller (listenOut) must keep those buffers alive until Flush.
type coalesceSlot struct {
passthrough bool
rawPkt []byte // borrowed when passthrough
fk flowKey
hdrBuf [tcpCoalesceHdrCap]byte
hdrLen int
ipHdrLen int
isV6 bool
gsoSize int
numSeg int
totalPay int
nextSeq uint32
// psh closes the chain: set when the last-accepted segment had PSH or
// was sub-gsoSize. No further appends after that.
psh bool
payIovs [][]byte
}
// TCPCoalescer accumulates adjacent in-flow TCP data segments across
// multiple concurrent flows and emits each flow's run as a single TSO
// superpacket via tio.GSOWriter. All output — coalesced or not — is
// deferred until Flush so arrival order is preserved on the wire. Owns
// no locks; one coalescer per TUN write queue.
type TCPCoalescer struct {
plainW io.Writer
gsoW tio.GSOWriter // nil when the queue doesn't support TSO
// slots is the ordered event queue. Flush walks it once and emits each
// entry as either a WriteGSO (coalesced) or a plainW.Write (passthrough).
slots []*coalesceSlot
// openSlots maps a flow key to its most recent non-sealed slot, so new
// segments can extend an in-progress superpacket in O(1). Slots are
// removed from this map when they close (PSH or short-last-segment),
// when a non-admissible packet for that flow arrives, or in Flush.
openSlots map[flowKey]*coalesceSlot
pool []*coalesceSlot // free list for reuse
}
func NewTCPCoalescer(w io.Writer) *TCPCoalescer {
c := &TCPCoalescer{
plainW: w,
slots: make([]*coalesceSlot, 0, initialSlots),
openSlots: make(map[flowKey]*coalesceSlot, initialSlots),
pool: make([]*coalesceSlot, 0, initialSlots),
}
if gw, ok := w.(tio.GSOWriter); ok && gw.GSOSupported() {
c.gsoW = gw
}
return c
}
// parsedTCP holds the fields extracted from a single parse so later steps
// (admission, slot lookup, canAppend) don't re-walk the header.
type parsedTCP struct {
fk flowKey
ipHdrLen int
tcpHdrLen int
hdrLen int
payLen int
seq uint32
flags byte
}
// parseTCPBase extracts the flow key and IP/TCP offsets for any TCP packet,
// regardless of whether it's admissible for coalescing. Returns ok=false
// for non-TCP or malformed input. Accepts IPv4 (no options, no fragmentation)
// and IPv6 (no extension headers).
func parseTCPBase(pkt []byte) (parsedTCP, bool) {
var p parsedTCP
if len(pkt) < 20 {
return p, false
}
v := pkt[0] >> 4
switch v {
case 4:
ihl := int(pkt[0]&0x0f) * 4
if ihl != 20 {
return p, false
}
if pkt[9] != ipProtoTCP {
return p, false
}
// Reject actual fragmentation (MF or non-zero frag offset).
if binary.BigEndian.Uint16(pkt[6:8])&0x3fff != 0 {
return p, false
}
totalLen := int(binary.BigEndian.Uint16(pkt[2:4]))
if totalLen > len(pkt) || totalLen < ihl {
return p, false
}
p.ipHdrLen = 20
p.fk.isV6 = false
copy(p.fk.src[:4], pkt[12:16])
copy(p.fk.dst[:4], pkt[16:20])
pkt = pkt[:totalLen]
case 6:
if len(pkt) < 40 {
return p, false
}
if pkt[6] != ipProtoTCP {
return p, false
}
payloadLen := int(binary.BigEndian.Uint16(pkt[4:6]))
if 40+payloadLen > len(pkt) {
return p, false
}
p.ipHdrLen = 40
p.fk.isV6 = true
copy(p.fk.src[:], pkt[8:24])
copy(p.fk.dst[:], pkt[24:40])
pkt = pkt[:40+payloadLen]
default:
return p, false
}
if len(pkt) < p.ipHdrLen+20 {
return p, false
}
tcpOff := int(pkt[p.ipHdrLen+12]>>4) * 4
if tcpOff < 20 || tcpOff > 60 {
return p, false
}
if len(pkt) < p.ipHdrLen+tcpOff {
return p, false
}
p.tcpHdrLen = tcpOff
p.hdrLen = p.ipHdrLen + tcpOff
p.payLen = len(pkt) - p.hdrLen
p.seq = binary.BigEndian.Uint32(pkt[p.ipHdrLen+4 : p.ipHdrLen+8])
p.flags = pkt[p.ipHdrLen+13]
p.fk.sport = binary.BigEndian.Uint16(pkt[p.ipHdrLen : p.ipHdrLen+2])
p.fk.dport = binary.BigEndian.Uint16(pkt[p.ipHdrLen+2 : p.ipHdrLen+4])
return p, true
}
// coalesceable reports whether a parsed TCP segment is eligible for
// coalescing. Accepts only ACK or ACK|PSH with a non-empty payload.
func (p parsedTCP) coalesceable() bool {
const ack = 0x10
const psh = 0x08
if p.flags&^(ack|psh) != 0 || p.flags&ack == 0 {
return false
}
return p.payLen > 0
}
// Add borrows pkt. The caller must keep pkt valid until the next Flush,
// whether or not the packet was coalesced — passthrough (non-admissible)
// packets are queued and written at Flush time, not synchronously.
func (c *TCPCoalescer) Add(pkt []byte) error {
if c.gsoW == nil {
c.addPassthrough(pkt)
return nil
}
info, ok := parseTCPBase(pkt)
if !ok {
// Non-TCP or malformed — can't possibly collide with an open flow.
c.addPassthrough(pkt)
return nil
}
if !info.coalesceable() {
// TCP but not admissible (SYN/FIN/RST/URG/CWR/ECE or zero-payload).
// Seal this flow's open slot so later in-flow packets don't extend
// it and accidentally reorder past this passthrough.
delete(c.openSlots, info.fk)
c.addPassthrough(pkt)
return nil
}
if open := c.openSlots[info.fk]; open != nil {
if c.canAppend(open, pkt, info) {
c.appendPayload(open, pkt, info)
if open.psh {
delete(c.openSlots, info.fk)
}
return nil
}
// Can't extend — seal it and fall through to seed a fresh slot.
delete(c.openSlots, info.fk)
}
c.seed(pkt, info)
return nil
}
// Flush emits every queued event in arrival order. Coalesced slots go out
// via WriteGSO; passthrough slots go out via plainW.Write. Returns the
// first error observed; keeps draining so one bad packet doesn't hold up
// the rest. After Flush returns, borrowed payload slices may be recycled.
func (c *TCPCoalescer) Flush() error {
var first error
for _, s := range c.slots {
var err error
if s.passthrough {
_, err = c.plainW.Write(s.rawPkt)
} else {
err = c.flushSlot(s)
}
if err != nil && first == nil {
first = err
}
c.release(s)
}
for i := range c.slots {
c.slots[i] = nil
}
c.slots = c.slots[:0]
for k := range c.openSlots {
delete(c.openSlots, k)
}
return first
}
func (c *TCPCoalescer) addPassthrough(pkt []byte) {
s := c.take()
s.passthrough = true
s.rawPkt = pkt
c.slots = append(c.slots, s)
}
func (c *TCPCoalescer) seed(pkt []byte, info parsedTCP) {
if info.hdrLen > tcpCoalesceHdrCap || info.hdrLen+info.payLen > tcpCoalesceBufSize {
// Pathological shape — can't fit our scratch, emit as-is.
c.addPassthrough(pkt)
return
}
s := c.take()
s.passthrough = false
s.rawPkt = nil
copy(s.hdrBuf[:], pkt[:info.hdrLen])
s.hdrLen = info.hdrLen
s.ipHdrLen = info.ipHdrLen
s.isV6 = info.fk.isV6
s.fk = info.fk
s.gsoSize = info.payLen
s.numSeg = 1
s.totalPay = info.payLen
s.nextSeq = info.seq + uint32(info.payLen)
s.psh = info.flags&0x08 != 0
s.payIovs = append(s.payIovs[:0], pkt[info.hdrLen:info.hdrLen+info.payLen])
c.slots = append(c.slots, s)
if !s.psh {
c.openSlots[info.fk] = s
}
}
// canAppend reports whether info's packet extends the slot's seed: same
// header shape and stable contents, adjacent seq, not oversized, chain not
// closed.
func (c *TCPCoalescer) canAppend(s *coalesceSlot, pkt []byte, info parsedTCP) bool {
if s.psh {
return false
}
if info.hdrLen != s.hdrLen {
return false
}
if info.seq != s.nextSeq {
return false
}
if s.numSeg >= tcpCoalesceMaxSegs {
return false
}
if info.payLen > s.gsoSize {
return false
}
if s.hdrLen+s.totalPay+info.payLen > tcpCoalesceBufSize {
return false
}
if !headersMatch(s.hdrBuf[:s.hdrLen], pkt[:info.hdrLen], s.isV6, s.ipHdrLen) {
return false
}
return true
}
func (c *TCPCoalescer) appendPayload(s *coalesceSlot, pkt []byte, info parsedTCP) {
s.payIovs = append(s.payIovs, pkt[info.hdrLen:info.hdrLen+info.payLen])
s.numSeg++
s.totalPay += info.payLen
s.nextSeq = info.seq + uint32(info.payLen)
if info.payLen < s.gsoSize || info.flags&0x08 != 0 {
s.psh = true
}
}
func (c *TCPCoalescer) take() *coalesceSlot {
if n := len(c.pool); n > 0 {
s := c.pool[n-1]
c.pool[n-1] = nil
c.pool = c.pool[:n-1]
return s
}
return &coalesceSlot{}
}
func (c *TCPCoalescer) release(s *coalesceSlot) {
s.passthrough = false
s.rawPkt = nil
for i := range s.payIovs {
s.payIovs[i] = nil
}
s.payIovs = s.payIovs[:0]
s.numSeg = 0
s.totalPay = 0
s.psh = false
c.pool = append(c.pool, s)
}
// flushSlot patches the header and calls WriteGSO. Does not remove the
// slot from c.slots.
func (c *TCPCoalescer) flushSlot(s *coalesceSlot) error {
total := s.hdrLen + s.totalPay
l4Len := total - s.ipHdrLen
hdr := s.hdrBuf[:s.hdrLen]
if s.isV6 {
binary.BigEndian.PutUint16(hdr[4:6], uint16(l4Len))
} else {
binary.BigEndian.PutUint16(hdr[2:4], uint16(total))
hdr[10] = 0
hdr[11] = 0
binary.BigEndian.PutUint16(hdr[10:12], ipv4HdrChecksum(hdr[:s.ipHdrLen]))
}
var psum uint32
if s.isV6 {
psum = pseudoSumIPv6(hdr[8:24], hdr[24:40], ipProtoTCP, l4Len)
} else {
psum = pseudoSumIPv4(hdr[12:16], hdr[16:20], ipProtoTCP, l4Len)
}
tcsum := s.ipHdrLen + 16
binary.BigEndian.PutUint16(hdr[tcsum:tcsum+2], foldOnceNoInvert(psum))
return c.gsoW.WriteGSO(hdr, s.payIovs, uint16(s.gsoSize), s.isV6, uint16(s.ipHdrLen))
}
// headersMatch compares two IP+TCP header prefixes for byte-for-byte
// equality on every field that must be identical across coalesced
// segments. Size/IPID/IPCsum/seq/flags/tcpCsum are masked out.
func headersMatch(a, b []byte, isV6 bool, ipHdrLen int) bool {
if len(a) != len(b) {
return false
}
if isV6 {
// IPv6: bytes [0:4] = version/TC/flow-label, [6:8] = next_hdr/hop,
// [8:40] = src+dst. Skip [4:6] payload length.
if !bytes.Equal(a[0:4], b[0:4]) {
return false
}
if !bytes.Equal(a[6:40], b[6:40]) {
return false
}
} else {
// IPv4: [0:2] version/IHL/TOS, [6:10] flags/fragoff/TTL/proto,
// [12:20] src+dst. Skip [2:4] total len, [4:6] id, [10:12] csum.
if !bytes.Equal(a[0:2], b[0:2]) {
return false
}
if !bytes.Equal(a[6:10], b[6:10]) {
return false
}
if !bytes.Equal(a[12:20], b[12:20]) {
return false
}
}
// TCP: compare [0:4] ports, [8:13] ack+dataoff, [14:16] window,
// [18:tcpHdrLen] options (incl. urgent).
tcp := ipHdrLen
if !bytes.Equal(a[tcp:tcp+4], b[tcp:tcp+4]) {
return false
}
if !bytes.Equal(a[tcp+8:tcp+13], b[tcp+8:tcp+13]) {
return false
}
if !bytes.Equal(a[tcp+14:tcp+16], b[tcp+14:tcp+16]) {
return false
}
if !bytes.Equal(a[tcp+18:], b[tcp+18:]) {
return false
}
return true
}
// ipv4HdrChecksum computes the IPv4 header checksum over hdr (which must
// already have its checksum field zeroed) and returns the folded/inverted
// 16-bit value to store.
func ipv4HdrChecksum(hdr []byte) uint16 {
var sum uint32
for i := 0; i+1 < len(hdr); i += 2 {
sum += uint32(binary.BigEndian.Uint16(hdr[i : i+2]))
}
if len(hdr)%2 == 1 {
sum += uint32(hdr[len(hdr)-1]) << 8
}
for sum>>16 != 0 {
sum = (sum & 0xffff) + (sum >> 16)
}
return ^uint16(sum)
}
// pseudoSumIPv4 / pseudoSumIPv6 build the TCP pseudo-header partial sum
// expected by the virtio NEEDS_CSUM kernel path: the 32-bit accumulator
// before folding.
func pseudoSumIPv4(src, dst []byte, proto byte, l4Len int) uint32 {
var sum uint32
sum += uint32(binary.BigEndian.Uint16(src[0:2]))
sum += uint32(binary.BigEndian.Uint16(src[2:4]))
sum += uint32(binary.BigEndian.Uint16(dst[0:2]))
sum += uint32(binary.BigEndian.Uint16(dst[2:4]))
sum += uint32(proto)
sum += uint32(l4Len)
return sum
}
func pseudoSumIPv6(src, dst []byte, proto byte, l4Len int) uint32 {
var sum uint32
for i := 0; i < 16; i += 2 {
sum += uint32(binary.BigEndian.Uint16(src[i : i+2]))
sum += uint32(binary.BigEndian.Uint16(dst[i : i+2]))
}
sum += uint32(l4Len >> 16)
sum += uint32(l4Len & 0xffff)
sum += uint32(proto)
return sum
}
// foldOnceNoInvert folds the 32-bit accumulator to 16 bits and returns it
// unchanged (no one's complement). This is what virtio NEEDS_CSUM wants in
// the L4 checksum field — the kernel will add the payload sum and invert.
func foldOnceNoInvert(sum uint32) uint16 {
for sum>>16 != 0 {
sum = (sum & 0xffff) + (sum >> 16)
}
return uint16(sum)
}

View File

@@ -0,0 +1,576 @@
package coalesce
import (
"encoding/binary"
"testing"
)
// fakeTunWriter records plain Writes and WriteGSO calls without touching a
// real TUN fd. WriteGSO preserves the split between hdr and borrowed pays
// so tests can inspect each independently.
type fakeTunWriter struct {
gsoEnabled bool
writes [][]byte
gsoWrites []fakeGSOWrite
}
type fakeGSOWrite struct {
hdr []byte
pays [][]byte
gsoSize uint16
isV6 bool
csumStart uint16
}
// total returns hdrLen + sum of pay lens.
func (g fakeGSOWrite) total() int {
n := len(g.hdr)
for _, p := range g.pays {
n += len(p)
}
return n
}
// payLen sums the pays.
func (g fakeGSOWrite) payLen() int {
var n int
for _, p := range g.pays {
n += len(p)
}
return n
}
func (w *fakeTunWriter) Write(p []byte) (int, error) {
buf := make([]byte, len(p))
copy(buf, p)
w.writes = append(w.writes, buf)
return len(p), nil
}
func (w *fakeTunWriter) WriteGSO(hdr []byte, pays [][]byte, gsoSize uint16, isV6 bool, csumStart uint16) error {
hcopy := make([]byte, len(hdr))
copy(hcopy, hdr)
paysCopy := make([][]byte, len(pays))
for i, p := range pays {
pc := make([]byte, len(p))
copy(pc, p)
paysCopy[i] = pc
}
w.gsoWrites = append(w.gsoWrites, fakeGSOWrite{
hdr: hcopy,
pays: paysCopy,
gsoSize: gsoSize,
isV6: isV6,
csumStart: csumStart,
})
return nil
}
func (w *fakeTunWriter) GSOSupported() bool { return w.gsoEnabled }
// buildTCPv4 constructs a minimal IPv4+TCP packet with the given payload,
// seq, and flags. Assumes no IP options and a 20-byte TCP header.
func buildTCPv4(seq uint32, flags byte, payload []byte) []byte {
return buildTCPv4Ports(1000, 2000, seq, flags, payload)
}
// buildTCPv4Ports is buildTCPv4 with caller-specified ports so tests can
// build distinct flows.
func buildTCPv4Ports(sport, dport uint16, seq uint32, flags byte, payload []byte) []byte {
const ipHdrLen = 20
const tcpHdrLen = 20
total := ipHdrLen + tcpHdrLen + len(payload)
pkt := make([]byte, total)
pkt[0] = 0x45
pkt[1] = 0x00
binary.BigEndian.PutUint16(pkt[2:4], uint16(total))
binary.BigEndian.PutUint16(pkt[4:6], 0)
binary.BigEndian.PutUint16(pkt[6:8], 0x4000)
pkt[8] = 64
pkt[9] = ipProtoTCP
copy(pkt[12:16], []byte{10, 0, 0, 1})
copy(pkt[16:20], []byte{10, 0, 0, 2})
binary.BigEndian.PutUint16(pkt[20:22], sport)
binary.BigEndian.PutUint16(pkt[22:24], dport)
binary.BigEndian.PutUint32(pkt[24:28], seq)
binary.BigEndian.PutUint32(pkt[28:32], 12345)
pkt[32] = 0x50
pkt[33] = flags
binary.BigEndian.PutUint16(pkt[34:36], 0xffff)
copy(pkt[40:], payload)
return pkt
}
const (
tcpAck = 0x10
tcpPsh = 0x08
tcpSyn = 0x02
tcpFin = 0x01
tcpAckPsh = tcpAck | tcpPsh
)
func TestCoalescerPassthroughWhenGSOUnavailable(t *testing.T) {
w := &fakeTunWriter{gsoEnabled: false}
c := NewTCPCoalescer(w)
pkt := buildTCPv4(1000, tcpAck, []byte("hello"))
if err := c.Add(pkt); err != nil {
t.Fatal(err)
}
// No sync write — passthrough is deferred to Flush.
if len(w.writes) != 0 || len(w.gsoWrites) != 0 {
t.Fatalf("no Add-time writes: got writes=%d gso=%d", len(w.writes), len(w.gsoWrites))
}
if err := c.Flush(); err != nil {
t.Fatal(err)
}
if len(w.writes) != 1 || len(w.gsoWrites) != 0 {
t.Fatalf("want single plain write, got writes=%d gso=%d", len(w.writes), len(w.gsoWrites))
}
}
func TestCoalescerNonTCPPassthrough(t *testing.T) {
w := &fakeTunWriter{gsoEnabled: true}
c := NewTCPCoalescer(w)
pkt := make([]byte, 28)
pkt[0] = 0x45
binary.BigEndian.PutUint16(pkt[2:4], 28)
pkt[9] = 1
copy(pkt[12:16], []byte{10, 0, 0, 1})
copy(pkt[16:20], []byte{10, 0, 0, 2})
if err := c.Add(pkt); err != nil {
t.Fatal(err)
}
if err := c.Flush(); err != nil {
t.Fatal(err)
}
if len(w.writes) != 1 || len(w.gsoWrites) != 0 {
t.Fatalf("ICMP should pass through unchanged")
}
}
func TestCoalescerSeedThenFlushAlone(t *testing.T) {
w := &fakeTunWriter{gsoEnabled: true}
c := NewTCPCoalescer(w)
pkt := buildTCPv4(1000, tcpAck, make([]byte, 1000))
if err := c.Add(pkt); err != nil {
t.Fatal(err)
}
if len(w.writes) != 0 || len(w.gsoWrites) != 0 {
t.Fatalf("unexpected output before flush")
}
if err := c.Flush(); err != nil {
t.Fatal(err)
}
// Single-segment flush now goes through WriteGSO with GSO_NONE
// (virtio NEEDS_CSUM lets the kernel fill in the L4 csum).
if len(w.gsoWrites) != 1 || len(w.writes) != 0 {
t.Fatalf("single-seg flush: writes=%d gso=%d", len(w.writes), len(w.gsoWrites))
}
g := w.gsoWrites[0]
if g.total() != 40+1000 {
t.Errorf("super total=%d want %d", g.total(), 40+1000)
}
if g.payLen() != 1000 {
t.Errorf("payLen=%d want 1000", g.payLen())
}
}
func TestCoalescerCoalescesAdjacentACKs(t *testing.T) {
w := &fakeTunWriter{gsoEnabled: true}
c := NewTCPCoalescer(w)
pay := make([]byte, 1200)
if err := c.Add(buildTCPv4(1000, tcpAck, pay)); err != nil {
t.Fatal(err)
}
if err := c.Add(buildTCPv4(2200, tcpAck, pay)); err != nil {
t.Fatal(err)
}
if err := c.Add(buildTCPv4(3400, tcpAck, pay)); err != nil {
t.Fatal(err)
}
if err := c.Flush(); err != nil {
t.Fatal(err)
}
if len(w.gsoWrites) != 1 {
t.Fatalf("want 1 gso write, got %d (plain=%d)", len(w.gsoWrites), len(w.writes))
}
g := w.gsoWrites[0]
if g.gsoSize != 1200 {
t.Errorf("gsoSize=%d want 1200", g.gsoSize)
}
if len(g.hdr) != 40 {
t.Errorf("hdrLen=%d want 40", len(g.hdr))
}
if g.csumStart != 20 {
t.Errorf("csumStart=%d want 20", g.csumStart)
}
if len(g.pays) != 3 {
t.Errorf("pay count=%d want 3", len(g.pays))
}
if g.total() != 40+3*1200 {
t.Errorf("superpacket len=%d want %d", g.total(), 40+3*1200)
}
if tot := binary.BigEndian.Uint16(g.hdr[2:4]); int(tot) != g.total() {
t.Errorf("ip total_length=%d want %d", tot, g.total())
}
}
func TestCoalescerRejectsSeqGap(t *testing.T) {
w := &fakeTunWriter{gsoEnabled: true}
c := NewTCPCoalescer(w)
pay := make([]byte, 1200)
if err := c.Add(buildTCPv4(1000, tcpAck, pay)); err != nil {
t.Fatal(err)
}
if err := c.Add(buildTCPv4(3000, tcpAck, pay)); err != nil {
t.Fatal(err)
}
if err := c.Flush(); err != nil {
t.Fatal(err)
}
// Each packet flushes as its own single-segment WriteGSO now.
if len(w.gsoWrites) != 2 || len(w.writes) != 0 {
t.Fatalf("seq gap: want 2 gso writes got writes=%d gso=%d", len(w.writes), len(w.gsoWrites))
}
}
func TestCoalescerRejectsFlagMismatch(t *testing.T) {
w := &fakeTunWriter{gsoEnabled: true}
c := NewTCPCoalescer(w)
pay := make([]byte, 1200)
if err := c.Add(buildTCPv4(1000, tcpAck, pay)); err != nil {
t.Fatal(err)
}
// SYN|ACK is non-admissible. Must flush matching flow's slot (gso)
// and then plain-write the SYN packet itself.
syn := buildTCPv4(2200, tcpSyn|tcpAck, pay)
if err := c.Add(syn); err != nil {
t.Fatal(err)
}
if err := c.Flush(); err != nil {
t.Fatal(err)
}
if len(w.writes) != 1 || len(w.gsoWrites) != 1 {
t.Fatalf("flag mismatch: want 1 plain + 1 gso, got writes=%d gso=%d", len(w.writes), len(w.gsoWrites))
}
}
func TestCoalescerRejectsFIN(t *testing.T) {
w := &fakeTunWriter{gsoEnabled: true}
c := NewTCPCoalescer(w)
fin := buildTCPv4(1000, tcpAck|tcpFin, []byte("x"))
if err := c.Add(fin); err != nil {
t.Fatal(err)
}
if err := c.Flush(); err != nil {
t.Fatal(err)
}
// FIN isn't admissible — passthrough as plain, no slot, no gso.
if len(w.writes) != 1 || len(w.gsoWrites) != 0 {
t.Fatalf("FIN should be passthrough, got writes=%d gso=%d", len(w.writes), len(w.gsoWrites))
}
}
func TestCoalescerShortLastSegmentClosesChain(t *testing.T) {
w := &fakeTunWriter{gsoEnabled: true}
c := NewTCPCoalescer(w)
full := make([]byte, 1200)
half := make([]byte, 500)
if err := c.Add(buildTCPv4(1000, tcpAck, full)); err != nil {
t.Fatal(err)
}
if err := c.Add(buildTCPv4(2200, tcpAck, half)); err != nil {
t.Fatal(err)
}
// Chain now closed; next packet seeds a new slot on the same flow
// after flushing the old one.
if err := c.Add(buildTCPv4(2700, tcpAck, full)); err != nil {
t.Fatal(err)
}
if err := c.Flush(); err != nil {
t.Fatal(err)
}
// Expect two gso writes: the first two packets coalesced, then the
// third flushed alone (single-seg via GSO_NONE).
if len(w.gsoWrites) != 2 {
t.Fatalf("want 2 gso writes got %d", len(w.gsoWrites))
}
if len(w.writes) != 0 {
t.Fatalf("want 0 plain writes got %d", len(w.writes))
}
if w.gsoWrites[0].gsoSize != 1200 {
t.Errorf("gsoSize=%d want 1200", w.gsoWrites[0].gsoSize)
}
if got, want := w.gsoWrites[0].total(), 40+1200+500; got != want {
t.Errorf("super len=%d want %d", got, want)
}
}
func TestCoalescerPSHFinalizesChain(t *testing.T) {
w := &fakeTunWriter{gsoEnabled: true}
c := NewTCPCoalescer(w)
pay := make([]byte, 1200)
if err := c.Add(buildTCPv4(1000, tcpAck, pay)); err != nil {
t.Fatal(err)
}
if err := c.Add(buildTCPv4(2200, tcpAckPsh, pay)); err != nil {
t.Fatal(err)
}
if err := c.Add(buildTCPv4(3400, tcpAck, pay)); err != nil {
t.Fatal(err)
}
if err := c.Flush(); err != nil {
t.Fatal(err)
}
// First two coalesce; the third seeds a fresh slot that flushes alone.
if len(w.gsoWrites) != 2 {
t.Fatalf("want 2 gso writes got %d", len(w.gsoWrites))
}
if len(w.writes) != 0 {
t.Fatalf("want 0 plain writes got %d", len(w.writes))
}
}
func TestCoalescerRejectsDifferentFlow(t *testing.T) {
w := &fakeTunWriter{gsoEnabled: true}
c := NewTCPCoalescer(w)
pay := make([]byte, 1200)
p1 := buildTCPv4(1000, tcpAck, pay)
p2 := buildTCPv4(2200, tcpAck, pay)
binary.BigEndian.PutUint16(p2[20:22], 9999)
if err := c.Add(p1); err != nil {
t.Fatal(err)
}
if err := c.Add(p2); err != nil {
t.Fatal(err)
}
if err := c.Flush(); err != nil {
t.Fatal(err)
}
// Two independent flows, each flushes its own single-segment WriteGSO.
if len(w.gsoWrites) != 2 || len(w.writes) != 0 {
t.Fatalf("diff flow: want 2 gso writes got writes=%d gso=%d", len(w.writes), len(w.gsoWrites))
}
}
func TestCoalescerRejectsIPOptions(t *testing.T) {
w := &fakeTunWriter{gsoEnabled: true}
c := NewTCPCoalescer(w)
pay := make([]byte, 500)
pkt := buildTCPv4(1000, tcpAck, pay)
// Bump IHL to 6 to simulate 4 bytes of IP options. Don't actually add
// bytes — parser should bail before it matters.
pkt[0] = 0x46
if err := c.Add(pkt); err != nil {
t.Fatal(err)
}
if err := c.Flush(); err != nil {
t.Fatal(err)
}
// Non-admissible parse → passthrough as plain.
if len(w.writes) != 1 || len(w.gsoWrites) != 0 {
t.Fatalf("IP options should passthrough, got writes=%d gso=%d", len(w.writes), len(w.gsoWrites))
}
}
func TestCoalescerCapBySegments(t *testing.T) {
w := &fakeTunWriter{gsoEnabled: true}
c := NewTCPCoalescer(w)
pay := make([]byte, 512)
seq := uint32(1000)
for i := 0; i < tcpCoalesceMaxSegs+5; i++ {
if err := c.Add(buildTCPv4(seq, tcpAck, pay)); err != nil {
t.Fatal(err)
}
seq += uint32(len(pay))
}
if err := c.Flush(); err != nil {
t.Fatal(err)
}
for _, g := range w.gsoWrites {
segs := len(g.pays)
if segs > tcpCoalesceMaxSegs {
t.Fatalf("super exceeded seg cap: %d > %d", segs, tcpCoalesceMaxSegs)
}
}
}
// TestCoalescerMultipleFlowsInSameBatch proves two interleaved bulk TCP
// flows coalesce independently in a single Flush.
func TestCoalescerMultipleFlowsInSameBatch(t *testing.T) {
w := &fakeTunWriter{gsoEnabled: true}
c := NewTCPCoalescer(w)
pay := make([]byte, 1200)
// Flow A: sport 1000. Flow B: sport 3000.
if err := c.Add(buildTCPv4Ports(1000, 2000, 100, tcpAck, pay)); err != nil {
t.Fatal(err)
}
if err := c.Add(buildTCPv4Ports(3000, 2000, 500, tcpAck, pay)); err != nil {
t.Fatal(err)
}
if err := c.Add(buildTCPv4Ports(1000, 2000, 1300, tcpAck, pay)); err != nil {
t.Fatal(err)
}
if err := c.Add(buildTCPv4Ports(3000, 2000, 1700, tcpAck, pay)); err != nil {
t.Fatal(err)
}
if err := c.Add(buildTCPv4Ports(1000, 2000, 2500, tcpAck, pay)); err != nil {
t.Fatal(err)
}
if err := c.Add(buildTCPv4Ports(3000, 2000, 2900, tcpAck, pay)); err != nil {
t.Fatal(err)
}
if err := c.Flush(); err != nil {
t.Fatal(err)
}
if len(w.gsoWrites) != 2 {
t.Fatalf("want 2 gso writes (one per flow), got %d", len(w.gsoWrites))
}
if len(w.writes) != 0 {
t.Fatalf("want no plain writes, got %d", len(w.writes))
}
// Each superpacket should carry 3 segments.
for i, g := range w.gsoWrites {
if len(g.pays) != 3 {
t.Errorf("gso[%d]: segs=%d want 3", i, len(g.pays))
}
if g.gsoSize != 1200 {
t.Errorf("gso[%d]: gsoSize=%d want 1200", i, g.gsoSize)
}
}
// Verify each superpacket carries the source port it was seeded with.
seenSports := map[uint16]bool{}
for _, g := range w.gsoWrites {
sp := binary.BigEndian.Uint16(g.hdr[20:22])
seenSports[sp] = true
}
if !seenSports[1000] || !seenSports[3000] {
t.Errorf("expected superpackets for sports 1000 and 3000, got %v", seenSports)
}
}
// TestCoalescerPreservesArrivalOrder confirms that with passthrough and
// coalesced events both queued, Flush emits them in Add order rather than
// writing passthrough packets synchronously.
func TestCoalescerPreservesArrivalOrder(t *testing.T) {
w := &orderedFakeWriter{gsoEnabled: true}
c := NewTCPCoalescer(w)
// Sequence: coalesceable TCP, ICMP (passthrough), coalesceable TCP on
// a different flow. Expected emit order: gso(X), plain(ICMP), gso(Y).
pay := make([]byte, 1200)
if err := c.Add(buildTCPv4Ports(1000, 2000, 100, tcpAck, pay)); err != nil {
t.Fatal(err)
}
icmp := make([]byte, 28)
icmp[0] = 0x45
binary.BigEndian.PutUint16(icmp[2:4], 28)
icmp[9] = 1
copy(icmp[12:16], []byte{10, 0, 0, 1})
copy(icmp[16:20], []byte{10, 0, 0, 3})
if err := c.Add(icmp); err != nil {
t.Fatal(err)
}
if err := c.Add(buildTCPv4Ports(3000, 2000, 500, tcpAck, pay)); err != nil {
t.Fatal(err)
}
// Nothing should have hit the writer synchronously.
if len(w.events) != 0 {
t.Fatalf("Add emitted events synchronously: %v", w.events)
}
if err := c.Flush(); err != nil {
t.Fatal(err)
}
if got, want := w.events, []string{"gso", "plain", "gso"}; !stringSliceEq(got, want) {
t.Fatalf("flush order=%v want %v", got, want)
}
}
// orderedFakeWriter records only the sequence of call types so tests can
// assert arrival order without inspecting bytes.
type orderedFakeWriter struct {
gsoEnabled bool
events []string
}
func (w *orderedFakeWriter) Write(p []byte) (int, error) {
w.events = append(w.events, "plain")
return len(p), nil
}
func (w *orderedFakeWriter) WriteGSO(hdr []byte, pays [][]byte, gsoSize uint16, isV6 bool, csumStart uint16) error {
w.events = append(w.events, "gso")
return nil
}
func (w *orderedFakeWriter) GSOSupported() bool { return w.gsoEnabled }
func stringSliceEq(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
// TestCoalescerInterleavedFlowsPreserveOrdering checks that a non-admissible
// packet (SYN) mid-flow only flushes its own flow, not others.
func TestCoalescerInterleavedFlowsPreserveOrdering(t *testing.T) {
w := &fakeTunWriter{gsoEnabled: true}
c := NewTCPCoalescer(w)
pay := make([]byte, 1200)
// Flow A two segments.
if err := c.Add(buildTCPv4Ports(1000, 2000, 100, tcpAck, pay)); err != nil {
t.Fatal(err)
}
if err := c.Add(buildTCPv4Ports(1000, 2000, 1300, tcpAck, pay)); err != nil {
t.Fatal(err)
}
// Flow B two segments.
if err := c.Add(buildTCPv4Ports(3000, 2000, 500, tcpAck, pay)); err != nil {
t.Fatal(err)
}
if err := c.Add(buildTCPv4Ports(3000, 2000, 1700, tcpAck, pay)); err != nil {
t.Fatal(err)
}
// Flow A SYN (non-admissible) — must flush only flow A's slot.
syn := buildTCPv4Ports(1000, 2000, 9999, tcpSyn|tcpAck, pay)
if err := c.Add(syn); err != nil {
t.Fatal(err)
}
// Flow B continues — should still be coalesced with its seed.
if err := c.Add(buildTCPv4Ports(3000, 2000, 2900, tcpAck, pay)); err != nil {
t.Fatal(err)
}
if err := c.Flush(); err != nil {
t.Fatal(err)
}
// Expected:
// - 1 gso for flow A (first 2 segments)
// - 1 plain for flow A SYN
// - 1 gso for flow B (3 segments)
if len(w.gsoWrites) != 2 {
t.Fatalf("want 2 gso writes, got %d", len(w.gsoWrites))
}
if len(w.writes) != 1 {
t.Fatalf("want 1 plain write (SYN), got %d", len(w.writes))
}
// Find the 3-segment gso (flow B) and the 2-segment gso (flow A).
var segCounts []int
for _, g := range w.gsoWrites {
segCounts = append(segCounts, len(g.pays))
}
if !(segCounts[0] == 2 && segCounts[1] == 3) && !(segCounts[0] == 3 && segCounts[1] == 2) {
t.Errorf("unexpected segment counts: %v (want 2 and 3)", segCounts)
}
}

View File

@@ -4,6 +4,7 @@ import (
"io"
"net/netip"
"github.com/slackhq/nebula/overlay/tio"
"github.com/slackhq/nebula/routing"
)
@@ -11,59 +12,13 @@ import (
// that don't do TSO segmentation. 65535 covers any single IP packet.
const defaultBatchBufSize = 65535
// Queue is a readable/writable tun queue. One Queue is driven by a single
// read goroutine plus concurrent writers (see Write / WriteReject below).
type Queue interface {
io.Closer
// Read returns one or more packets. The returned slices are borrowed
// from the Queue's internal buffer and are only valid until the next
// Read or Close on this Queue — callers must encrypt or copy each
// slice before the next call. Not safe for concurrent Reads; exactly
// one goroutine per Queue reads.
Read() ([][]byte, error)
// Write emits a single packet on the plaintext (outside→inside)
// delivery path. May run concurrently with WriteReject on the same
// Queue, but not with itself.
Write(p []byte) (int, error)
// WriteReject writes a single packet that originated from the inside
// path (reject replies or self-forward) using scratch state distinct
// from Write, so it can run concurrently with Write on the same Queue
// without a data race. On backends without a shared-scratch Write, a
// trivial delegation to Write is acceptable.
WriteReject(p []byte) (int, error)
}
type Device interface {
Queue
io.Closer
Activate() error
Networks() []netip.Prefix
Name() string
RoutesFor(netip.Addr) routing.Gateways
SupportsMultiqueue() bool
NewMultiQueueReader() (Queue, error)
}
// GSOWriter is implemented by Queues that can emit a TCP TSO superpacket
// assembled from a header prefix plus one or more borrowed payload
// fragments, in a single vectored write (writev with a leading
// virtio_net_hdr). This lets the coalescer avoid copying payload bytes
// between the caller's decrypt buffer and the TUN. Backends without GSO
// support return false from GSOSupported and coalescing is skipped.
//
// hdr contains the IPv4/IPv6 + TCP header prefix (mutable — callers will
// have filled in total length and pseudo-header partial). pays are
// non-overlapping payload fragments whose concatenation is the full
// superpacket payload; they are read-only from the writer's perspective
// and must remain valid until the call returns. gsoSize is the MSS:
// every segment except possibly the last is exactly that many bytes.
// csumStart is the byte offset where the TCP header begins within hdr.
//
// hdr's TCP checksum field must already hold the pseudo-header partial
// sum (single-fold, not inverted), per virtio NEEDS_CSUM semantics.
type GSOWriter interface {
WriteGSO(hdr []byte, pays [][]byte, gsoSize uint16, isV6 bool, csumStart uint16) error
GSOSupported() bool
SupportsMultiqueue() bool //todo remove?
NewMultiQueueReader() error
Readers() []tio.Queue
}

View File

@@ -4,6 +4,7 @@ import (
"errors"
"net/netip"
"github.com/slackhq/nebula/overlay/tio"
"github.com/slackhq/nebula/routing"
)
@@ -41,7 +42,7 @@ func (NoopTun) SupportsMultiqueue() bool {
return false
}
func (NoopTun) NewMultiQueueReader() (Queue, error) {
func (NoopTun) NewMultiQueueReader() (tio.Queue, error) {
return nil, errors.New("unsupported")
}

View File

@@ -0,0 +1,70 @@
package tio
import (
"encoding/binary"
"errors"
"fmt"
"golang.org/x/sys/unix"
)
type gsoContainer struct {
pq []*tunFile
// pqi is exactly the same as pq, but stored as the interface type
pqi []Queue
shutdownFd int
}
func NewGSOContainer() (Container, error) {
shutdownFd, err := unix.Eventfd(0, unix.EFD_NONBLOCK|unix.EFD_CLOEXEC)
if err != nil {
return nil, fmt.Errorf("failed to create eventfd: %w", err)
}
out := &gsoContainer{
pq: []*tunFile{},
pqi: []Queue{},
shutdownFd: shutdownFd,
}
return out, nil
}
func (c *gsoContainer) Queues() []Queue {
return c.pqi
}
func (c *gsoContainer) Add(fd int) error {
x, err := newTunFd(fd, c.shutdownFd)
if err != nil {
return err
}
c.pq = append(c.pq, x)
c.pqi = append(c.pqi, x)
return nil
}
func (c *gsoContainer) wakeForShutdown() error {
var buf [8]byte
binary.NativeEndian.PutUint64(buf[:], 1)
_, err := unix.Write(int(c.shutdownFd), buf[:])
return err
}
func (c *gsoContainer) Close() error {
errs := []error{}
// Signal all readers blocked in poll to wake up and exit
if err := c.wakeForShutdown(); err != nil {
errs = append(errs, err)
}
for _, x := range c.pq {
if err := x.Close(); err != nil {
errs = append(errs, err)
}
}
return errors.Join(errs...)
}

View File

@@ -0,0 +1,69 @@
package tio
import (
"encoding/binary"
"errors"
"fmt"
"golang.org/x/sys/unix"
)
type pollContainer struct {
pq []*Poll
// pqi is exactly the same as pq, but stored as the interface type
pqi []Queue
shutdownFd int
}
func NewPollContainer() (Container, error) {
shutdownFd, err := unix.Eventfd(0, unix.EFD_NONBLOCK|unix.EFD_CLOEXEC)
if err != nil {
return nil, fmt.Errorf("failed to create eventfd: %w", err)
}
out := &pollContainer{
pq: []*Poll{},
pqi: []Queue{},
shutdownFd: shutdownFd,
}
return out, nil
}
func (c *pollContainer) Queues() []Queue {
return c.pqi
}
func (c *pollContainer) Add(fd int) error {
x, err := newPoll(fd, c.shutdownFd)
if err != nil {
return err
}
c.pq = append(c.pq, x)
c.pqi = append(c.pqi, x)
return nil
}
func (c *pollContainer) wakeForShutdown() error {
var buf [8]byte
binary.NativeEndian.PutUint64(buf[:], 1)
_, err := unix.Write(int(c.shutdownFd), buf[:])
return err
}
func (c *pollContainer) Close() error {
errs := []error{}
if err := c.wakeForShutdown(); err != nil {
errs = append(errs, err)
}
for _, x := range c.pq {
if err := x.Close(); err != nil {
errs = append(errs, err)
}
}
return errors.Join(errs...)
}

63
overlay/tio/tio.go Normal file
View File

@@ -0,0 +1,63 @@
package tio
import "io"
// defaultBatchBufSize is the per-Queue scratch size for Read on backends
// that don't do TSO segmentation. 65535 covers any single IP packet.
const defaultBatchBufSize = 65535
type Container interface {
Queues() []Queue
Add(fd int) error
io.Closer
}
// Queue is a readable/writable Poll queue. One Queue is driven by a single
// read goroutine plus concurrent writers (see Write / WriteReject below).
type Queue interface {
io.Closer
// Read returns one or more packets. The returned slices are borrowed
// from the Queue's internal buffer and are only valid until the next
// Read or Close on this Queue — callers must encrypt or copy each
// slice before the next call. Not safe for concurrent Reads; exactly
// one goroutine per Queue reads.
Read() ([][]byte, error)
// Write emits a single packet on the plaintext (outside→inside)
// delivery path. May run concurrently with WriteReject on the same
// Queue, but not with itself.
Write(p []byte) (int, error)
// WriteReject writes a single packet that originated from the inside
// path (reject replies or self-forward) using scratch state distinct
// from Write, so it can run concurrently with Write on the same Queue
// without a data race. On backends without a shared-scratch Write, a
// trivial delegation to Write is acceptable.
WriteReject(p []byte) (int, error)
}
// GSOWriter is implemented by Queues that can emit a TCP TSO superpacket
// assembled from a header prefix plus one or more borrowed payload
// fragments, in a single vectored write (writev with a leading
// virtio_net_hdr). This lets the coalescer avoid copying payload bytes
// between the caller's decrypt buffer and the TUN. Backends without GSO
// support return false from GSOSupported and coalescing is skipped.
//
// hdr contains the IPv4/IPv6 + TCP header prefix (mutable — callers will
// have filled in total length and pseudo-header partial). pays are
// non-overlapping payload fragments whose concatenation is the full
// superpacket payload; they are read-only from the writer's perspective
// and must remain valid until the call returns. gsoSize is the MSS:
// every segment except possibly the last is exactly that many bytes.
// csumStart is the byte offset where the TCP header begins within hdr.
//
// # TODO fold into Queue
//
// hdr's TCP checksum field must already hold the pseudo-header partial
// sum (single-fold, not inverted), per virtio NEEDS_CSUM semantics.
type GSOWriter interface {
WriteGSO(hdr []byte, pays [][]byte, gsoSize uint16, isV6 bool, csumStart uint16) error
GSOSupported() bool
}

View File

@@ -0,0 +1,405 @@
package tio
import (
"fmt"
"io"
"os"
"runtime"
"sync/atomic"
"syscall"
"unsafe"
"golang.org/x/sys/unix"
)
// Space for segmented output. Worst case is many small segments, each paying
// an IP+TCP header. 128KiB comfortably covers the 64KiB payload ceiling.
const tunSegBufSize = 131072
// tunSegBufCap is the total size we allocate for the per-reader segment
// buffer. It is sized as one worst-case TSO superpacket (tunSegBufSize) plus
// the same again as drain headroom so a Read wake can accumulate
// additional packets after an initial big read without overflowing.
const tunSegBufCap = tunSegBufSize * 2
// tunDrainCap caps how many packets a single Read will accumulate via
// the post-wake drain loop. Sized to soak up a burst of small ACKs while
// bounding how much work a single caller holds before handing off.
const tunDrainCap = 64
// gsoInitialPayIovs is the starting capacity (in payload fragments) of
// tunFile.gsoIovs. Sized to cover the default coalesce segment cap without
// any reallocations.
const gsoInitialPayIovs = 66
// validVnetHdr is the 10-byte virtio_net_hdr we prepend to every non-GSO TUN
// write. Only flag set is VIRTIO_NET_HDR_F_DATA_VALID, which marks the skb
// CHECKSUM_UNNECESSARY so the receiving network stack skips L4 checksum
// verification. All packets that reach the plain Write / WriteReject paths
// already carry a valid L4 checksum (either supplied by a remote peer whose
// ciphertext we AEAD-authenticated, or produced by finishChecksum during TSO
// segmentation, or built locally by CreateRejectPacket), so trusting them is
// safe.
var validVnetHdr = [virtioNetHdrLen]byte{unix.VIRTIO_NET_HDR_F_DATA_VALID}
// tunFile wraps a TUN file descriptor with poll-based reads. The FD provided will be changed to non-blocking.
// A shared eventfd allows Close to wake all readers blocked in poll.
type tunFile struct { //todo rename GSO
fd int
shutdownFd int
readPoll [2]unix.PollFd
writePoll [2]unix.PollFd
closed atomic.Bool
readBuf []byte // scratch for a single raw read (virtio hdr + superpacket)
segBuf []byte // backing store for segmented output
segOff int // cursor into segBuf for the current Read drain
pending [][]byte // segments returned from the most recent Read
writeIovs [2]unix.Iovec // preallocated iovecs for Write (coalescer passthrough); iovs[0] is fixed to validVnetHdr
// rejectIovs is a second preallocated iovec scratch used exclusively by
// WriteReject (reject + self-forward from the inside path). It mirrors
// writeIovs but lets listenIn goroutines emit reject packets without
// racing with the listenOut coalescer that owns writeIovs.
rejectIovs [2]unix.Iovec
// gsoHdrBuf is a per-queue 10-byte scratch for the virtio_net_hdr emitted
// by WriteGSO. Separate from validVnetHdr so a concurrent non-GSO Write on
// another queue never observes a half-written header.
gsoHdrBuf [virtioNetHdrLen]byte
// gsoIovs is the writev iovec scratch for WriteGSO. Sized to hold the
// virtio header + IP/TCP header + up to gsoInitialPayIovs payload
// fragments; grown on demand if a coalescer pushes more.
gsoIovs []unix.Iovec
}
func newTunFd(fd int, shutdownFd int) (*tunFile, error) {
if err := unix.SetNonblock(fd, true); err != nil {
return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err)
}
out := &tunFile{
fd: fd,
shutdownFd: shutdownFd,
closed: atomic.Bool{},
readBuf: make([]byte, tunReadBufSize),
readPoll: [2]unix.PollFd{
{Fd: int32(fd), Events: unix.POLLIN},
{Fd: int32(shutdownFd), Events: unix.POLLIN},
},
writePoll: [2]unix.PollFd{
{Fd: int32(fd), Events: unix.POLLOUT},
{Fd: int32(shutdownFd), Events: unix.POLLIN},
},
segBuf: make([]byte, tunSegBufSize),
gsoIovs: make([]unix.Iovec, 2, 2+gsoInitialPayIovs),
}
out.writeIovs[0].Base = &validVnetHdr[0]
out.writeIovs[0].SetLen(virtioNetHdrLen)
out.rejectIovs[0].Base = &validVnetHdr[0]
out.rejectIovs[0].SetLen(virtioNetHdrLen)
out.gsoIovs[0].Base = &out.gsoHdrBuf[0]
out.gsoIovs[0].SetLen(virtioNetHdrLen)
return out, nil
}
func (r *tunFile) blockOnRead() error {
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
var err error
for {
_, err = unix.Poll(r.readPoll[:], -1)
if err != unix.EINTR {
break
}
}
//always reset these!
tunEvents := r.readPoll[0].Revents
shutdownEvents := r.readPoll[1].Revents
r.readPoll[0].Revents = 0
r.readPoll[1].Revents = 0
//do the err check before trusting the potentially bogus bits we just got
if err != nil {
return err
}
if shutdownEvents&(unix.POLLIN|problemFlags) != 0 {
return os.ErrClosed
} else if tunEvents&problemFlags != 0 {
return os.ErrClosed
}
return nil
}
func (r *tunFile) blockOnWrite() error {
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
var err error
for {
_, err = unix.Poll(r.writePoll[:], -1)
if err != unix.EINTR {
break
}
}
//always reset these!
tunEvents := r.writePoll[0].Revents
shutdownEvents := r.writePoll[1].Revents
r.writePoll[0].Revents = 0
r.writePoll[1].Revents = 0
//do the err check before trusting the potentially bogus bits we just got
if err != nil {
return err
}
if shutdownEvents&(unix.POLLIN|problemFlags) != 0 {
return os.ErrClosed
} else if tunEvents&problemFlags != 0 {
return os.ErrClosed
}
return nil
}
func (r *tunFile) readRaw(buf []byte) (int, error) {
for {
if n, err := unix.Read(r.fd, buf); err == nil {
return n, nil
} else if err == unix.EAGAIN {
if err = r.blockOnRead(); err != nil {
return 0, err
}
continue
} else if err == unix.EINTR {
continue
} else if err == unix.EBADF {
return 0, os.ErrClosed
} else {
return 0, err
}
}
}
// Read reads one or more superpackets from the tun and returns the
// resulting packets. The first read blocks via poll; once the fd is known
// readable we drain additional packets non-blocking until the kernel queue
// is empty (EAGAIN), we've collected tunDrainCap packets, or we're out of
// segBuf headroom. This amortizes the poll wake over bursts of small
// packets (e.g. TCP ACKs). Slices point into the tunFile's internal buffers
// and are only valid until the next Read or Close on this Queue.
func (r *tunFile) Read() ([][]byte, error) {
r.pending = r.pending[:0]
r.segOff = 0
// Initial (blocking) read. Retry on decode errors so a single bad
// packet does not stall the reader.
for {
n, err := r.readRaw(r.readBuf)
if err != nil {
return nil, err
}
if err := r.decodeRead(n); err != nil {
// Drop and read again — a bad packet should not kill the reader.
continue
}
break
}
// Drain: non-blocking reads until the kernel queue is empty, the drain
// cap is reached, or segBuf no longer has room for another worst-case
// superpacket.
for len(r.pending) < tunDrainCap && tunSegBufCap-r.segOff >= tunSegBufSize {
n, err := unix.Read(r.fd, r.readBuf)
if err != nil {
// EAGAIN / EINTR / anything else: stop draining. We already
// have a valid batch from the first read.
break
}
if n <= 0 {
break
}
if err := r.decodeRead(n); err != nil {
// Drop this packet and stop the drain; we'd rather hand off
// what we have than keep spinning here.
break
}
}
return r.pending, nil
}
// decodeRead decodes the virtio header plus payload in r.readBuf[:n], appends
// the segments to r.pending, and advances r.segOff by the total scratch used.
// Caller must have already ensured r.vnetHdr is true.
func (r *tunFile) decodeRead(n int) error {
if n < virtioNetHdrLen {
return fmt.Errorf("short tun read: %d < %d", n, virtioNetHdrLen)
}
var hdr VirtioNetHdr
hdr.decode(r.readBuf[:virtioNetHdrLen])
before := len(r.pending)
if err := segmentInto(r.readBuf[virtioNetHdrLen:n], hdr, &r.pending, r.segBuf[r.segOff:]); err != nil {
return err
}
for k := before; k < len(r.pending); k++ {
r.segOff += len(r.pending[k])
}
return nil
}
func (r *tunFile) Write(buf []byte) (int, error) {
return r.writeWithScratch(buf, &r.writeIovs)
}
// WriteReject emits a packet using a dedicated iovec scratch (rejectIovs)
// distinct from the one used by the coalescer's Write path. This avoids a
// data race between the inside (listenIn) goroutine emitting reject or
// self-forward packets and the outside (listenOut) goroutine flushing TCP
// coalescer passthroughs on the same tunFile.
func (r *tunFile) WriteReject(buf []byte) (int, error) {
return r.writeWithScratch(buf, &r.rejectIovs)
}
func (r *tunFile) writeWithScratch(buf []byte, iovs *[2]unix.Iovec) (int, error) {
if len(buf) == 0 {
return 0, nil
}
// Point the payload iovec at the caller's buffer. iovs[0] is pre-wired
// to validVnetHdr during tunFile construction so we don't rebuild it here.
iovs[1].Base = &buf[0]
iovs[1].SetLen(len(buf))
iovPtr := uintptr(unsafe.Pointer(&iovs[0]))
// The TUN fd is non-blocking (set in newTunFd / newFriend), so writev
// either completes promptly or returns EAGAIN — it cannot park the
// goroutine inside the kernel. That lets us use syscall.RawSyscall and
// skip the runtime.entersyscall / exitsyscall bookkeeping on every
// packet; we only pay that cost when we fall through to blockOnWrite.
for {
n, _, errno := syscall.RawSyscall(unix.SYS_WRITEV, uintptr(r.fd), iovPtr, 2)
if errno == 0 {
runtime.KeepAlive(buf)
if int(n) < virtioNetHdrLen {
return 0, io.ErrShortWrite
}
return int(n) - virtioNetHdrLen, nil
}
if errno == unix.EAGAIN {
runtime.KeepAlive(buf)
if err := r.blockOnWrite(); err != nil {
return 0, err
}
continue
}
if errno == unix.EINTR {
continue
}
if errno == unix.EBADF {
return 0, os.ErrClosed
}
runtime.KeepAlive(buf)
return 0, errno
}
}
// GSOSupported reports whether this queue was opened with IFF_VNET_HDR and
// can accept WriteGSO. When false, callers should fall back to per-segment
// Write calls.
func (r *tunFile) GSOSupported() bool { return true }
// WriteGSO emits a TCP TSO superpacket in a single writev. hdr is the
// IPv4/IPv6 + TCP header prefix (already finalized — total length, IP csum,
// and TCP pseudo-header partial set by the caller). pays are payload
// fragments whose concatenation forms the full coalesced payload; each
// slice is read-only and must stay valid until return. gsoSize is the MSS;
// every segment except possibly the last is exactly gsoSize bytes.
// csumStart is the byte offset where the TCP header begins within hdr.
func (r *tunFile) WriteGSO(hdr []byte, pays [][]byte, gsoSize uint16, isV6 bool, csumStart uint16) error {
if len(hdr) == 0 || len(pays) == 0 {
return nil
}
// Build the virtio_net_hdr. When pays total to <= gsoSize the kernel
// would produce a single segment; keep NEEDS_CSUM semantics but skip
// the GSO type so the kernel doesn't spuriously mark this as TSO.
vhdr := VirtioNetHdr{
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
HdrLen: uint16(len(hdr)),
GSOSize: gsoSize,
CsumStart: csumStart,
CsumOffset: 16, // TCP checksum field lives 16 bytes into the TCP header
}
var totalPay int
for _, p := range pays {
totalPay += len(p)
}
if totalPay > int(gsoSize) {
if isV6 {
vhdr.GSOType = unix.VIRTIO_NET_HDR_GSO_TCPV6
} else {
vhdr.GSOType = unix.VIRTIO_NET_HDR_GSO_TCPV4
}
} else {
vhdr.GSOType = unix.VIRTIO_NET_HDR_GSO_NONE
vhdr.GSOSize = 0
}
vhdr.encode(r.gsoHdrBuf[:])
// Build the iovec array: [virtio_hdr, hdr, pays...]. r.gsoIovs[0] is
// wired to gsoHdrBuf at construction and never changes.
need := 2 + len(pays)
if cap(r.gsoIovs) < need {
grown := make([]unix.Iovec, need)
grown[0] = r.gsoIovs[0]
r.gsoIovs = grown
} else {
r.gsoIovs = r.gsoIovs[:need]
}
r.gsoIovs[1].Base = &hdr[0]
r.gsoIovs[1].SetLen(len(hdr))
for i, p := range pays {
r.gsoIovs[2+i].Base = &p[0]
r.gsoIovs[2+i].SetLen(len(p))
}
iovPtr := uintptr(unsafe.Pointer(&r.gsoIovs[0]))
iovCnt := uintptr(len(r.gsoIovs))
for {
n, _, errno := syscall.RawSyscall(unix.SYS_WRITEV, uintptr(r.fd), iovPtr, iovCnt)
if errno == 0 {
runtime.KeepAlive(hdr)
runtime.KeepAlive(pays)
if int(n) < virtioNetHdrLen {
return io.ErrShortWrite
}
return nil
}
if errno == unix.EAGAIN {
runtime.KeepAlive(hdr)
runtime.KeepAlive(pays)
if err := r.blockOnWrite(); err != nil {
return err
}
continue
}
if errno == unix.EINTR {
continue
}
if errno == unix.EBADF {
return os.ErrClosed
}
runtime.KeepAlive(hdr)
runtime.KeepAlive(pays)
return errno
}
}
func (r *tunFile) Close() error {
if r.closed.Swap(true) {
return nil
}
//shutdownFd is owned by the container, so we should not close it
var err error
if r.fd >= 0 {
err = unix.Close(r.fd)
r.fd = -1
}
return err
}

View File

@@ -0,0 +1,205 @@
package tio
import (
"fmt"
"os"
"sync/atomic"
"syscall"
"unsafe"
"golang.org/x/sys/unix"
)
// Maximum size we accept for a single read from a TUN with IFF_VNET_HDR. A
// TSO superpacket can be up to 64KiB of payload plus a single L2/L3/L4 header
// prefix plus the virtio header.
const tunReadBufSize = 65535
type Poll struct {
fd int
readPoll [2]unix.PollFd
writePoll [2]unix.PollFd
closed atomic.Bool
readBuf []byte
batchRet [1][]byte
}
func newPoll(fd int, shutdownFd int) (*Poll, error) {
if err := unix.SetNonblock(fd, true); err != nil {
_ = unix.Close(fd)
return nil, fmt.Errorf("failed to set Poll device as nonblocking: %w", err)
}
out := &Poll{
fd: fd,
readBuf: make([]byte, tunReadBufSize),
readPoll: [2]unix.PollFd{
{Fd: int32(fd), Events: unix.POLLIN},
{Fd: int32(shutdownFd), Events: unix.POLLIN},
},
writePoll: [2]unix.PollFd{
{Fd: int32(fd), Events: unix.POLLOUT},
{Fd: int32(shutdownFd), Events: unix.POLLIN},
},
}
return out, nil
}
// blockOnRead waits until the Poll fd is readable or shutdown has been signaled.
// Returns os.ErrClosed if Close was called.
func (t *Poll) blockOnRead() error {
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
var err error
for {
_, err = unix.Poll(t.readPoll[:], -1)
if err != unix.EINTR {
break
}
}
tunEvents := t.readPoll[0].Revents
shutdownEvents := t.readPoll[1].Revents
t.readPoll[0].Revents = 0
t.readPoll[1].Revents = 0
if err != nil {
return err
}
if shutdownEvents&(unix.POLLIN|problemFlags) != 0 {
return os.ErrClosed
}
if tunEvents&problemFlags != 0 {
return os.ErrClosed
}
return nil
}
func (t *Poll) blockOnWrite() error {
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
var err error
for {
_, err = unix.Poll(t.writePoll[:], -1)
if err != unix.EINTR {
break
}
}
tunEvents := t.writePoll[0].Revents
shutdownEvents := t.writePoll[1].Revents
t.writePoll[0].Revents = 0
t.writePoll[1].Revents = 0
if err != nil {
return err
}
if shutdownEvents&(unix.POLLIN|problemFlags) != 0 {
return os.ErrClosed
}
if tunEvents&problemFlags != 0 {
return os.ErrClosed
}
return nil
}
func (t *Poll) Read() ([][]byte, error) {
if t.readBuf == nil {
t.readBuf = make([]byte, defaultBatchBufSize)
}
n, err := t.readOne(t.readBuf)
if err != nil {
return nil, err
}
t.batchRet[0] = t.readBuf[:n]
return t.batchRet[:], nil
}
func (t *Poll) readOne(to []byte) (int, error) {
// first 4 bytes is protocol family, in network byte order
var head [4]byte
iovecs := [2]syscall.Iovec{ //todo plat-specific
{&head[0], 4},
{&to[0], uint64(len(to))},
}
for {
n, _, errno := syscall.Syscall(syscall.SYS_READV, uintptr(t.fd), uintptr(unsafe.Pointer(&iovecs[0])), 2)
if errno == 0 {
bytesRead := int(n)
if bytesRead < 4 {
return 0, nil
}
return bytesRead - 4, nil
}
switch errno {
case unix.EAGAIN:
if err := t.blockOnRead(); err != nil {
return 0, err
}
case unix.EINTR:
// retry
case unix.EBADF:
return 0, os.ErrClosed
default:
return 0, errno
}
}
}
// Write is only valid for single threaded use
func (t *Poll) Write(from []byte) (int, error) {
if len(from) <= 1 {
return 0, syscall.EIO
}
ipVer := from[0] >> 4
var head [4]byte
// first 4 bytes is protocol family, in network byte order
switch ipVer {
case 4:
head[3] = syscall.AF_INET
case 6:
head[3] = syscall.AF_INET6
default:
return 0, fmt.Errorf("unable to determine IP version from packet")
}
iovecs := [2]syscall.Iovec{ //todo plat specific
{&head[0], 4},
{&from[0], uint64(len(from))},
}
for {
n, _, errno := syscall.Syscall(syscall.SYS_WRITEV, uintptr(t.fd), uintptr(unsafe.Pointer(&iovecs[0])), 2)
if errno == 0 {
return int(n) - 4, nil
}
switch errno {
case unix.EAGAIN:
if err := t.blockOnWrite(); err != nil {
return 0, err
}
case unix.EINTR:
// retry
case unix.EBADF:
return 0, os.ErrClosed
default:
return 0, errno
}
}
}
func (t *Poll) Close() error {
if t.closed.Swap(true) {
return nil
}
//shutdownFd is owned by the container, so we should not close it
var err error
if t.fd >= 0 {
err = unix.Close(t.fd)
t.fd = -1
}
return err
}
func (t *Poll) WriteReject(p []byte) (int, error) {
return t.Write(p)
}

View File

@@ -1,7 +1,7 @@
//go:build linux && !android && !e2e_testing
// +build linux,!android,!e2e_testing
package overlay
package tio
import (
"errors"

View File

@@ -1,7 +1,7 @@
//go:build linux && !android && !e2e_testing
// +build linux,!android,!e2e_testing
package overlay
package tio
import (
"encoding/binary"
@@ -10,66 +10,10 @@ import (
"golang.org/x/sys/unix"
)
// Size of the legacy struct virtio_net_hdr that the kernel prepends/expects on
// a TUN opened with IFF_VNET_HDR (TUNSETVNETHDRSZ not set).
const virtioNetHdrLen = 10
// Maximum size we accept for a single read from a TUN with IFF_VNET_HDR. A
// TSO superpacket can be up to 64KiB of payload plus a single L2/L3/L4 header
// prefix plus the virtio header.
const tunReadBufSize = 65535
// Space for segmented output. Worst case is many small segments, each paying
// an IP+TCP header. 128KiB comfortably covers the 64KiB payload ceiling.
const tunSegBufSize = 131072
// tunSegBufCap is the total size we allocate for the per-reader segment
// buffer. It is sized as one worst-case TSO superpacket (tunSegBufSize) plus
// the same again as drain headroom so a Read wake can accumulate
// additional packets after an initial big read without overflowing.
const tunSegBufCap = tunSegBufSize * 2
// tunDrainCap caps how many packets a single Read will accumulate via
// the post-wake drain loop. Sized to soak up a burst of small ACKs while
// bounding how much work a single caller holds before handing off.
const tunDrainCap = 64
type virtioNetHdr struct {
Flags uint8
GSOType uint8
HdrLen uint16
GSOSize uint16
CsumStart uint16
CsumOffset uint16
}
// decode reads a virtio_net_hdr in host byte order (TUN default; we never
// call TUNSETVNETLE so the kernel matches our endianness).
func (h *virtioNetHdr) decode(b []byte) {
h.Flags = b[0]
h.GSOType = b[1]
h.HdrLen = binary.NativeEndian.Uint16(b[2:4])
h.GSOSize = binary.NativeEndian.Uint16(b[4:6])
h.CsumStart = binary.NativeEndian.Uint16(b[6:8])
h.CsumOffset = binary.NativeEndian.Uint16(b[8:10])
}
// encode is the inverse of decode: writes the virtio_net_hdr fields into b
// (must be at least virtioNetHdrLen bytes). Used to emit a TSO superpacket
// on egress.
func (h *virtioNetHdr) encode(b []byte) {
b[0] = h.Flags
b[1] = h.GSOType
binary.NativeEndian.PutUint16(b[2:4], h.HdrLen)
binary.NativeEndian.PutUint16(b[4:6], h.GSOSize)
binary.NativeEndian.PutUint16(b[6:8], h.CsumStart)
binary.NativeEndian.PutUint16(b[8:10], h.CsumOffset)
}
// segmentInto splits a TUN-side packet described by hdr into one or more
// IP packets, each appended to *out as a slice of scratch. scratch must be
// sized to hold every segment (including replicated headers).
func segmentInto(pkt []byte, hdr virtioNetHdr, out *[][]byte, scratch []byte) error {
func segmentInto(pkt []byte, hdr VirtioNetHdr, out *[][]byte, scratch []byte) error {
// When RSC_INFO is set the csum_start/csum_offset fields are repurposed to
// carry coalescing info rather than checksum offsets. A TUN writing via
// IFF_VNET_HDR should never emit this, but if it did we would silently
@@ -105,7 +49,7 @@ func segmentInto(pkt []byte, hdr virtioNetHdr, out *[][]byte, scratch []byte) er
// handed us with NEEDS_CSUM set. csum_start / csum_offset point at the 16-bit
// checksum field; we zero it, fold a full sum (the field was pre-loaded with
// the pseudo-header partial sum by the kernel), and store the result.
func finishChecksum(seg []byte, hdr virtioNetHdr) error {
func finishChecksum(seg []byte, hdr VirtioNetHdr) error {
cs := int(hdr.CsumStart)
co := int(hdr.CsumOffset)
if cs+co+2 > len(seg) {
@@ -129,7 +73,7 @@ func finishChecksum(seg []byte, hdr virtioNetHdr) error {
// are each summed once up front — every segment reuses those three pre-folded
// uint32 values and combines them with small per-segment deltas (seq, flags,
// tcpLen, ip_id, total_len) that are cheap to fold in.
func segmentTCP(pkt []byte, hdr virtioNetHdr, out *[][]byte, scratch []byte) error {
func segmentTCP(pkt []byte, hdr VirtioNetHdr, out *[][]byte, scratch []byte) error {
if hdr.GSOSize == 0 {
return fmt.Errorf("gso_size is zero")
}

View File

@@ -1,7 +1,7 @@
//go:build linux && !android && !e2e_testing
// +build linux,!android,!e2e_testing
package overlay
package tio
import (
"encoding/binary"
@@ -23,7 +23,7 @@ func verifyChecksum(b []byte, pseudo uint32) bool {
// buildTSOv4 builds a synthetic IPv4/TCP TSO superpacket with a payload of
// `payLen` bytes split at `mss`.
func buildTSOv4(t *testing.T, payLen, mss int) ([]byte, virtioNetHdr) {
func buildTSOv4(t *testing.T, payLen, mss int) ([]byte, VirtioNetHdr) {
t.Helper()
const ipLen = 20
const tcpLen = 20
@@ -53,7 +53,7 @@ func buildTSOv4(t *testing.T, payLen, mss int) ([]byte, virtioNetHdr) {
pkt[ipLen+tcpLen+i] = byte(i & 0xff)
}
return pkt, virtioNetHdr{
return pkt, VirtioNetHdr{
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
GSOType: unix.VIRTIO_NET_HDR_GSO_TCPV4,
HdrLen: uint16(ipLen + tcpLen),
@@ -174,7 +174,7 @@ func TestSegmentTCPv6(t *testing.T) {
pkt[ipLen+tcpLen+i] = byte(i)
}
hdr := virtioNetHdr{
hdr := VirtioNetHdr{
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
GSOType: unix.VIRTIO_NET_HDR_GSO_TCPV6,
HdrLen: uint16(ipLen + tcpLen),
@@ -240,7 +240,7 @@ func TestSegmentGSONonePassesThrough(t *testing.T) {
}
func TestSegmentRejectsUDP(t *testing.T) {
hdr := virtioNetHdr{GSOType: unix.VIRTIO_NET_HDR_GSO_UDP}
hdr := VirtioNetHdr{GSOType: unix.VIRTIO_NET_HDR_GSO_UDP}
var out [][]byte
if err := segmentInto(nil, hdr, &out, nil); err == nil {
t.Fatalf("expected rejection for UDP GSO")
@@ -279,7 +279,7 @@ func BenchmarkSegmentTCPv4(b *testing.B) {
for i := 0; i < sz.payLen; i++ {
pkt[ipLen+tcpLen+i] = byte(i)
}
hdr := virtioNetHdr{
hdr := VirtioNetHdr{
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
GSOType: unix.VIRTIO_NET_HDR_GSO_TCPV4,
HdrLen: uint16(ipLen + tcpLen),
@@ -312,7 +312,7 @@ func TestTunFileWriteVnetHdrNoAlloc(t *testing.T) {
}
t.Cleanup(func() { _ = unix.Close(fd) })
tf := &tunFile{fd: fd, vnetHdr: true}
tf := &tunFile{fd: fd}
tf.writeIovs[0].Base = &validVnetHdr[0]
tf.writeIovs[0].SetLen(virtioNetHdrLen)

View File

@@ -0,0 +1,38 @@
//go:build !e2e_testing
// +build !e2e_testing
package tio
import (
"testing"
"github.com/slackhq/nebula/overlay"
)
var runAdvMSSTests = []struct {
name string
tun *overlay.tun
r overlay.Route
expected int
}{
// Standard case, default MTU is the device max MTU
{"default", &overlay.tun{DefaultMTU: 1440, MaxMTU: 1440}, overlay.Route{}, 0},
{"default-min", &overlay.tun{DefaultMTU: 1440, MaxMTU: 1440}, overlay.Route{MTU: 1440}, 0},
{"default-low", &overlay.tun{DefaultMTU: 1440, MaxMTU: 1440}, overlay.Route{MTU: 1200}, 1160},
// Case where we have a route MTU set higher than the default
{"route", &overlay.tun{DefaultMTU: 1440, MaxMTU: 8941}, overlay.Route{}, 1400},
{"route-min", &overlay.tun{DefaultMTU: 1440, MaxMTU: 8941}, overlay.Route{MTU: 1440}, 1400},
{"route-high", &overlay.tun{DefaultMTU: 1440, MaxMTU: 8941}, overlay.Route{MTU: 8941}, 0},
}
func TestTunAdvMSS(t *testing.T) {
for _, tt := range runAdvMSSTests {
t.Run(tt.name, func(t *testing.T) {
o := tt.tun.advMSS(tt.r)
if o != tt.expected {
t.Errorf("got %d, want %d", o, tt.expected)
}
})
}
}

View File

@@ -0,0 +1,39 @@
package tio
import "encoding/binary"
// Size of the legacy struct virtio_net_hdr that the kernel prepends/expects on
// a TUN opened with IFF_VNET_HDR (TUNSETVNETHDRSZ not set).
const virtioNetHdrLen = 10
type VirtioNetHdr struct {
Flags uint8
GSOType uint8
HdrLen uint16
GSOSize uint16
CsumStart uint16
CsumOffset uint16
}
// decode reads a virtio_net_hdr in host byte order (TUN default; we never
// call TUNSETVNETLE so the kernel matches our endianness).
func (h *VirtioNetHdr) decode(b []byte) {
h.Flags = b[0]
h.GSOType = b[1]
h.HdrLen = binary.NativeEndian.Uint16(b[2:4])
h.GSOSize = binary.NativeEndian.Uint16(b[4:6])
h.CsumStart = binary.NativeEndian.Uint16(b[6:8])
h.CsumOffset = binary.NativeEndian.Uint16(b[8:10])
}
// encode is the inverse of decode: writes the virtio_net_hdr fields into b
// (must be at least virtioNetHdrLen bytes). Used to emit a TSO superpacket
// on egress.
func (h *VirtioNetHdr) encode(b []byte) {
b[0] = h.Flags
b[1] = h.GSOType
binary.NativeEndian.PutUint16(b[2:4], h.HdrLen)
binary.NativeEndian.PutUint16(b[4:6], h.GSOSize)
binary.NativeEndian.PutUint16(b[6:8], h.CsumStart)
binary.NativeEndian.PutUint16(b[8:10], h.CsumOffset)
}

View File

@@ -13,6 +13,7 @@ import (
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay/tio"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
)
@@ -126,6 +127,6 @@ func (t *tun) SupportsMultiqueue() bool {
return false
}
func (t *tun) NewMultiQueueReader() (Queue, error) {
func (t *tun) NewMultiQueueReader() (tio.Queue, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for android")
}

View File

@@ -16,6 +16,7 @@ import (
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay/tio"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
netroute "golang.org/x/net/route"
@@ -572,6 +573,6 @@ func (t *tun) SupportsMultiqueue() bool {
return false
}
func (t *tun) NewMultiQueueReader() (Queue, error) {
func (t *tun) NewMultiQueueReader() (tio.Queue, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin")
}

View File

@@ -9,6 +9,7 @@ import (
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/overlay/tio"
"github.com/slackhq/nebula/routing"
)
@@ -17,9 +18,10 @@ type disabledTun struct {
vpnNetworks []netip.Prefix
// Track these metrics since we don't have the tun device to do it for us
tx metrics.Counter
rx metrics.Counter
l *logrus.Logger
tx metrics.Counter
rx metrics.Counter
l *logrus.Logger
numReaders int
batchRet [1][]byte
}
@@ -44,6 +46,7 @@ func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled boo
vpnNetworks: vpnNetworks,
read: make(chan []byte, queueLen),
l: l,
numReaders: 1,
}
if metricsEnabled {
@@ -112,8 +115,17 @@ func (t *disabledTun) SupportsMultiqueue() bool {
return true
}
func (t *disabledTun) NewMultiQueueReader() (Queue, error) {
return t, nil
func (t *disabledTun) NewMultiQueueReader() error {
t.numReaders++
return nil
}
func (t *disabledTun) Readers() []tio.Queue {
out := make([]tio.Queue, t.numReaders)
for i := range t.numReaders {
out[i] = t
}
return out
}
func (t *disabledTun) Close() error {

View File

@@ -18,6 +18,7 @@ import (
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay/tio"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
netroute "golang.org/x/net/route"
@@ -581,7 +582,7 @@ func (t *tun) SupportsMultiqueue() bool {
return false
}
func (t *tun) NewMultiQueueReader() (Queue, error) {
func (t *tun) NewMultiQueueReader() (tio.Queue, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for freebsd")
}

View File

@@ -16,6 +16,7 @@ import (
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay/tio"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
)
@@ -182,6 +183,6 @@ func (t *tun) SupportsMultiqueue() bool {
return false
}
func (t *tun) NewMultiQueueReader() (Queue, error) {
func (t *tun) NewMultiQueueReader() (tio.Queue, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for ios")
}

View File

@@ -4,478 +4,28 @@
package overlay
import (
"encoding/binary"
"fmt"
"io"
"net"
"net/netip"
"os"
"runtime"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"
"unsafe"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay/tio"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
)
// tunFile wraps a TUN file descriptor with poll-based reads. The FD provided will be changed to non-blocking.
// A shared eventfd allows Close to wake all readers blocked in poll.
type tunFile struct {
fd int
shutdownFd int
lastOne bool
readPoll [2]unix.PollFd
writePoll [2]unix.PollFd
closed bool
// vnetHdr is true when this fd was opened with IFF_VNET_HDR and the
// kernel successfully accepted TUNSETOFFLOAD. Reads include a leading
// virtio_net_hdr and may carry a TSO superpacket we must segment;
// writes must prepend a zeroed virtio_net_hdr.
vnetHdr bool
readBuf []byte // scratch for a single raw read (virtio hdr + superpacket)
segBuf []byte // backing store for segmented output
segOff int // cursor into segBuf for the current Read drain
pending [][]byte // segments returned from the most recent Read
writeIovs [2]unix.Iovec // preallocated iovecs for Write (coalescer passthrough); iovs[0] is fixed to validVnetHdr
// rejectIovs is a second preallocated iovec scratch used exclusively by
// WriteReject (reject + self-forward from the inside path). It mirrors
// writeIovs but lets listenIn goroutines emit reject packets without
// racing with the listenOut coalescer that owns writeIovs.
rejectIovs [2]unix.Iovec
// gsoHdrBuf is a per-queue 10-byte scratch for the virtio_net_hdr emitted
// by WriteGSO. Separate from validVnetHdr so a concurrent non-GSO Write on
// another queue never observes a half-written header.
gsoHdrBuf [virtioNetHdrLen]byte
// gsoIovs is the writev iovec scratch for WriteGSO. Sized to hold the
// virtio header + IP/TCP header + up to gsoInitialPayIovs payload
// fragments; grown on demand if a coalescer pushes more.
gsoIovs []unix.Iovec
}
// gsoInitialPayIovs is the starting capacity (in payload fragments) of
// tunFile.gsoIovs. Sized to cover the default coalesce segment cap without
// any reallocations.
const gsoInitialPayIovs = 66
// validVnetHdr is the 10-byte virtio_net_hdr we prepend to every non-GSO TUN
// write. Only flag set is VIRTIO_NET_HDR_F_DATA_VALID, which marks the skb
// CHECKSUM_UNNECESSARY so the receiving network stack skips L4 checksum
// verification. All packets that reach the plain Write / WriteReject paths
// already carry a valid L4 checksum (either supplied by a remote peer whose
// ciphertext we AEAD-authenticated, or produced by finishChecksum during TSO
// segmentation, or built locally by CreateRejectPacket), so trusting them is
// safe.
var validVnetHdr = [virtioNetHdrLen]byte{unix.VIRTIO_NET_HDR_F_DATA_VALID}
// newFriend makes a tunFile for a MultiQueueReader that copies the shutdown eventfd from the parent tun
func (r *tunFile) newFriend(fd int) (*tunFile, error) {
if err := unix.SetNonblock(fd, true); err != nil {
return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err)
}
out := &tunFile{
fd: fd,
shutdownFd: r.shutdownFd,
vnetHdr: r.vnetHdr,
readBuf: make([]byte, tunReadBufSize),
readPoll: [2]unix.PollFd{
{Fd: int32(fd), Events: unix.POLLIN},
{Fd: int32(r.shutdownFd), Events: unix.POLLIN},
},
writePoll: [2]unix.PollFd{
{Fd: int32(fd), Events: unix.POLLOUT},
{Fd: int32(r.shutdownFd), Events: unix.POLLIN},
},
}
if r.vnetHdr {
out.segBuf = make([]byte, tunSegBufCap)
out.writeIovs[0].Base = &validVnetHdr[0]
out.writeIovs[0].SetLen(virtioNetHdrLen)
out.rejectIovs[0].Base = &validVnetHdr[0]
out.rejectIovs[0].SetLen(virtioNetHdrLen)
out.gsoIovs = make([]unix.Iovec, 2, 2+gsoInitialPayIovs)
out.gsoIovs[0].Base = &out.gsoHdrBuf[0]
out.gsoIovs[0].SetLen(virtioNetHdrLen)
}
return out, nil
}
func newTunFd(fd int, vnetHdr bool) (*tunFile, error) {
if err := unix.SetNonblock(fd, true); err != nil {
return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err)
}
shutdownFd, err := unix.Eventfd(0, unix.EFD_NONBLOCK|unix.EFD_CLOEXEC)
if err != nil {
return nil, fmt.Errorf("failed to create eventfd: %w", err)
}
out := &tunFile{
fd: fd,
shutdownFd: shutdownFd,
lastOne: true,
vnetHdr: vnetHdr,
readBuf: make([]byte, tunReadBufSize),
readPoll: [2]unix.PollFd{
{Fd: int32(fd), Events: unix.POLLIN},
{Fd: int32(shutdownFd), Events: unix.POLLIN},
},
writePoll: [2]unix.PollFd{
{Fd: int32(fd), Events: unix.POLLOUT},
{Fd: int32(shutdownFd), Events: unix.POLLIN},
},
}
if vnetHdr {
out.segBuf = make([]byte, tunSegBufCap)
out.writeIovs[0].Base = &validVnetHdr[0]
out.writeIovs[0].SetLen(virtioNetHdrLen)
out.rejectIovs[0].Base = &validVnetHdr[0]
out.rejectIovs[0].SetLen(virtioNetHdrLen)
out.gsoIovs = make([]unix.Iovec, 2, 2+gsoInitialPayIovs)
out.gsoIovs[0].Base = &out.gsoHdrBuf[0]
out.gsoIovs[0].SetLen(virtioNetHdrLen)
}
return out, nil
}
func (r *tunFile) blockOnRead() error {
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
var err error
for {
_, err = unix.Poll(r.readPoll[:], -1)
if err != unix.EINTR {
break
}
}
//always reset these!
tunEvents := r.readPoll[0].Revents
shutdownEvents := r.readPoll[1].Revents
r.readPoll[0].Revents = 0
r.readPoll[1].Revents = 0
//do the err check before trusting the potentially bogus bits we just got
if err != nil {
return err
}
if shutdownEvents&(unix.POLLIN|problemFlags) != 0 {
return os.ErrClosed
} else if tunEvents&problemFlags != 0 {
return os.ErrClosed
}
return nil
}
func (r *tunFile) blockOnWrite() error {
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
var err error
for {
_, err = unix.Poll(r.writePoll[:], -1)
if err != unix.EINTR {
break
}
}
//always reset these!
tunEvents := r.writePoll[0].Revents
shutdownEvents := r.writePoll[1].Revents
r.writePoll[0].Revents = 0
r.writePoll[1].Revents = 0
//do the err check before trusting the potentially bogus bits we just got
if err != nil {
return err
}
if shutdownEvents&(unix.POLLIN|problemFlags) != 0 {
return os.ErrClosed
} else if tunEvents&problemFlags != 0 {
return os.ErrClosed
}
return nil
}
func (r *tunFile) readRaw(buf []byte) (int, error) {
for {
if n, err := unix.Read(r.fd, buf); err == nil {
return n, nil
} else if err == unix.EAGAIN {
if err = r.blockOnRead(); err != nil {
return 0, err
}
continue
} else if err == unix.EINTR {
continue
} else if err == unix.EBADF {
return 0, os.ErrClosed
} else {
return 0, err
}
}
}
// Read reads one or more superpackets from the tun and returns the
// resulting packets. The first read blocks via poll; once the fd is known
// readable we drain additional packets non-blocking until the kernel queue
// is empty (EAGAIN), we've collected tunDrainCap packets, or we're out of
// segBuf headroom. This amortizes the poll wake over bursts of small
// packets (e.g. TCP ACKs). Slices point into the tunFile's internal buffers
// and are only valid until the next Read or Close on this Queue.
func (r *tunFile) Read() ([][]byte, error) {
r.pending = r.pending[:0]
r.segOff = 0
// Initial (blocking) read. Retry on decode errors so a single bad
// packet does not stall the reader.
for {
n, err := r.readRaw(r.readBuf)
if err != nil {
return nil, err
}
if !r.vnetHdr {
r.pending = append(r.pending, r.readBuf[:n])
// Non-vnetHdr mode shares one readBuf so we can't drain safely
// without copying; return the single packet as before.
return r.pending, nil
}
if err := r.decodeRead(n); err != nil {
// Drop and read again — a bad packet should not kill the reader.
continue
}
break
}
// Drain: non-blocking reads until the kernel queue is empty, the drain
// cap is reached, or segBuf no longer has room for another worst-case
// superpacket.
for len(r.pending) < tunDrainCap && tunSegBufCap-r.segOff >= tunSegBufSize {
n, err := unix.Read(r.fd, r.readBuf)
if err != nil {
// EAGAIN / EINTR / anything else: stop draining. We already
// have a valid batch from the first read.
break
}
if n <= 0 {
break
}
if err := r.decodeRead(n); err != nil {
// Drop this packet and stop the drain; we'd rather hand off
// what we have than keep spinning here.
break
}
}
return r.pending, nil
}
// decodeRead decodes the virtio header plus payload in r.readBuf[:n], appends
// the segments to r.pending, and advances r.segOff by the total scratch used.
// Caller must have already ensured r.vnetHdr is true.
func (r *tunFile) decodeRead(n int) error {
if n < virtioNetHdrLen {
return fmt.Errorf("short tun read: %d < %d", n, virtioNetHdrLen)
}
var hdr virtioNetHdr
hdr.decode(r.readBuf[:virtioNetHdrLen])
before := len(r.pending)
if err := segmentInto(r.readBuf[virtioNetHdrLen:n], hdr, &r.pending, r.segBuf[r.segOff:]); err != nil {
return err
}
for k := before; k < len(r.pending); k++ {
r.segOff += len(r.pending[k])
}
return nil
}
func (r *tunFile) Write(buf []byte) (int, error) {
return r.writeWithScratch(buf, &r.writeIovs)
}
// WriteReject emits a packet using a dedicated iovec scratch (rejectIovs)
// distinct from the one used by the coalescer's Write path. This avoids a
// data race between the inside (listenIn) goroutine emitting reject or
// self-forward packets and the outside (listenOut) goroutine flushing TCP
// coalescer passthroughs on the same tunFile.
func (r *tunFile) WriteReject(buf []byte) (int, error) {
return r.writeWithScratch(buf, &r.rejectIovs)
}
func (r *tunFile) writeWithScratch(buf []byte, iovs *[2]unix.Iovec) (int, error) {
if !r.vnetHdr {
for {
if n, err := unix.Write(r.fd, buf); err == nil {
return n, nil
} else if err == unix.EAGAIN {
if err = r.blockOnWrite(); err != nil {
return 0, err
}
continue
} else if err == unix.EINTR {
continue
} else if err == unix.EBADF {
return 0, os.ErrClosed
} else {
return 0, err
}
}
}
if len(buf) == 0 {
return 0, nil
}
// Point the payload iovec at the caller's buffer. iovs[0] is pre-wired
// to validVnetHdr during tunFile construction so we don't rebuild it here.
iovs[1].Base = &buf[0]
iovs[1].SetLen(len(buf))
iovPtr := uintptr(unsafe.Pointer(&iovs[0]))
// The TUN fd is non-blocking (set in newTunFd / newFriend), so writev
// either completes promptly or returns EAGAIN — it cannot park the
// goroutine inside the kernel. That lets us use syscall.RawSyscall and
// skip the runtime.entersyscall / exitsyscall bookkeeping on every
// packet; we only pay that cost when we fall through to blockOnWrite.
for {
n, _, errno := syscall.RawSyscall(unix.SYS_WRITEV, uintptr(r.fd), iovPtr, 2)
if errno == 0 {
runtime.KeepAlive(buf)
if int(n) < virtioNetHdrLen {
return 0, io.ErrShortWrite
}
return int(n) - virtioNetHdrLen, nil
}
if errno == unix.EAGAIN {
runtime.KeepAlive(buf)
if err := r.blockOnWrite(); err != nil {
return 0, err
}
continue
}
if errno == unix.EINTR {
continue
}
runtime.KeepAlive(buf)
return 0, errno
}
}
// GSOSupported reports whether this queue was opened with IFF_VNET_HDR and
// can accept WriteGSO. When false, callers should fall back to per-segment
// Write calls.
func (r *tunFile) GSOSupported() bool { return r.vnetHdr }
// WriteGSO emits a TCP TSO superpacket in a single writev. hdr is the
// IPv4/IPv6 + TCP header prefix (already finalized — total length, IP csum,
// and TCP pseudo-header partial set by the caller). pays are payload
// fragments whose concatenation forms the full coalesced payload; each
// slice is read-only and must stay valid until return. gsoSize is the MSS;
// every segment except possibly the last is exactly gsoSize bytes.
// csumStart is the byte offset where the TCP header begins within hdr.
func (r *tunFile) WriteGSO(hdr []byte, pays [][]byte, gsoSize uint16, isV6 bool, csumStart uint16) error {
if !r.vnetHdr {
return fmt.Errorf("WriteGSO called on tun without IFF_VNET_HDR")
}
if len(hdr) == 0 || len(pays) == 0 {
return nil
}
// Build the virtio_net_hdr. When pays total to <= gsoSize the kernel
// would produce a single segment; keep NEEDS_CSUM semantics but skip
// the GSO type so the kernel doesn't spuriously mark this as TSO.
vhdr := virtioNetHdr{
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
HdrLen: uint16(len(hdr)),
GSOSize: gsoSize,
CsumStart: csumStart,
CsumOffset: 16, // TCP checksum field lives 16 bytes into the TCP header
}
var totalPay int
for _, p := range pays {
totalPay += len(p)
}
if totalPay > int(gsoSize) {
if isV6 {
vhdr.GSOType = unix.VIRTIO_NET_HDR_GSO_TCPV6
} else {
vhdr.GSOType = unix.VIRTIO_NET_HDR_GSO_TCPV4
}
} else {
vhdr.GSOType = unix.VIRTIO_NET_HDR_GSO_NONE
vhdr.GSOSize = 0
}
vhdr.encode(r.gsoHdrBuf[:])
// Build the iovec array: [virtio_hdr, hdr, pays...]. r.gsoIovs[0] is
// wired to gsoHdrBuf at construction and never changes.
need := 2 + len(pays)
if cap(r.gsoIovs) < need {
grown := make([]unix.Iovec, need)
grown[0] = r.gsoIovs[0]
r.gsoIovs = grown
} else {
r.gsoIovs = r.gsoIovs[:need]
}
r.gsoIovs[1].Base = &hdr[0]
r.gsoIovs[1].SetLen(len(hdr))
for i, p := range pays {
r.gsoIovs[2+i].Base = &p[0]
r.gsoIovs[2+i].SetLen(len(p))
}
iovPtr := uintptr(unsafe.Pointer(&r.gsoIovs[0]))
iovCnt := uintptr(len(r.gsoIovs))
for {
n, _, errno := syscall.RawSyscall(unix.SYS_WRITEV, uintptr(r.fd), iovPtr, iovCnt)
if errno == 0 {
runtime.KeepAlive(hdr)
runtime.KeepAlive(pays)
if int(n) < virtioNetHdrLen {
return io.ErrShortWrite
}
return nil
}
if errno == unix.EAGAIN {
runtime.KeepAlive(hdr)
runtime.KeepAlive(pays)
if err := r.blockOnWrite(); err != nil {
return err
}
continue
}
if errno == unix.EINTR {
continue
}
runtime.KeepAlive(hdr)
runtime.KeepAlive(pays)
return errno
}
}
func (r *tunFile) wakeForShutdown() error {
var buf [8]byte
binary.NativeEndian.PutUint64(buf[:], 1)
_, err := unix.Write(int(r.readPoll[1].Fd), buf[:])
return err
}
func (r *tunFile) Close() error {
if r.closed { // avoid closing more than once. Technically a fd could get re-used, which would be a problem
return nil
}
r.closed = true
if r.lastOne {
_ = unix.Close(r.shutdownFd)
}
return unix.Close(r.fd)
}
type tun struct {
*tunFile
readers []*tunFile
readers tio.Container
closeLock sync.Mutex
Device string
vpnNetworks []netip.Prefix
@@ -484,6 +34,7 @@ type tun struct {
TXQueueLen int
deviceIndex int
ioctlFd uintptr
vnetHdr bool
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
@@ -622,15 +173,28 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
// newTunGeneric does all the stuff common to different tun initialization paths. It will close your files on error.
func newTunGeneric(c *config.C, l *logrus.Logger, fd int, vnetHdr bool, vpnNetworks []netip.Prefix) (*tun, error) {
tfd, err := newTunFd(fd, vnetHdr)
var container tio.Container
var err error
if vnetHdr {
container, err = tio.NewGSOContainer()
} else {
container, err = tio.NewPollContainer()
}
if err != nil {
_ = unix.Close(fd)
return nil, err
}
err = container.Add(fd)
if err != nil {
_ = unix.Close(fd)
return nil, err
}
t := &tun{
tunFile: tfd,
readers: []*tunFile{tfd},
readers: container,
closeLock: sync.Mutex{},
vnetHdr: vnetHdr,
vpnNetworks: vpnNetworks,
TXQueueLen: c.GetInt("tun.tx_queue", 500),
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
@@ -732,13 +296,13 @@ func (t *tun) SupportsMultiqueue() bool {
return true
}
func (t *tun) NewMultiQueueReader() (Queue, error) {
func (t *tun) NewMultiQueueReader() error {
t.closeLock.Lock()
defer t.closeLock.Unlock()
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil {
return nil, err
return err
}
flags := uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE)
@@ -747,25 +311,23 @@ func (t *tun) NewMultiQueueReader() (Queue, error) {
}
if _, err = tunSetIff(fd, t.Device, flags); err != nil {
_ = unix.Close(fd)
return nil, err
return err
}
if t.vnetHdr {
if err = ioctl(uintptr(fd), unix.TUNSETOFFLOAD, uintptr(tsoOffloadFlags)); err != nil {
_ = unix.Close(fd)
return nil, fmt.Errorf("failed to enable offload on multiqueue tun fd: %w", err)
return fmt.Errorf("failed to enable offload on multiqueue tun fd: %w", err)
}
}
out, err := t.tunFile.newFriend(fd)
err = t.readers.Add(fd)
if err != nil {
_ = unix.Close(fd)
return nil, err
return err
}
t.readers = append(t.readers, out)
return out, nil
return nil
}
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
@@ -1195,6 +757,10 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
t.routeTree.Store(newTree)
}
func (t *tun) Readers() []tio.Queue {
return t.readers.Queues()
}
func (t *tun) Close() error {
t.closeLock.Lock()
defer t.closeLock.Unlock()
@@ -1204,32 +770,10 @@ func (t *tun) Close() error {
t.routeChan = nil
}
// Signal all readers blocked in poll to wake up and exit
_ = t.tunFile.wakeForShutdown()
if t.ioctlFd > 0 {
_ = unix.Close(int(t.ioctlFd))
t.ioctlFd = 0
}
for i := range t.readers {
if i == 0 {
continue //we want to close the zeroth reader last
}
err := t.readers[i].Close()
if err != nil {
t.l.WithField("reader", i).WithError(err).Error("error closing tun reader")
} else {
t.l.WithField("reader", i).Info("closed tun reader")
}
}
//this is t.readers[0] too
err := t.tunFile.Close()
if err != nil {
t.l.WithField("reader", 0).WithError(err).Error("error closing tun reader")
} else {
t.l.WithField("reader", 0).Info("closed tun reader")
}
return err
return t.readers.Close()
}

View File

@@ -1,34 +0,0 @@
//go:build !e2e_testing
// +build !e2e_testing
package overlay
import "testing"
var runAdvMSSTests = []struct {
name string
tun *tun
r Route
expected int
}{
// Standard case, default MTU is the device max MTU
{"default", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{}, 0},
{"default-min", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1440}, 0},
{"default-low", &tun{DefaultMTU: 1440, MaxMTU: 1440}, Route{MTU: 1200}, 1160},
// Case where we have a route MTU set higher than the default
{"route", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{}, 1400},
{"route-min", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 1440}, 1400},
{"route-high", &tun{DefaultMTU: 1440, MaxMTU: 8941}, Route{MTU: 8941}, 0},
}
func TestTunAdvMSS(t *testing.T) {
for _, tt := range runAdvMSSTests {
t.Run(tt.name, func(t *testing.T) {
o := tt.tun.advMSS(tt.r)
if o != tt.expected {
t.Errorf("got %d, want %d", o, tt.expected)
}
})
}
}

View File

@@ -16,6 +16,7 @@ import (
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay/tio"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
netroute "golang.org/x/net/route"
@@ -412,7 +413,7 @@ func (t *tun) SupportsMultiqueue() bool {
return false
}
func (t *tun) NewMultiQueueReader() (Queue, error) {
func (t *tun) NewMultiQueueReader() (tio.Queue, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for netbsd")
}

View File

@@ -16,6 +16,7 @@ import (
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay/tio"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
netroute "golang.org/x/net/route"
@@ -332,7 +333,7 @@ func (t *tun) SupportsMultiqueue() bool {
return false
}
func (t *tun) NewMultiQueueReader() (Queue, error) {
func (t *tun) NewMultiQueueReader() (tio.Queue, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for openbsd")
}

View File

@@ -13,6 +13,7 @@ import (
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay/tio"
"github.com/slackhq/nebula/routing"
)
@@ -142,6 +143,6 @@ func (t *TestTun) SupportsMultiqueue() bool {
return false
}
func (t *TestTun) NewMultiQueueReader() (Queue, error) {
func (t *TestTun) NewMultiQueueReader() (tio.Queue, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented")
}

View File

@@ -17,6 +17,7 @@ import (
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay/tio"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
"github.com/slackhq/nebula/wintun"
@@ -255,7 +256,7 @@ func (t *winTun) SupportsMultiqueue() bool {
return false
}
func (t *winTun) NewMultiQueueReader() (Queue, error) {
func (t *winTun) NewMultiQueueReader() (tio.Queue, error) {
return nil, fmt.Errorf("TODO: multiqueue not implemented for windows")
}

View File

@@ -6,6 +6,7 @@ import (
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay/tio"
"github.com/slackhq/nebula/routing"
)
@@ -28,6 +29,7 @@ func NewUserDevice(vpnNetworks []netip.Prefix) (Device, error) {
type UserDevice struct {
vpnNetworks []netip.Prefix
numReaders int
outboundReader *io.PipeReader
outboundWriter *io.PipeWriter
@@ -65,8 +67,17 @@ func (d *UserDevice) SupportsMultiqueue() bool {
return true
}
func (d *UserDevice) NewMultiQueueReader() (Queue, error) {
return d, nil
func (d *UserDevice) NewMultiQueueReader() error {
d.numReaders++
return nil
}
func (d *UserDevice) Readers() []tio.Queue {
out := make([]tio.Queue, d.numReaders)
for i := range d.numReaders {
out[i] = d
}
return out
}
func (d *UserDevice) Pipe() (*io.PipeReader, *io.PipeWriter) {