GSO/GRO offloads, with TCP+ECN and UDP support

This commit is contained in:
JackDoan
2026-04-17 10:25:05 -05:00
parent 4b4331ba42
commit 6a46a2913a
60 changed files with 6915 additions and 283 deletions

View File

@@ -10,7 +10,7 @@ import (
"github.com/slackhq/nebula/noiseutil" "github.com/slackhq/nebula/noiseutil"
) )
const ReplayWindow = 1024 const ReplayWindow = 8192
type ConnectionState struct { type ConnectionState struct {
eKey noiseutil.CipherState eKey noiseutil.CipherState

23
cpupin_linux.go Normal file
View File

@@ -0,0 +1,23 @@
//go:build linux && !android && !e2e_testing
package nebula
import (
"runtime"
"golang.org/x/sys/unix"
)
// pinThreadToCPU restricts the calling OS thread to the given CPU via
// sched_setaffinity(2). Combined with runtime.LockOSThread on the
// goroutine, this prevents the kernel from migrating us across CPUs and
// in turn keeps every sendmmsg from this goroutine going through the
// same XPS-selected TX ring, eliminating the wire-side reorder that
// otherwise fragments one nebula flow across multiple rings.
func pinThreadToCPU(cpu int) error {
runtime.LockOSThread()
var set unix.CPUSet
set.Zero()
set.Set(cpu)
return unix.SchedSetaffinity(0, &set)
}

11
cpupin_other.go Normal file
View File

@@ -0,0 +1,11 @@
//go:build !linux || android || e2e_testing
package nebula
// pinThreadToCPU is a no-op outside Linux: only Linux exposes a stable
// per-thread CPU affinity API and only Linux has XPS-driven TX ring
// selection in the first place. On every other platform there's nothing
// to fix here.
func pinThreadToCPU(_ int) error {
return nil
}

125
ecn_inner_test.go Normal file
View File

@@ -0,0 +1,125 @@
package nebula
import (
"io"
"log/slog"
"testing"
)
func TestInnerECN(t *testing.T) {
cases := []struct {
name string
pkt []byte
want byte
}{
{"empty", nil, 0},
{"v4_NotECT", v4WithToS(0x00), 0x00},
{"v4_ECT0", v4WithToS(0x02), 0x02},
{"v4_ECT1", v4WithToS(0x01), 0x01},
{"v4_CE", v4WithToS(0x03), 0x03},
{"v4_DSCP_then_NotECT", v4WithToS(0x88 | 0x00), 0x00},
{"v4_DSCP_then_CE", v4WithToS(0x88 | 0x03), 0x03},
{"v6_NotECT", v6WithTC(0x00), 0x00},
{"v6_ECT0", v6WithTC(0x02), 0x02},
{"v6_CE", v6WithTC(0x03), 0x03},
{"v6_DSCP_then_CE", v6WithTC(0x88 | 0x03), 0x03},
{"unknown_version", []byte{0xa5, 0xff}, 0},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
got := innerECN(c.pkt)
if got != c.want {
t.Errorf("innerECN=0x%02x want 0x%02x", got, c.want)
}
})
}
}
// v4WithToS returns a 2-byte slice tall enough for innerECN: byte 0 carries
// version=4 in the high nibble, byte 1 is the full ToS so we exercise both
// the DSCP and ECN portions through the byte 1 mask.
func v4WithToS(tos byte) []byte {
return []byte{0x45, tos}
}
// v6WithTC builds a 2-byte slice that places a known traffic class value
// across bytes 0 (high nibble of TC) and 1 (low nibble of TC). innerECN
// extracts ECN as (b[1]>>4)&0x03, which corresponds to TC[1:0].
func v6WithTC(tc byte) []byte {
return []byte{0x60 | (tc>>4)&0x0f, (tc & 0x0f) << 4}
}
func TestApplyOuterECN(t *testing.T) {
silent := slog.New(slog.NewTextHandler(io.Discard, nil))
hi := &HostInfo{}
// Build a v4 packet helper with a given inner ECN field.
v4 := func(innerECN byte) []byte {
// 20-byte minimal IPv4 header with ToS = innerECN (DSCP zeroed).
return []byte{
0x45, innerECN, 0, 28,
0, 0, 0x40, 0,
64, 6, 0, 0,
10, 0, 0, 1,
10, 0, 0, 2,
}
}
// Build a v6 packet helper with a given inner ECN field. ECN occupies
// TC[1:0] which sit at byte 1 mask 0x30.
v6 := func(innerECN byte) []byte {
// 40-byte minimal IPv6 header with TC[1:0] = innerECN.
pkt := make([]byte, 40)
pkt[0] = 0x60 // version=6, TC[7:4]=0
pkt[1] = (innerECN & 0x03) << 4 // TC[3:0]: low 2 bits = ECN, top 2 = DSCP-low (0)
return pkt
}
type cell struct {
outer byte
inner byte
wantECN byte
wantSame bool // expect inner unchanged (true => verify the byte didn't move)
}
// RFC 6040 normal-mode combine table. Only outer==CE causes mutation.
table := []cell{
{ecnNotECT, ecnNotECT, ecnNotECT, true},
{ecnNotECT, ecnECT0, ecnECT0, true},
{ecnNotECT, ecnECT1, ecnECT1, true},
{ecnNotECT, ecnCE, ecnCE, true},
{ecnECT0, ecnNotECT, ecnNotECT, true},
{ecnECT0, ecnECT0, ecnECT0, true},
{ecnECT0, ecnECT1, ecnECT1, true},
{ecnECT0, ecnCE, ecnCE, true},
{ecnECT1, ecnNotECT, ecnNotECT, true},
{ecnECT1, ecnECT0, ecnECT0, true},
{ecnECT1, ecnECT1, ecnECT1, true},
{ecnECT1, ecnCE, ecnCE, true},
{ecnCE, ecnNotECT, ecnNotECT, true}, // legacy: log, leave alone
{ecnCE, ecnECT0, ecnCE, false}, // CE folded in
{ecnCE, ecnECT1, ecnCE, false},
{ecnCE, ecnCE, ecnCE, true},
}
for _, c := range table {
t.Run("v4", func(t *testing.T) {
pkt := v4(c.inner)
applyOuterECN(pkt, c.outer, hi, silent)
got := pkt[1] & 0x03
if got != c.wantECN {
t.Errorf("v4 outer=0x%02x inner=0x%02x: got 0x%02x want 0x%02x", c.outer, c.inner, got, c.wantECN)
}
})
t.Run("v6", func(t *testing.T) {
pkt := v6(c.inner)
applyOuterECN(pkt, c.outer, hi, silent)
got := (pkt[1] >> 4) & 0x03
if got != c.wantECN {
t.Errorf("v6 outer=0x%02x inner=0x%02x: got 0x%02x want 0x%02x", c.outer, c.inner, got, c.wantECN)
}
})
}
}

View File

@@ -5,6 +5,8 @@ import (
"log/slog" "log/slog"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/slackhq/nebula/logging"
) )
// ConntrackCache is used as a local routine cache to know if a given flow // ConntrackCache is used as a local routine cache to know if a given flow
@@ -56,8 +58,8 @@ func (c *ConntrackCacheTicker) Get() ConntrackCache {
if tick := c.cacheTick.Load(); tick != c.cacheV { if tick := c.cacheTick.Load(); tick != c.cacheV {
c.cacheV = tick c.cacheV = tick
if ll := len(c.cache); ll > 0 { if ll := len(c.cache); ll > 0 {
if c.l.Enabled(context.Background(), slog.LevelDebug) { if c.l.Enabled(context.Background(), logging.LevelTrace) {
c.l.Debug("resetting conntrack cache", "len", ll) c.l.Log(context.Background(), logging.LevelTrace, "resetting conntrack cache", "len", ll)
} }
c.cache = make(ConntrackCache, ll) c.cache = make(ConntrackCache, ll)
} }

View File

@@ -6,6 +6,7 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/slackhq/nebula/logging"
"github.com/slackhq/nebula/test" "github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@@ -30,27 +31,27 @@ func newFixedTicker(t *testing.T, l *slog.Logger, cacheLen int) *ConntrackCacheT
func TestConntrackCacheTicker_Get_TextFormat(t *testing.T) { func TestConntrackCacheTicker_Get_TextFormat(t *testing.T) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelDebug) l := test.NewLoggerWithOutputAndLevel(buf, logging.LevelTrace)
c := newFixedTicker(t, l, 3) c := newFixedTicker(t, l, 3)
c.Get() c.Get()
assert.Equal(t, "level=DEBUG msg=\"resetting conntrack cache\" len=3\n", buf.String()) assert.Equal(t, "level=DEBUG-4 msg=\"resetting conntrack cache\" len=3\n", buf.String())
} }
func TestConntrackCacheTicker_Get_JSONFormat(t *testing.T) { func TestConntrackCacheTicker_Get_JSONFormat(t *testing.T) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
l := test.NewJSONLoggerWithOutput(buf, slog.LevelDebug) l := test.NewJSONLoggerWithOutput(buf, logging.LevelTrace)
c := newFixedTicker(t, l, 2) c := newFixedTicker(t, l, 2)
c.Get() c.Get()
assert.JSONEq(t, `{"level":"DEBUG","msg":"resetting conntrack cache","len":2}`, strings.TrimSpace(buf.String())) assert.JSONEq(t, `{"level":"DEBUG-4","msg":"resetting conntrack cache","len":2}`, strings.TrimSpace(buf.String()))
} }
func TestConntrackCacheTicker_Get_QuietBelowDebug(t *testing.T) { func TestConntrackCacheTicker_Get_QuietBelowTrace(t *testing.T) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelInfo) l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelDebug)
c := newFixedTicker(t, l, 5) c := newFixedTicker(t, l, 5)
c.Get() c.Get()
@@ -60,7 +61,7 @@ func TestConntrackCacheTicker_Get_QuietBelowDebug(t *testing.T) {
func TestConntrackCacheTicker_Get_QuietWhenCacheEmpty(t *testing.T) { func TestConntrackCacheTicker_Get_QuietWhenCacheEmpty(t *testing.T) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelDebug) l := test.NewLoggerWithOutputAndLevel(buf, logging.LevelTrace)
c := newFixedTicker(t, l, 0) c := newFixedTicker(t, l, 0)
c.Get() c.Get()

1
go.mod
View File

@@ -43,6 +43,7 @@ require (
github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/google/btree v1.1.2 // indirect github.com/google/btree v1.1.2 // indirect
github.com/guptarohit/asciigraph v0.9.0 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_model v0.6.2 // indirect github.com/prometheus/client_model v0.6.2 // indirect

2
go.sum
View File

@@ -60,6 +60,8 @@ github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
github.com/guptarohit/asciigraph v0.9.0 h1:MvCSRRVkT2XvU1IO6n92o7l7zqx1DiFaoszOUZQztbY=
github.com/guptarohit/asciigraph v0.9.0/go.mod h1:dYl5wwK4gNsnFf9Zp+l06rFiDZ5YtXM6x7SRWZ3KGag=
github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4=
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=

143
inside.go
View File

@@ -10,10 +10,23 @@ import (
"github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/noiseutil" "github.com/slackhq/nebula/noiseutil"
"github.com/slackhq/nebula/overlay/batch" "github.com/slackhq/nebula/overlay/batch"
"github.com/slackhq/nebula/overlay/tio"
"github.com/slackhq/nebula/routing" "github.com/slackhq/nebula/routing"
) )
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb []byte, sendBatch batch.TxBatcher, rejectBuf []byte, q int, localCache firewall.ConntrackCache) { func (f *Interface) consumeInsidePacket(pkt tio.Packet, fwPacket *firewall.Packet, nb []byte, sendBatch batch.TxBatcher, rejectBuf []byte, q int, localCache firewall.ConntrackCache) {
// borrowed: pkt.Bytes is owned by the originating tio.Queue and is
// only valid until the next Read on that queue. Every consumer below
// (parse, self-forward, handshake cache, sendInsideMessage) reads it
// synchronously; do not retain pkt outside this call. If a future
// caller needs to keep the packet, use pkt.Clone() to detach it from
// the borrow.
//
// pkt.Bytes is either one IP datagram (GSO zero) or a TSO/USO
// superpacket. In both cases the L3+L4 headers at the start describe
// the same 5-tuple every segment will share, so a single newPacket /
// firewall check covers the whole superpacket.
packet := pkt.Bytes
err := newPacket(packet, false, fwPacket) err := newPacket(packet, false, fwPacket)
if err != nil { if err != nil {
if f.l.Enabled(context.Background(), slog.LevelDebug) { if f.l.Enabled(context.Background(), slog.LevelDebug) {
@@ -38,7 +51,14 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
// routes packets from the Nebula addr to the Nebula addr through the Nebula // routes packets from the Nebula addr to the Nebula addr through the Nebula
// TUN device. // TUN device.
if immediatelyForwardToSelf { if immediatelyForwardToSelf {
_, err := f.readers[q].Write(packet) // Write copies into the kernel queue synchronously, so seg's lifetime ends at return.
// A self-forwarded superpacket would be re-handed to the
// kernel as one giant blob; segment first so the loopback
// path sees one IP datagram per Write.
err := tio.SegmentSuperpacket(pkt, func(seg []byte) error {
_, werr := f.readers[q].Write(seg)
return werr
})
if err != nil { if err != nil {
f.l.Error("Failed to forward to tun", "error", err) f.l.Error("Failed to forward to tun", "error", err)
} }
@@ -54,7 +74,19 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
} }
hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) { hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) {
hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) // borrowed: SegmentSuperpacket builds each segment in the kernel-supplied pkt
// bytes underneath. cachePacket explicitly copies its argument (handshake_manager.go cachePacket),
// so retaining segments past the loop is safe.
err := tio.SegmentSuperpacket(pkt, func(seg []byte) error {
hh.cachePacket(f.l, header.Message, 0, seg, f.sendMessageNow, f.cachedPacketMetrics)
return nil
})
if err != nil && f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("Failed to segment superpacket for handshake cache",
"error", err,
"vpnAddr", fwPacket.RemoteAddr,
)
}
}) })
if hostinfo == nil { if hostinfo == nil {
@@ -74,7 +106,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache) dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
if dropReason == nil { if dropReason == nil {
f.sendInsideMessage(hostinfo, packet, nb, sendBatch, rejectBuf, q) f.sendInsideMessage(hostinfo, pkt, nb, sendBatch, rejectBuf, q)
} else { } else {
f.rejectInside(packet, rejectBuf, q) f.rejectInside(packet, rejectBuf, q)
if f.l.Enabled(context.Background(), slog.LevelDebug) { if f.l.Enabled(context.Background(), slog.LevelDebug) {
@@ -86,11 +118,23 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
} }
} }
// sendInsideMessage encrypts a firewall-approved inside packet into the // sendInsideMessage encrypts a firewall-approved inside packet (or every
// caller's batch slot for later sendmmsg flush. When hostinfo.remote is not // segment of a TSO/USO superpacket) into the caller's batch slot for
// valid we fall through to the relay slow path via the unbatched sendNoMetrics // later sendmmsg flush. Segmentation is fused with encryption here so the
// so relay behavior is unchanged. // kernel-supplied superpacket bytes never get written into a separate
func (f *Interface) sendInsideMessage(hostinfo *HostInfo, p, nb []byte, sendBatch batch.TxBatcher, rejectBuf []byte, q int) { // scratch arena: SegmentSuperpacket builds each segment's plaintext in
// segScratch[:segLen] in turn, and we encrypt directly into a fresh
// SendBatch slot.
//
// When hostinfo.remote is not valid we fall through to the relay slow
// path via the unbatched sendNoMetrics so relay behavior is unchanged;
// each segment of a superpacket goes through that path independently.
// sendInsideMessage takes a borrowed pkt: pkt.Bytes is only valid until the
// next Read on the originating tio.Queue. Each segment is encrypted into a
// fresh sendBatch slot (Reserve returns owned scratch), so the borrow ends
// inside the SegmentSuperpacket callback below. Do not retain pkt or any
// seg slice past the callback's return.
func (f *Interface) sendInsideMessage(hostinfo *HostInfo, pkt tio.Packet, nb []byte, sendBatch batch.TxBatcher, rejectBuf []byte, q int) {
ci := hostinfo.ConnectionState ci := hostinfo.ConnectionState
if ci.eKey == nil { if ci.eKey == nil {
return return
@@ -99,26 +143,20 @@ func (f *Interface) sendInsideMessage(hostinfo *HostInfo, p, nb []byte, sendBatc
if !hostinfo.remote.IsValid() { if !hostinfo.remote.IsValid() {
// Slow path: relay fallback. Reuse rejectBuf as the ciphertext // Slow path: relay fallback. Reuse rejectBuf as the ciphertext
// scratch; sendNoMetrics arranges header space for SendVia. // scratch; sendNoMetrics arranges header space for SendVia.
f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, p, nb, rejectBuf, q) // Segment any superpacket so each segment is sized to fit a
// single relay encap.
err := tio.SegmentSuperpacket(pkt, func(seg []byte) error {
f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, seg, nb, rejectBuf, q)
return nil
})
if err != nil {
hostinfo.logger(f.l).Error("Failed to segment superpacket for relay send",
"error", err,
)
}
return return
} }
scratch := sendBatch.Next()
if scratch == nil {
// Batch full: bypass batching and send this packet directly so we
// never drop traffic on over-subscribed iterations.
f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, p, nb, rejectBuf, q)
return
}
if noiseutil.EncryptLockNeeded {
ci.writeLock.Lock()
}
c := ci.messageCounter.Add(1)
out := header.Encode(scratch, header.Version, header.Message, 0, hostinfo.remoteIndexId, c)
f.connectionManager.Out(hostinfo)
if hostinfo.lastRebindCount != f.rebindCount { if hostinfo.lastRebindCount != f.rebindCount {
//NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is //NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is
// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help. // finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
@@ -131,20 +169,63 @@ func (f *Interface) sendInsideMessage(hostinfo *HostInfo, p, nb []byte, sendBatc
} }
} }
out, err := ci.eKey.EncryptDanger(out, out, p, c, nb) ecnEnabled := f.ecnEnabled.Load()
err := tio.SegmentSuperpacket(pkt, func(seg []byte) error {
// header + plaintext + AEAD tag (16 bytes for both AES-GCM and ChaCha20-Poly1305)
scratch := sendBatch.Reserve(header.Len + len(seg) + 16)
if noiseutil.EncryptLockNeeded {
ci.writeLock.Lock()
}
c := ci.messageCounter.Add(1)
out := header.Encode(scratch, header.Version, header.Message, 0, hostinfo.remoteIndexId, c)
f.connectionManager.Out(hostinfo)
out, encErr := ci.eKey.EncryptDanger(out, out, seg, c, nb)
if noiseutil.EncryptLockNeeded { if noiseutil.EncryptLockNeeded {
ci.writeLock.Unlock() ci.writeLock.Unlock()
} }
if err != nil { if encErr != nil {
hostinfo.logger(f.l).Error("Failed to encrypt outgoing packet", hostinfo.logger(f.l).Error("Failed to encrypt outgoing packet",
"error", err, "error", encErr,
"udpAddr", hostinfo.remote, "udpAddr", hostinfo.remote,
"counter", c, "counter", c,
) )
return // Skip this segment; the rest of the superpacket can still
// go out — TCP will retransmit anything we drop here.
return nil
} }
sendBatch.Commit(len(out), hostinfo.remote) var ecn byte
if ecnEnabled {
ecn = innerECN(seg)
}
sendBatch.Commit(out, hostinfo.remote, ecn)
return nil
})
if err != nil {
hostinfo.logger(f.l).Error("Failed to segment superpacket for send",
"error", err,
)
}
}
// innerECN returns the 2-bit IP-level ECN codepoint of an inner IPv4 or IPv6
// packet, or 0 if pkt is too short or its IP version is unrecognized. Used at
// encap to copy the inner codepoint onto the outer carrier per RFC 6040.
func innerECN(pkt []byte) byte {
if len(pkt) < 2 {
return 0
}
switch pkt[0] >> 4 {
case 4:
return pkt[1] & 0x03
case 6:
return (pkt[1] >> 4) & 0x03
}
return 0
} }
func (f *Interface) rejectInside(packet []byte, out []byte, q int) { func (f *Interface) rejectInside(packet []byte, out []byte, q int) {

View File

@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"log/slog" "log/slog"
"net/netip" "net/netip"
"runtime"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -48,6 +49,13 @@ type InterfaceConfig struct {
reQueryWait time.Duration reQueryWait time.Duration
ConntrackCacheTimeout time.Duration ConntrackCacheTimeout time.Duration
// CpuAffinity, when non-empty, names the CPUs each TUN reader goroutine
// should pin to. Queue i pins to CpuAffinity[i % len(CpuAffinity)] —
// shorter lists than `routines` cycle. Empty list keeps the default
// pin-to-(i % NumCPU) behavior.
CpuAffinity []int
l *slog.Logger l *slog.Logger
} }
@@ -72,6 +80,15 @@ type Interface struct {
routines int routines int
disconnectInvalid atomic.Bool disconnectInvalid atomic.Bool
closed atomic.Bool closed atomic.Bool
// cpuAffinity, when non-empty, names the CPUs each TUN reader goroutine
// should pin to. Queue i pins to cpuAffinity[i % len(cpuAffinity)].
// Empty falls back to the default pin-to-(i % NumCPU) behavior.
cpuAffinity []int
// ecnEnabled gates RFC 6040 underlay ECN propagation. When true,
// inside.go copies the inner ECN onto the outer carrier on encap and
// decryptToTun folds outer CE into the inner header on decap. Toggle
// via tunnels.ecn (default true).
ecnEnabled atomic.Bool
relayManager *relayManager relayManager *relayManager
tryPromoteEvery atomic.Uint32 tryPromoteEvery atomic.Uint32
@@ -202,6 +219,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
relayManager: c.relayManager, relayManager: c.relayManager,
connectionManager: c.connectionManager, connectionManager: c.connectionManager,
conntrackCacheTimeout: c.ConntrackCacheTimeout, conntrackCacheTimeout: c.ConntrackCacheTimeout,
cpuAffinity: c.CpuAffinity,
metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)), metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
messageMetrics: c.MessageMetrics, messageMetrics: c.MessageMetrics,
@@ -260,8 +278,17 @@ func (f *Interface) activate() error {
} }
f.readers = f.inside.Readers() f.readers = f.inside.Readers()
for i := range f.readers { for i := range f.readers {
caps := tio.QueueCapabilities(f.readers[i])
if caps.TSO || caps.USO {
// Multi-lane: TCP gets coalesced when TSO is on, UDP when USO
// is on, everything else (and either lane disabled) falls
// through to passthrough so non-IP / non-TCP-UDP traffic still
// reaches the TUN.
f.batchers[i] = batch.NewMultiCoalescer(f.readers[i], caps.TSO, caps.USO)
} else {
f.batchers[i] = batch.NewPassthrough(f.readers[i]) f.batchers[i] = batch.NewPassthrough(f.readers[i])
} }
}
f.wg.Add(1) // for us to wait on Close() to return f.wg.Add(1) // for us to wait on Close() to return
if err = f.inside.Activate(); err != nil { if err = f.inside.Activate(); err != nil {
@@ -322,15 +349,13 @@ func (f *Interface) listenOut(i int) {
fwPacket := &firewall.Packet{} fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12) nb := make([]byte, 12, 12)
coalescer := f.batchers[i] listener := func(fromUdpAddr netip.AddrPort, payload []byte, meta udp.RxMeta) {
listener := func(fromUdpAddr netip.AddrPort, payload []byte) {
plaintext := f.batchers[i].Reserve(len(payload)) plaintext := f.batchers[i].Reserve(len(payload))
f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get()) f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(), meta)
} }
flusher := func() { flusher := func() {
if err := coalescer.Flush(); err != nil { if err := f.batchers[i].Flush(); err != nil {
f.l.Error("Failed to flush tun coalescer", "error", err) f.l.Error("Failed to flush tun coalescer", "error", err)
} }
} }
@@ -346,8 +371,27 @@ func (f *Interface) listenOut(i int) {
} }
func (f *Interface) listenIn(reader tio.Queue, i int) { func (f *Interface) listenIn(reader tio.Queue, i int) {
// Pin this goroutine to one CPU. LockOSThread alone keeps the goroutine
// on a single OS thread but the kernel can still migrate that thread
// across CPUs — XPS reads smp_processor_id() at sendmmsg time and picks
// the TX ring from the current CPU's xps_cpus map, so an unpinned
// thread bouncing between CPUs spreads one nebula flow's packets across
// multiple TX rings, which the rings then drain at independent rates
// and the wire delivers reordered.
//
// Pinning keeps every sendmmsg from this goroutine going through the
// same TX ring, so the wire sees per-flow order. Cost: less scheduler
// flexibility — if i % NumCPU collides between two TUN reader
// goroutines they share a CPU.
cpu := i % runtime.NumCPU()
if n := len(f.cpuAffinity); n > 0 {
cpu = f.cpuAffinity[i%n]
}
if err := pinThreadToCPU(cpu); err != nil {
f.l.Warn("failed to pin tun reader to CPU", "queue", i, "cpu", cpu, "err", err)
}
rejectBuf := make([]byte, mtu) rejectBuf := make([]byte, mtu)
sb := batch.NewSendBatch(batch.SendBatchCap, udp.MTU+32) sb := batch.NewSendBatch(f.writers[i], batch.SendBatchCap, udp.MTU+32)
fwPacket := &firewall.Packet{} fwPacket := &firewall.Packet{}
nb := make([]byte, 12, 12) nb := make([]byte, 12, 12)
@@ -363,35 +407,24 @@ func (f *Interface) listenIn(reader tio.Queue, i int) {
break break
} }
sb.Reset()
for _, pkt := range pkts { for _, pkt := range pkts {
if sb.Len() >= sb.Cap() {
f.flushBatch(sb, i)
sb.Reset()
}
f.consumeInsidePacket(pkt, fwPacket, nb, sb, rejectBuf, i, conntrackCache.Get()) f.consumeInsidePacket(pkt, fwPacket, nb, sb, rejectBuf, i, conntrackCache.Get())
} }
if sb.Len() > 0 { if err := sb.Flush(); err != nil {
f.flushBatch(sb, i) f.l.Error("Failed to write outgoing batch", "error", err, "writer", i)
} }
} }
f.l.Debug("overlay reader is done", "reader", i) f.l.Debug("overlay reader is done", "reader", i)
} }
func (f *Interface) flushBatch(sb batch.TxBatcher, q int) {
bufs, dsts := sb.Get()
if err := f.writers[q].WriteBatch(bufs, dsts); err != nil {
f.l.Error("Failed to write outgoing batch", "error", err, "writer", q)
}
}
func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) { func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
c.RegisterReloadCallback(f.reloadFirewall) c.RegisterReloadCallback(f.reloadFirewall)
c.RegisterReloadCallback(f.reloadSendRecvError) c.RegisterReloadCallback(f.reloadSendRecvError)
c.RegisterReloadCallback(f.reloadAcceptRecvError) c.RegisterReloadCallback(f.reloadAcceptRecvError)
c.RegisterReloadCallback(f.reloadDisconnectInvalid) c.RegisterReloadCallback(f.reloadDisconnectInvalid)
c.RegisterReloadCallback(f.reloadMisc) c.RegisterReloadCallback(f.reloadMisc)
c.RegisterReloadCallback(f.reloadEcn)
for _, udpConn := range f.writers { for _, udpConn := range f.writers {
c.RegisterReloadCallback(udpConn.ReloadConfig) c.RegisterReloadCallback(udpConn.ReloadConfig)
@@ -515,6 +548,20 @@ func (f *Interface) reloadMisc(c *config.C) {
} }
} }
// reloadEcn syncs Interface.ecnEnabled with the tunnels.ecn config knob.
// Default is enabled (RFC 6040 normal mode); set false on the rare path
// where an underlay middlebox rewrites or drops ECN bits unpredictably.
func (f *Interface) reloadEcn(c *config.C) {
initial := c.InitialLoad()
if initial || c.HasChanged("tunnels.ecn") {
v := c.GetBool("tunnels.ecn", true)
f.ecnEnabled.Store(v)
if !initial {
f.l.Info("tunnels.ecn changed", "enabled", v)
}
}
}
func (f *Interface) emitStats(ctx context.Context, i time.Duration) { func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
ticker := time.NewTicker(i) ticker := time.NewTicker(i)
defer ticker.Stop() defer ticker.Stop()

49
main.go
View File

@@ -220,6 +220,7 @@ func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, dev
relayManager: NewRelayManager(ctx, l, hostMap, c), relayManager: NewRelayManager(ctx, l, hostMap, c),
punchy: punchy, punchy: punchy,
ConntrackCacheTimeout: conntrackCacheTimeout, ConntrackCacheTimeout: conntrackCacheTimeout,
CpuAffinity: parseCpuAffinity(c, l, routines),
l: l, l: l,
} }
@@ -237,6 +238,7 @@ func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, dev
ifce.reloadDisconnectInvalid(c) ifce.reloadDisconnectInvalid(c)
ifce.reloadSendRecvError(c) ifce.reloadSendRecvError(c)
ifce.reloadAcceptRecvError(c) ifce.reloadAcceptRecvError(c)
ifce.reloadEcn(c)
handshakeManager.f = ifce handshakeManager.f = ifce
go handshakeManager.Run(ctx) go handshakeManager.Run(ctx)
@@ -271,6 +273,53 @@ func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, dev
}, nil }, nil
} }
// parseCpuAffinity reads `tun.cpu_affinity` from the config — a list of
// integer CPU IDs, one per TUN reader goroutine. Empty / unset returns nil
// (listenIn falls back to its default `i % NumCPU` pinning). Length
// mismatch with `routines` is a warning, not an error: shorter lists are
// modulo-cycled across queues, longer lists' tail is ignored. Invalid
// entries (non-integer, out of range) are also a warning and disable the
// override entirely so we don't silently pin to the wrong CPU.
func parseCpuAffinity(c *config.C, l *slog.Logger, routines int) []int {
raw := c.Get("tun.cpu_affinity")
if raw == nil {
return nil
}
rv, ok := raw.([]any)
if !ok {
l.Warn("tun.cpu_affinity must be a list of integers; ignoring", "value", raw)
return nil
}
nCPU := runtime.NumCPU()
cpus := make([]int, 0, len(rv))
for i, e := range rv {
var cpu int
switch v := e.(type) {
case int:
cpu = v
case int64:
cpu = int(v)
case float64:
cpu = int(v)
default:
l.Warn("tun.cpu_affinity entry not an integer; ignoring affinity",
"index", i, "value", e)
return nil
}
if cpu < 0 || cpu >= nCPU {
l.Warn("tun.cpu_affinity entry out of range; ignoring affinity",
"index", i, "cpu", cpu, "num_cpu", nCPU)
return nil
}
cpus = append(cpus, cpu)
}
if len(cpus) != routines {
l.Warn("tun.cpu_affinity length doesn't match routines; queues will modulo-cycle through the list",
"affinity_len", len(cpus), "routines", routines)
}
return cpus
}
func moduleVersion() string { func moduleVersion() string {
info, ok := debug.ReadBuildInfo() info, ok := debug.ReadBuildInfo()
if !ok { if !ok {

View File

@@ -13,6 +13,7 @@ import (
"github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/udp"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
) )
@@ -22,7 +23,7 @@ const (
var ErrOutOfWindow = errors.New("out of window packet") var ErrOutOfWindow = errors.New("out of window packet")
func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) { func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, meta udp.RxMeta) {
err := h.Parse(packet) err := h.Parse(packet)
if err != nil { if err != nil {
// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors // Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
@@ -135,7 +136,7 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
case header.Message: case header.Message:
switch h.Subtype { switch h.Subtype {
case header.MessageNone: case header.MessageNone:
f.handleOutsideMessagePacket(hostinfo, out, packet, fwPacket, nb, q, localCache) f.handleOutsideMessagePacket(hostinfo, out, packet, fwPacket, nb, q, localCache, meta)
default: default:
hostinfo.logger(f.l).Error("IsValidSubType was true, but unexpected message subtype seen", "from", via, "header", h) hostinfo.logger(f.l).Error("IsValidSubType was true, but unexpected message subtype seen", "from", via, "header", h)
return return
@@ -168,7 +169,7 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte,
} }
} }
func (f *Interface) handleOutsideRelayPacket(hostinfo *HostInfo, via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) { func (f *Interface) handleOutsideRelayPacket(hostinfo *HostInfo, via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, meta udp.RxMeta) {
// The entire body is sent as AD, not encrypted. // The entire body is sent as AD, not encrypted.
// The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value. // The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value.
// The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's // The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's
@@ -211,7 +212,7 @@ func (f *Interface) handleOutsideRelayPacket(hostinfo *HostInfo, via ViaSender,
relay: relay, relay: relay,
IsRelayed: true, IsRelayed: true,
} }
f.readOutsidePackets(via, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache) f.readOutsidePackets(via, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache, meta)
case ForwardingType: case ForwardingType:
// Find the target HostInfo relay object // Find the target HostInfo relay object
targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr) targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr)
@@ -512,7 +513,61 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
return out, nil return out, nil
} }
func (f *Interface) handleOutsideMessagePacket(hostinfo *HostInfo, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) { // 2-bit IP-level ECN codepoints (lower bits of IPv4 ToS / IPv6 TC).
const (
ecnNotECT = 0x00
ecnECT1 = 0x01
ecnECT0 = 0x02
ecnCE = 0x03
)
// applyOuterECN folds an outer CE mark from the underlay into the inner
// IP header per RFC 6040 normal mode. It mutates pkt[1] in place. Other
// codepoints are advisory only and leave the inner unchanged.
//
// Merge cases (outer × inner → action):
//
// outer != CE : no-op (inner is authoritative)
// outer == CE, inner Not-ECT : log; cannot propagate to a non-ECN host
// outer == CE, inner ECT/CE : rewrite inner ECN to CE
func applyOuterECN(pkt []byte, outerECN byte, hostinfo *HostInfo, l *slog.Logger) {
if outerECN&ecnCE != ecnCE || len(pkt) < 2 {
return
}
switch pkt[0] >> 4 {
case 4:
switch pkt[1] & 0x03 {
case ecnNotECT:
if l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(l).Debug("RFC 6040: outer CE on inner Not-ECT, leaving inner unchanged")
}
case ecnCE:
// Already CE.
default:
pkt[1] = (pkt[1] &^ 0x03) | ecnCE
}
case 6:
switch (pkt[1] >> 4) & 0x03 {
case ecnNotECT:
if l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(l).Debug("RFC 6040: outer CE on inner Not-ECT, leaving inner unchanged")
}
case ecnCE:
// Already CE.
default:
pkt[1] = (pkt[1] &^ 0x30) | (ecnCE << 4)
}
}
}
func (f *Interface) handleOutsideMessagePacket(hostinfo *HostInfo, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache, meta udp.RxMeta) {
// RFC 6040 normal-mode combine: fold any outer CE mark stamped by the
// underlay into the inner header before firewall + TUN write. Other
// outer codepoints are advisory only — we keep the inner unchanged.
if f.ecnEnabled.Load() {
applyOuterECN(out, meta.OuterECN, hostinfo, f.l)
}
err := newPacket(out, true, fwPacket) err := newPacket(out, true, fwPacket)
if err != nil { if err != nil {
hostinfo.logger(f.l).Warn("Error while validating inbound packet", hostinfo.logger(f.l).Warn("Error while validating inbound packet",

View File

@@ -14,20 +14,15 @@ type RxBatcher interface {
} }
type TxBatcher interface { type TxBatcher interface {
// Next returns a zero-length slice with slotCap capacity over the next unused // Reserve creates a pkt to borrow
// slot's backing bytes. The caller writes into the returned slice and then Reserve(sz int) []byte
// calls Commit with the final length and destination. Next returns nil when // Commit borrows pkt and records its destination plus the 2-bit
// the batch is full. // IP-level ECN codepoint to set on the outer (carrier) header. The
Next() []byte // caller must keep pkt valid until the next Flush. Pass 0 (Not-ECT)
// Commit records the slot just returned by Next as a packet of length n // to leave the outer ECN field unset.
// destined for dst. Commit(pkt []byte, dst netip.AddrPort, outerECN byte)
Commit(n int, dst netip.AddrPort) // Flush emits every queued packet via the underlying batch writer in
// Reset clears committed slots; backing storage is retained for reuse. // arrival order. Returns the first error observed. After Flush returns,
Reset() // borrowed payload slices may be recycled.
// Len returns the number of committed packets. Flush() error
Len() int
// Cap returns the maximum number of slots in the batch.
Cap() int
// Get returns the buffers needed to send the batch
Get() ([][]byte, []netip.AddrPort)
} }

View File

@@ -0,0 +1,163 @@
package batch
import (
"bytes"
"encoding/binary"
)
// flowKey identifies a transport flow by {src, dst, sport, dport, family}.
// Comparable, so map lookups and linear scans over the slot list stay tight.
// Shared by the TCP and UDP coalescers; each coalescer keeps its own
// openSlots map, so a TCP and UDP flow on the same 5-tuple-without-proto
// never alias.
type flowKey struct {
src, dst [16]byte
sport, dport uint16
isV6 bool
}
// initialSlots is the starting capacity of the slot pool. One flow per
// packet is the worst case so this matches a typical carrier-side
// recvmmsg batch on the encrypted UDP socket.
const initialSlots = 64
// parsedIP is the IP-level result of parseIPPrologue. The caller layers
// L4-specific parsing (TCP / UDP) on top.
type parsedIP struct {
fk flowKey
ipHdrLen int
// pkt is the original buffer trimmed to the IP-declared total length.
// Anything below the IP layer (transport parsers) should slice into
// pkt rather than the unbounded original.
pkt []byte
}
// parseIPPrologue extracts the IP-level fields the coalescers care about:
// IHL/payload length, version, src/dst addresses, and the L4 protocol byte.
// Returns ok=false for malformed input, IPv4 with options or fragmentation,
// or IPv6 with extension headers (all rejected by both coalescers in
// identical ways before this refactor).
//
// On success, p.pkt is len-trimmed to the IP-declared length so callers
// don't have to repeat the trim. wantProto is the IANA protocol number to
// require (6 for TCP, 17 for UDP); ok=false for any other value.
func parseIPPrologue(pkt []byte, wantProto byte) (parsedIP, bool) {
var p parsedIP
if len(pkt) < 20 {
return p, false
}
v := pkt[0] >> 4
switch v {
case 4:
ihl := int(pkt[0]&0x0f) * 4
if ihl != 20 {
return p, false
}
if pkt[9] != wantProto {
return p, false
}
// Reject actual fragmentation (MF or non-zero frag offset).
if binary.BigEndian.Uint16(pkt[6:8])&0x3fff != 0 {
return p, false
}
totalLen := int(binary.BigEndian.Uint16(pkt[2:4]))
if totalLen > len(pkt) || totalLen < ihl {
return p, false
}
p.ipHdrLen = 20
p.fk.isV6 = false
copy(p.fk.src[:4], pkt[12:16])
copy(p.fk.dst[:4], pkt[16:20])
p.pkt = pkt[:totalLen]
case 6:
if len(pkt) < 40 {
return p, false
}
if pkt[6] != wantProto {
return p, false
}
payloadLen := int(binary.BigEndian.Uint16(pkt[4:6]))
if 40+payloadLen > len(pkt) {
return p, false
}
p.ipHdrLen = 40
p.fk.isV6 = true
copy(p.fk.src[:], pkt[8:24])
copy(p.fk.dst[:], pkt[24:40])
p.pkt = pkt[:40+payloadLen]
default:
return p, false
}
return p, true
}
// ipHeadersMatch compares the IP portion of two packet header prefixes for
// byte-for-byte equality on every field that must be identical across
// coalesced segments. Size/IPID/IPCsum and the 2-bit IP-level ECN field are
// masked out — the appendPayload step merges CE into the seed.
//
// The transport (L4) portion of the header is checked separately by the
// per-protocol matcher.
func ipHeadersMatch(a, b []byte, isV6 bool) bool {
if isV6 {
// IPv6: byte 0 = version/TC[7:4], byte 1 = TC[3:0]/flow[19:16],
// bytes [2:4] = flow[15:0], [6:8] = next_hdr/hop, [8:40] = src+dst.
// ECN lives in TC[1:0] = byte 1 mask 0x30. Skip [4:6] payload_len.
if a[0] != b[0] {
return false
}
if a[1]&^0x30 != b[1]&^0x30 {
return false
}
if !bytes.Equal(a[2:4], b[2:4]) {
return false
}
if !bytes.Equal(a[6:40], b[6:40]) {
return false
}
return true
}
// IPv4: byte 0 = version/IHL, byte 1 = DSCP(6)|ECN(2),
// [6:10] flags/fragoff/TTL/proto, [12:20] src+dst.
// Skip [2:4] total len, [4:6] id, [10:12] csum.
if a[0] != b[0] {
return false
}
if a[1]&^0x03 != b[1]&^0x03 {
return false
}
if !bytes.Equal(a[6:10], b[6:10]) {
return false
}
if !bytes.Equal(a[12:20], b[12:20]) {
return false
}
return true
}
// mergeECNIntoSeed ORs the 2-bit IP-level ECN field of pkt's IP header
// onto the seed's IP header, so a CE mark on any coalesced segment
// propagates to the final superpacket. (CE is 0b11; ORing yields CE if
// any segment carried it.) Used by both TCP and UDP coalescers, so the
// invariant lives in one place.
func mergeECNIntoSeed(seedHdr, pktHdr []byte, isV6 bool) {
if isV6 {
seedHdr[1] |= pktHdr[1] & 0x30
} else {
seedHdr[1] |= pktHdr[1] & 0x03
}
}
// reserveFromBacking implements the Reserve half of the RxBatcher contract
// shared by TCP and UDP coalescers. The backing slice grows on demand;
// already-committed slices reference the old array and remain valid until
// Flush resets backing.
func reserveFromBacking(backing *[]byte, sz int) []byte {
if len(*backing)+sz > cap(*backing) {
newCap := max(cap(*backing)*2, sz)
*backing = make([]byte, 0, newCap)
}
start := len(*backing)
*backing = (*backing)[:start+sz]
return (*backing)[start : start+sz : start+sz]
}

View File

@@ -0,0 +1,133 @@
package batch
import (
"io"
)
// MultiCoalescer fans plaintext packets out to lane-specific batchers based
// on the IP/L4 protocol of the packet, sharing a single Reserve arena
// across lanes so the caller's allocation pattern is unchanged.
//
// Lanes are processed independently: the TCP coalescer only sees TCP, the
// UDP coalescer only sees UDP, and the passthrough lane handles everything
// else. Per-flow arrival order is preserved because a single 5-tuple only
// ever lands in one lane and each lane preserves its own slot order.
//
// Cross-lane order is NOT preserved across the TCP/UDP/passthrough split.
// This is acceptable because the carrier-side recvmmsg path already
// stable-sorts by (peer, message counter) before delivering plaintext
// here, so replay-window invariants are unaffected, and apps observe
// correct per-flow ordering — which is all the IP layer guarantees anyway.
// Do not "fix" this by interleaving lane outputs at flush time; that
// negates the entire point of coalescing (each lane needs to see runs of
// adjacent same-flow packets to coalesce them).
type MultiCoalescer struct {
tcp *TCPCoalescer
udp *UDPCoalescer
pt *Passthrough
// arena shared across all lanes so a single Reserve grows one backing
// slice; lane Commit calls borrow into this same arena.
backing []byte
}
// NewMultiCoalescer builds a multi-lane batcher. tcpEnabled lets the caller
// opt out of TCP coalescing (e.g. when the queue can't do TSO); udpEnabled
// likewise gates UDP coalescing (only enable when USO was negotiated).
// Either lane disabled redirects its traffic into the passthrough lane.
func NewMultiCoalescer(w io.Writer, tcpEnabled, udpEnabled bool) *MultiCoalescer {
m := &MultiCoalescer{
pt: NewPassthrough(w),
backing: make([]byte, 0, initialSlots*65535),
}
if tcpEnabled {
m.tcp = NewTCPCoalescer(w)
}
if udpEnabled {
m.udp = NewUDPCoalescer(w)
}
return m
}
func (m *MultiCoalescer) Reserve(sz int) []byte {
if len(m.backing)+sz > cap(m.backing) {
newCap := max(cap(m.backing)*2, sz)
m.backing = make([]byte, 0, newCap)
}
start := len(m.backing)
m.backing = m.backing[:start+sz]
return m.backing[start : start+sz : start+sz]
}
// Commit dispatches pkt to the appropriate lane based on IP version + L4
// proto. Borrowed slice contract is identical to the single-lane batchers
// — pkt must remain valid until the next Flush.
//
// On the success path the IP/TCP-or-UDP parse happens here once and the
// parsed struct is handed to the lane via commitParsed so the lane doesn't
// re-walk the header. On a parse failure we fall through to the lane's
// public Commit, which re-runs the parse before passthrough — that path
// only fires for malformed/unsupported packets so the duplicated parse is
// not on the hot path. The lane's public Commit still works for direct
// callers.
func (m *MultiCoalescer) Commit(pkt []byte) error {
if len(pkt) < 20 {
return m.pt.Commit(pkt)
}
v := pkt[0] >> 4
var proto byte
switch v {
case 4:
proto = pkt[9]
case 6:
if len(pkt) < 40 {
return m.pt.Commit(pkt)
}
proto = pkt[6]
default:
return m.pt.Commit(pkt)
}
switch proto {
case ipProtoTCP:
if m.tcp != nil {
info, ok := parseTCPBase(pkt)
if !ok {
// Malformed/unsupported TCP shape (IP options, fragments, ...)
// — the TCP lane handles this as passthrough.
return m.tcp.Commit(pkt)
}
return m.tcp.commitParsed(pkt, info)
}
case ipProtoUDP:
if m.udp != nil {
info, ok := parseUDP(pkt)
if !ok {
return m.udp.Commit(pkt)
}
return m.udp.commitParsed(pkt, info)
}
}
return m.pt.Commit(pkt)
}
// Flush drains every lane in a fixed order: TCP, UDP, passthrough. Errors
// from a lane do not stop subsequent lanes from flushing — we keep
// draining and return the first observed error so a single bad packet
// doesn't strand the others.
func (m *MultiCoalescer) Flush() error {
var first error
keep := func(err error) {
if err != nil && first == nil {
first = err
}
}
if m.tcp != nil {
keep(m.tcp.Flush())
}
if m.udp != nil {
keep(m.udp.Flush())
}
keep(m.pt.Flush())
m.backing = m.backing[:0]
return first
}

View File

@@ -0,0 +1,94 @@
package batch
import (
"testing"
)
// TestMultiCoalescerRoutesByProto confirms TCP/UDP/other land in the right
// lane: TCP and UDP get coalesced when their lanes are enabled, anything
// else (ICMP here) falls through to plain Write.
func TestMultiCoalescerRoutesByProto(t *testing.T) {
w := &fakeTunWriter{gsoEnabled: true}
m := NewMultiCoalescer(w, true, true)
tcpPay := make([]byte, 1200)
udpPay := make([]byte, 1200)
icmp := make([]byte, 28)
icmp[0] = 0x45
icmp[2] = 0
icmp[3] = 28
icmp[9] = 1
if err := m.Commit(buildTCPv4(1000, tcpAck, tcpPay)); err != nil {
t.Fatal(err)
}
if err := m.Commit(buildTCPv4(2200, tcpAck, tcpPay)); err != nil {
t.Fatal(err)
}
if err := m.Commit(buildUDPv4(2000, 53, udpPay)); err != nil {
t.Fatal(err)
}
if err := m.Commit(buildUDPv4(2000, 53, udpPay)); err != nil {
t.Fatal(err)
}
if err := m.Commit(icmp); err != nil {
t.Fatal(err)
}
if err := m.Flush(); err != nil {
t.Fatal(err)
}
// 1 TCP super (2 segments) + 1 UDP super (2 segments) = 2 gso writes.
if len(w.gsoWrites) != 2 {
t.Fatalf("want 2 gso writes (one TCP + one UDP), got %d", len(w.gsoWrites))
}
if len(w.writes) != 1 {
t.Fatalf("want 1 plain write (ICMP), got %d", len(w.writes))
}
}
// TestMultiCoalescerDisabledUDPFallsThrough verifies that when the UDP lane
// is disabled (e.g. kernel doesn't support USO), UDP packets still reach
// the kernel via the passthrough lane rather than being lost.
func TestMultiCoalescerDisabledUDPFallsThrough(t *testing.T) {
w := &fakeTunWriter{gsoEnabled: true}
m := NewMultiCoalescer(w, true, false) // TSO on, USO off
if err := m.Commit(buildUDPv4(1000, 53, make([]byte, 800))); err != nil {
t.Fatal(err)
}
if err := m.Commit(buildUDPv4(1000, 53, make([]byte, 800))); err != nil {
t.Fatal(err)
}
if err := m.Flush(); err != nil {
t.Fatal(err)
}
if len(w.gsoWrites) != 0 {
t.Errorf("UDP must NOT be coalesced when USO disabled, got %d gso writes", len(w.gsoWrites))
}
if len(w.writes) != 2 {
t.Errorf("UDP must pass through as 2 plain writes, got %d", len(w.writes))
}
}
// TestMultiCoalescerDisabledTCPFallsThrough mirrors the TSO=off case.
func TestMultiCoalescerDisabledTCPFallsThrough(t *testing.T) {
w := &fakeTunWriter{gsoEnabled: true}
m := NewMultiCoalescer(w, false, true) // TSO off, USO on
pay := make([]byte, 1200)
if err := m.Commit(buildTCPv4(1000, tcpAck, pay)); err != nil {
t.Fatal(err)
}
if err := m.Commit(buildTCPv4(2200, tcpAck, pay)); err != nil {
t.Fatal(err)
}
if err := m.Flush(); err != nil {
t.Fatal(err)
}
if len(w.gsoWrites) != 0 {
t.Errorf("TCP must NOT be coalesced when TSO disabled, got %d gso writes", len(w.gsoWrites))
}
if len(w.writes) != 2 {
t.Errorf("TCP must pass through as 2 plain writes, got %d", len(w.writes))
}
}

View File

@@ -0,0 +1,722 @@
package batch
import (
"bytes"
"encoding/binary"
"io"
"log/slog"
"net/netip"
"sort"
"github.com/slackhq/nebula/overlay/tio"
)
// ipProtoTCP is the IANA protocol number for TCP. Hardcoded instead of
// reaching for golang.org/x/sys/unix — that package doesn't define the
// constant on Windows, which would break cross-compiles even though this
// file runs unchanged on every platform.
const ipProtoTCP = 6
// tcpCoalesceBufSize caps total bytes per superpacket. Mirrors the kernel's
// sk_gso_max_size of ~64KiB; anything beyond this would be rejected anyway.
const tcpCoalesceBufSize = 65535
// tcpCoalesceMaxSegs caps how many segments we'll coalesce into a single
// superpacket. Keeping this well below the kernel's TSO ceiling bounds
// latency.
const tcpCoalesceMaxSegs = 64
// tcpCoalesceHdrCap is the scratch space we copy a seed's IP+TCP header
// into. IPv6 (40) + TCP with full options (60) = 100 bytes.
const tcpCoalesceHdrCap = 100
// coalesceSlot is one entry in the coalescer's ordered event queue. When
// passthrough is true the slot holds a single borrowed packet that must be
// emitted verbatim (non-TCP, non-admissible TCP, or oversize seed). When
// passthrough is false the slot is an in-progress coalesced superpacket:
// hdrBuf is a mutable copy of the seed's IP+TCP header (we patch total
// length and pseudo-header partial at flush), and payIovs are *borrowed*
// slices from the caller's plaintext buffers — no payload is ever copied.
// The caller (listenOut) must keep those buffers alive until Flush.
type coalesceSlot struct {
passthrough bool
rawPkt []byte // borrowed when passthrough
fk flowKey
hdrBuf [tcpCoalesceHdrCap]byte
hdrLen int
ipHdrLen int
isV6 bool
gsoSize int
numSeg int
totalPay int
nextSeq uint32
// psh closes the chain: set when the last-accepted segment had PSH or
// was sub-gsoSize. No further appends after that.
psh bool
payIovs [][]byte
}
// TCPCoalescer accumulates adjacent in-flow TCP data segments across
// multiple concurrent flows and emits each flow's run as a single TSO
// superpacket via tio.GSOWriter. All output — coalesced or not — is
// deferred until Flush so arrival order is preserved on the wire. Owns
// no locks; one coalescer per TUN write queue.
type TCPCoalescer struct {
plainW io.Writer
gsoW tio.GSOWriter // nil when the queue doesn't support TSO
// slots is the ordered event queue. Flush walks it once and emits each
// entry as either a WriteGSO (coalesced) or a plainW.Write (passthrough).
slots []*coalesceSlot
// openSlots maps a flow key to its most recent non-sealed slot, so new
// segments can extend an in-progress superpacket in O(1). Slots are
// removed from this map when they close (PSH or short-last-segment),
// when a non-admissible packet for that flow arrives, or in Flush.
openSlots map[flowKey]*coalesceSlot
// lastSlot caches the most recently touched open slot. Steady-state
// bulk traffic is dominated by a single flow, so comparing the
// incoming key against the cached slot's own fk lets the hot path
// skip the map lookup (and the aeshash of a 38-byte key) entirely.
// Kept in lockstep with openSlots: nil whenever the slot it pointed
// at is removed/sealed.
lastSlot *coalesceSlot
pool []*coalesceSlot // free list for reuse
backing []byte
}
func NewTCPCoalescer(w io.Writer) *TCPCoalescer {
c := &TCPCoalescer{
plainW: w,
slots: make([]*coalesceSlot, 0, initialSlots),
openSlots: make(map[flowKey]*coalesceSlot, initialSlots),
pool: make([]*coalesceSlot, 0, initialSlots),
backing: make([]byte, 0, initialSlots*65535),
}
if gw, ok := tio.SupportsGSO(w, tio.GSOProtoTCP); ok {
c.gsoW = gw
}
return c
}
// parsedTCP holds the fields extracted from a single parse so later steps
// (admission, slot lookup, canAppend) don't re-walk the header.
type parsedTCP struct {
fk flowKey
ipHdrLen int
tcpHdrLen int
hdrLen int
payLen int
seq uint32
flags byte
}
// parseTCPBase extracts the flow key and IP/TCP offsets for any TCP packet,
// regardless of whether it's admissible for coalescing. Returns ok=false
// for non-TCP or malformed input. Accepts IPv4 (no options, no fragmentation)
// and IPv6 (no extension headers).
func parseTCPBase(pkt []byte) (parsedTCP, bool) {
var p parsedTCP
ip, ok := parseIPPrologue(pkt, ipProtoTCP)
if !ok {
return p, false
}
pkt = ip.pkt
p.fk = ip.fk
p.ipHdrLen = ip.ipHdrLen
if len(pkt) < p.ipHdrLen+20 {
return p, false
}
tcpOff := int(pkt[p.ipHdrLen+12]>>4) * 4
if tcpOff < 20 || tcpOff > 60 {
return p, false
}
if len(pkt) < p.ipHdrLen+tcpOff {
return p, false
}
p.tcpHdrLen = tcpOff
p.hdrLen = p.ipHdrLen + tcpOff
p.payLen = len(pkt) - p.hdrLen
p.seq = binary.BigEndian.Uint32(pkt[p.ipHdrLen+4 : p.ipHdrLen+8])
p.flags = pkt[p.ipHdrLen+13]
p.fk.sport = binary.BigEndian.Uint16(pkt[p.ipHdrLen : p.ipHdrLen+2])
p.fk.dport = binary.BigEndian.Uint16(pkt[p.ipHdrLen+2 : p.ipHdrLen+4])
return p, true
}
// TCP flag bits (byte 13 of the TCP header). Only the bits actually consulted
// by the coalescer are named; FIN/SYN/RST/URG/CWR are rejected via the
// negative mask in coalesceable, not by name.
const (
tcpFlagPsh = 0x08
tcpFlagAck = 0x10
tcpFlagEce = 0x40
)
// coalesceable reports whether a parsed TCP segment is eligible for
// coalescing. Accepts ACK, ACK|PSH, ACK|ECE, ACK|PSH|ECE with a
// non-empty payload. CWR is excluded because it marks a one-shot
// congestion-window-reduced transition the receiver must observe at a
// segment boundary.
func (p parsedTCP) coalesceable() bool {
if p.flags&tcpFlagAck == 0 {
return false
}
if p.flags&^(tcpFlagAck|tcpFlagPsh|tcpFlagEce) != 0 {
return false
}
return p.payLen > 0
}
func (c *TCPCoalescer) Reserve(sz int) []byte {
return reserveFromBacking(&c.backing, sz)
}
// Commit borrows pkt. The caller must keep pkt valid until the next Flush,
// whether or not the packet was coalesced — passthrough (non-admissible)
// packets are queued and written at Flush time, not synchronously.
func (c *TCPCoalescer) Commit(pkt []byte) error {
if c.gsoW == nil {
c.addPassthrough(pkt)
return nil
}
info, ok := parseTCPBase(pkt)
if !ok {
c.addPassthrough(pkt)
return nil
}
return c.commitParsed(pkt, info)
}
// commitParsed is the post-parse half of Commit. The caller must have
// already verified parseTCPBase succeeded (info is a valid TCP parse).
// Used by MultiCoalescer.Commit to avoid re-walking the IP/TCP header
// after the dispatcher has already done so.
func (c *TCPCoalescer) commitParsed(pkt []byte, info parsedTCP) error {
if c.gsoW == nil {
c.addPassthrough(pkt)
return nil
}
if !info.coalesceable() {
// TCP but not admissible (SYN/FIN/RST/URG/CWR or zero-payload).
// Seal this flow's open slot so later in-flow packets don't extend
// it and accidentally reorder past this passthrough.
if last := c.lastSlot; last != nil && last.fk == info.fk {
c.lastSlot = nil
}
delete(c.openSlots, info.fk)
c.addPassthrough(pkt)
return nil
}
// Single-flow fast path: with only one open flow the cache hits every
// packet, and len(openSlots)==1 lets us skip the 38-byte fk compare
// when there are multiple flows in flight (where the hit rate would
// be ~0 and the compare is pure overhead).
var open *coalesceSlot
if last := c.lastSlot; last != nil && len(c.openSlots) == 1 && last.fk == info.fk {
open = last
} else {
open = c.openSlots[info.fk]
}
if open != nil {
if c.canAppend(open, pkt, info) {
c.appendPayload(open, pkt, info)
if open.psh {
delete(c.openSlots, info.fk)
c.lastSlot = nil
} else {
c.lastSlot = open
}
return nil
}
// Can't extend — seal it and fall through to seed a fresh slot.
delete(c.openSlots, info.fk)
if c.lastSlot == open {
c.lastSlot = nil
}
}
c.seed(pkt, info)
return nil
}
// Flush emits every queued event in (per-flow) seq order. Coalesced slots
// go out via WriteGSO; passthrough slots go out via plainW.Write.
// reorderForFlush first sorts each flow's slots into TCP-seq order within
// passthrough-bounded segments and merges contiguous adjacent slots, so
// any wire-side reorder that crossed an rxOrder batch boundary doesn't
// get amplified into kernel-visible reorder by the slot machinery.
// Returns the first error observed; keeps draining so one bad packet
// doesn't hold up the rest. After Flush returns, borrowed payload slices
// may be recycled.
func (c *TCPCoalescer) Flush() error {
c.reorderForFlush()
var first error
for _, s := range c.slots {
var err error
if s.passthrough {
_, err = c.plainW.Write(s.rawPkt)
} else {
err = c.flushSlot(s)
}
if err != nil && first == nil {
first = err
}
c.release(s)
}
for i := range c.slots {
c.slots[i] = nil
}
c.slots = c.slots[:0]
for k := range c.openSlots {
delete(c.openSlots, k)
}
c.lastSlot = nil
c.backing = c.backing[:0]
return first
}
func (c *TCPCoalescer) addPassthrough(pkt []byte) {
s := c.take()
s.passthrough = true
s.rawPkt = pkt
c.slots = append(c.slots, s)
}
func (c *TCPCoalescer) seed(pkt []byte, info parsedTCP) {
if info.hdrLen > tcpCoalesceHdrCap || info.hdrLen+info.payLen > tcpCoalesceBufSize {
// Pathological shape — can't fit our scratch, emit as-is.
c.addPassthrough(pkt)
return
}
s := c.take()
s.passthrough = false
s.rawPkt = nil
copy(s.hdrBuf[:], pkt[:info.hdrLen])
s.hdrLen = info.hdrLen
s.ipHdrLen = info.ipHdrLen
s.isV6 = info.fk.isV6
s.fk = info.fk
s.gsoSize = info.payLen
s.numSeg = 1
s.totalPay = info.payLen
s.nextSeq = info.seq + uint32(info.payLen)
s.psh = info.flags&tcpFlagPsh != 0
s.payIovs = append(s.payIovs[:0], pkt[info.hdrLen:info.hdrLen+info.payLen])
c.slots = append(c.slots, s)
if !s.psh {
c.openSlots[info.fk] = s
c.lastSlot = s
} else if last := c.lastSlot; last != nil && last.fk == info.fk {
// PSH-on-seed seals the slot immediately. Any prior cached open
// slot for this flow has just been sealed-and-replaced by this
// passthrough-shaped seed, so drop the cache too.
c.lastSlot = nil
}
}
// canAppend reports whether info's packet extends the slot's seed: same
// header shape and stable contents, adjacent seq, not oversized, chain not
// closed.
func (c *TCPCoalescer) canAppend(s *coalesceSlot, pkt []byte, info parsedTCP) bool {
if s.psh {
return false
}
if info.hdrLen != s.hdrLen {
return false
}
if info.seq != s.nextSeq {
return false
}
if s.numSeg >= tcpCoalesceMaxSegs {
return false
}
if info.payLen > s.gsoSize {
return false
}
if s.hdrLen+s.totalPay+info.payLen > tcpCoalesceBufSize {
return false
}
// ECE state must be stable across a burst — receivers expect the
// flag set on every segment of a CE-echoing window or none.
seedFlags := s.hdrBuf[s.ipHdrLen+13]
if (seedFlags^info.flags)&tcpFlagEce != 0 {
return false
}
if !headersMatch(s.hdrBuf[:s.hdrLen], pkt[:info.hdrLen], s.isV6, s.ipHdrLen) {
return false
}
return true
}
func (c *TCPCoalescer) appendPayload(s *coalesceSlot, pkt []byte, info parsedTCP) {
s.payIovs = append(s.payIovs, pkt[info.hdrLen:info.hdrLen+info.payLen])
s.numSeg++
s.totalPay += info.payLen
s.nextSeq = info.seq + uint32(info.payLen)
if info.flags&tcpFlagPsh != 0 {
// Propagate PSH into the seed header so kernel TSO sets it on the
// last segment. Without this the sender's push signal is dropped.
s.hdrBuf[s.ipHdrLen+13] |= tcpFlagPsh
}
// Merge IP-level CE marks into the seed: headersMatch ignores ECN, so
// this is the one place the signal is preserved.
mergeECNIntoSeed(s.hdrBuf[:s.ipHdrLen], pkt[:s.ipHdrLen], s.isV6)
if info.payLen < s.gsoSize || info.flags&tcpFlagPsh != 0 {
s.psh = true
}
}
func (c *TCPCoalescer) take() *coalesceSlot {
if n := len(c.pool); n > 0 {
s := c.pool[n-1]
c.pool[n-1] = nil
c.pool = c.pool[:n-1]
return s
}
return &coalesceSlot{}
}
func (c *TCPCoalescer) release(s *coalesceSlot) {
s.passthrough = false
s.rawPkt = nil
for i := range s.payIovs {
s.payIovs[i] = nil
}
s.payIovs = s.payIovs[:0]
s.numSeg = 0
s.totalPay = 0
s.psh = false
c.pool = append(c.pool, s)
}
// flushSlot patches the header and calls WriteGSO. Does not remove the
// slot from c.slots.
func (c *TCPCoalescer) flushSlot(s *coalesceSlot) error {
total := s.hdrLen + s.totalPay
l4Len := total - s.ipHdrLen
hdr := s.hdrBuf[:s.hdrLen]
if s.isV6 {
binary.BigEndian.PutUint16(hdr[4:6], uint16(l4Len))
} else {
binary.BigEndian.PutUint16(hdr[2:4], uint16(total))
hdr[10] = 0
hdr[11] = 0
binary.BigEndian.PutUint16(hdr[10:12], ipv4HdrChecksum(hdr[:s.ipHdrLen]))
}
var psum uint32
if s.isV6 {
psum = pseudoSumIPv6(hdr[8:24], hdr[24:40], ipProtoTCP, l4Len)
} else {
psum = pseudoSumIPv4(hdr[12:16], hdr[16:20], ipProtoTCP, l4Len)
}
tcsum := s.ipHdrLen + 16
binary.BigEndian.PutUint16(hdr[tcsum:tcsum+2], foldOnceNoInvert(psum))
return c.gsoW.WriteGSO(hdr[:s.ipHdrLen], hdr[s.ipHdrLen:], s.payIovs, tio.GSOProtoTCP)
}
// headersMatch compares two IP+TCP header prefixes for byte-for-byte
// equality on every field that must be identical across coalesced
// segments. Size/IPID/IPCsum/seq/flags/tcpCsum are masked out, as is the
// 2-bit IP-level ECN field — appendPayload merges CE into the seed.
func headersMatch(a, b []byte, isV6 bool, ipHdrLen int) bool {
if len(a) != len(b) {
return false
}
if !ipHeadersMatch(a, b, isV6) {
return false
}
// TCP: compare [0:4] ports, [8:13] ack+dataoff, [14:16] window,
// [18:tcpHdrLen] options (incl. urgent).
tcp := ipHdrLen
if !bytes.Equal(a[tcp:tcp+4], b[tcp:tcp+4]) {
return false
}
if !bytes.Equal(a[tcp+8:tcp+13], b[tcp+8:tcp+13]) {
return false
}
if !bytes.Equal(a[tcp+14:tcp+16], b[tcp+14:tcp+16]) {
return false
}
if !bytes.Equal(a[tcp+18:], b[tcp+18:]) {
return false
}
return true
}
// reorderForFlush neutralizes wire-side reorder that the rxOrder buffer
// couldn't catch (anything crossing a recvmmsg batch boundary). Without
// this pass a small wire reorder — counter 250 arriving in batch K when
// 200..249 are coming in batch K+1 — would seed an out-of-seq slot first
// and emit it ahead of the lower-seq slot, manifesting at the inner TCP
// receiver as a much larger reorder than the wire actually had.
//
// Two phases:
// 1. Sort each passthrough-bounded segment of c.slots by (flow, seq).
// Cross-flow ordering inside a segment isn't preserved (it never was
// and doesn't matter for any single flow's TCP correctness).
// 2. Sweep once and merge adjacent same-flow slots whose ranges are now
// contiguous AND whose tail is gsoSize-aligned. The tail constraint
// matters because the kernel TSO splitter chops at gsoSize from the
// start of the merged payload — a short segment in the middle would
// desynchronize every later segment.
//
// Passthrough slots act as barriers: the merge check skips them on either
// side, so a SYN/FIN/RST/CWR is never reordered relative to its flow's
// data.
func (c *TCPCoalescer) reorderForFlush() {
if len(c.slots) <= 1 {
return
}
runStart := 0
for i := 0; i <= len(c.slots); i++ {
if i < len(c.slots) && !c.slots[i].passthrough {
continue
}
c.sortRun(c.slots[runStart:i])
runStart = i + 1
}
out := c.slots[:0]
logged := false
for _, s := range c.slots {
if n := len(out); n > 0 {
prev := out[n-1]
if !prev.passthrough && !s.passthrough && prev.fk == s.fk {
// Same-flow neighbors after sort. If they aren't seq-
// contiguous it's a real gap — packets the wire reordered
// across batches, or actual loss before nebula. Log it so
// the operator can quantify how often it happens; the data
// itself still emits in seq order, kernel TCP handles the
// gap via its OOO queue.
if prev.nextSeq != slotSeedSeq(s) {
logged = true
gap := int64(slotSeedSeq(s)) - int64(prev.nextSeq)
slog.Default().Warn("tcp coalesce: cross-slot seq gap",
"src", flowKeyAddr(s.fk, false),
"dst", flowKeyAddr(s.fk, true),
"sport", s.fk.sport,
"dport", s.fk.dport,
"prev_seed_seq", slotSeedSeq(prev),
"prev_next_seq", prev.nextSeq,
"this_seed_seq", slotSeedSeq(s),
"gap_bytes", gap,
"prev_seg_count", prev.numSeg,
"prev_total_pay", prev.totalPay,
)
}
if canMergeSlots(prev, s) {
mergeSlots(prev, s)
c.release(s)
continue
}
}
}
out = append(out, s)
}
if logged {
slog.Default().Warn("==== end of batch ====")
}
c.slots = out
}
// flowKeyAddr returns the src or dst address from fk as a netip.Addr for
// logging. Only used on the cold gap-log path so the netip allocation
// doesn't matter.
func flowKeyAddr(fk flowKey, dst bool) netip.Addr {
src := fk.src
if dst {
src = fk.dst
}
if fk.isV6 {
return netip.AddrFrom16(src)
}
var v4 [4]byte
copy(v4[:], src[:4])
return netip.AddrFrom4(v4)
}
// sortRun stable-sorts run by (flowKey, seedSeq) so each flow's slots
// cluster together in seq order, ready for the merge sweep. Stable so
// equal-key slots keep their original relative position (defensive — a
// duplicate seedSeq would already mean something's wrong upstream).
func (c *TCPCoalescer) sortRun(run []*coalesceSlot) {
if len(run) <= 1 {
return
}
sort.SliceStable(run, func(i, j int) bool {
a, b := run[i], run[j]
if cmp := flowKeyCompare(a.fk, b.fk); cmp != 0 {
return cmp < 0
}
return tcpSeqLess(slotSeedSeq(a), slotSeedSeq(b))
})
}
// slotSeedSeq returns the TCP seq of the slot's seed (first segment).
// nextSeq tracks the seq just past the last appended byte; subtracting
// totalPay walks back to the seed. uint32 wraparound is the right TCP
// arithmetic so no special-casing is needed.
func slotSeedSeq(s *coalesceSlot) uint32 {
return s.nextSeq - uint32(s.totalPay)
}
// tcpSeqLess reports whether a precedes b in TCP serial-number arithmetic
// (RFC 1323 §2.3). The signed int32 cast turns the modular subtraction
// into the right comparison even across the 2^32 wrap.
func tcpSeqLess(a, b uint32) bool {
return int32(a-b) < 0
}
// flowKeyCompare orders flowKeys deterministically. The exact ordering
// is irrelevant — only that same-flow slots cluster together so the
// post-sort sweep can merge contiguous pairs.
func flowKeyCompare(a, b flowKey) int {
if c := bytes.Compare(a.src[:], b.src[:]); c != 0 {
return c
}
if c := bytes.Compare(a.dst[:], b.dst[:]); c != 0 {
return c
}
if a.sport != b.sport {
if a.sport < b.sport {
return -1
}
return 1
}
if a.dport != b.dport {
if a.dport < b.dport {
return -1
}
return 1
}
if a.isV6 != b.isV6 {
if !a.isV6 {
return -1
}
return 1
}
return 0
}
// canMergeSlots reports whether s can fold into prev as one merged TSO
// superpacket. Same flow, contiguous TCP byte range, equal gsoSize, and
// fits within the kernel TSO limits. The tail-of-prev check rejects any
// merge whose first slot ended on a sub-gsoSize segment — kernel TSO
// would split the merged skb at gsoSize boundaries from the start, so a
// short segment in the middle would corrupt every later segment. PSH and
// ECE state must agree across both slots: PSH is a semantic delimiter
// (preserving the sender's push boundary) and ECE state must be uniform
// across a window (the same rule canAppend enforces for in-flow appends).
//
// Note: a slot sealed by reorder (canAppend returned false on seq
// mismatch) keeps psh=false, so this restriction does not block the
// reorder-fix merge — only legitimate PSH-set seals.
func canMergeSlots(prev, s *coalesceSlot) bool {
if prev.psh {
return false
}
if prev.fk != s.fk {
return false
}
if prev.gsoSize != s.gsoSize {
return false
}
if prev.nextSeq != slotSeedSeq(s) {
return false
}
if prev.numSeg+s.numSeg > tcpCoalesceMaxSegs {
return false
}
if prev.hdrLen+prev.totalPay+s.totalPay > tcpCoalesceBufSize {
return false
}
if len(prev.payIovs[len(prev.payIovs)-1]) != prev.gsoSize {
return false
}
prevFlags := prev.hdrBuf[prev.ipHdrLen+13]
sFlags := s.hdrBuf[s.ipHdrLen+13]
if (prevFlags^sFlags)&tcpFlagEce != 0 {
return false
}
if !headersMatch(prev.hdrBuf[:prev.hdrLen], s.hdrBuf[:s.hdrLen], prev.isV6, prev.ipHdrLen) {
return false
}
return true
}
// mergeSlots folds src into dst in place: payIovs concatenated, counters
// and totals updated, PSH and IP-level CE bits OR'd into the seed header
// so neither the push signal nor a CE mark is lost. The seed header's
// seq, gsoSize, and fk are unchanged. Caller is responsible for releasing
// src (it's no longer in c.slots after this call).
func mergeSlots(dst, src *coalesceSlot) {
dst.payIovs = append(dst.payIovs, src.payIovs...)
dst.numSeg += src.numSeg
dst.totalPay += src.totalPay
dst.nextSeq = src.nextSeq
if src.psh {
dst.psh = true
dst.hdrBuf[dst.ipHdrLen+13] |= tcpFlagPsh
}
mergeECNIntoSeed(dst.hdrBuf[:dst.ipHdrLen], src.hdrBuf[:src.ipHdrLen], dst.isV6)
}
// ipv4HdrChecksum computes the IPv4 header checksum over hdr (which must
// already have its checksum field zeroed) and returns the folded/inverted
// 16-bit value to store.
func ipv4HdrChecksum(hdr []byte) uint16 {
var sum uint32
for i := 0; i+1 < len(hdr); i += 2 {
sum += uint32(binary.BigEndian.Uint16(hdr[i : i+2]))
}
if len(hdr)%2 == 1 {
sum += uint32(hdr[len(hdr)-1]) << 8
}
for sum>>16 != 0 {
sum = (sum & 0xffff) + (sum >> 16)
}
return ^uint16(sum)
}
// pseudoSumIPv4 / pseudoSumIPv6 build the L4 pseudo-header partial sum
// expected by the virtio NEEDS_CSUM kernel path: the 32-bit accumulator
// before folding. proto selects the L4 (TCP or UDP); the UDP coalescer
// reuses these helpers.
func pseudoSumIPv4(src, dst []byte, proto byte, l4Len int) uint32 {
var sum uint32
sum += uint32(binary.BigEndian.Uint16(src[0:2]))
sum += uint32(binary.BigEndian.Uint16(src[2:4]))
sum += uint32(binary.BigEndian.Uint16(dst[0:2]))
sum += uint32(binary.BigEndian.Uint16(dst[2:4]))
sum += uint32(proto)
sum += uint32(l4Len)
return sum
}
func pseudoSumIPv6(src, dst []byte, proto byte, l4Len int) uint32 {
var sum uint32
for i := 0; i < 16; i += 2 {
sum += uint32(binary.BigEndian.Uint16(src[i : i+2]))
sum += uint32(binary.BigEndian.Uint16(dst[i : i+2]))
}
sum += uint32(l4Len >> 16)
sum += uint32(l4Len & 0xffff)
sum += uint32(proto)
return sum
}
// foldOnceNoInvert folds the 32-bit accumulator to 16 bits and returns it
// unchanged (no one's complement). This is what virtio NEEDS_CSUM wants in
// the L4 checksum field — the kernel will add the payload sum and invert.
func foldOnceNoInvert(sum uint32) uint16 {
for sum>>16 != 0 {
sum = (sum & 0xffff) + (sum >> 16)
}
return uint16(sum)
}

View File

@@ -0,0 +1,173 @@
package batch
import (
"encoding/binary"
"testing"
"github.com/slackhq/nebula/overlay/tio"
)
// nopTunWriter is a zero-alloc tio.GSOWriter for benchmarks. Discards
// everything but satisfies the interface the coalescer detects.
type nopTunWriter struct{}
func (nopTunWriter) Write(p []byte) (int, error) { return len(p), nil }
func (nopTunWriter) WriteGSO(hdr []byte, transportHdr []byte, pays [][]byte, _ tio.GSOProto) error {
return nil
}
func (nopTunWriter) Capabilities() tio.Capabilities {
return tio.Capabilities{TSO: true, USO: true}
}
// buildTCPv4BulkFlow returns a slice of N adjacent ACK-only TCP segments
// on a single 5-tuple, each carrying payloadLen bytes. Seq numbers are
// contiguous so every packet is coalesceable onto the previous one.
func buildTCPv4BulkFlow(n, payloadLen int) [][]byte {
pkts := make([][]byte, n)
pay := make([]byte, payloadLen)
seq := uint32(1000)
for i := range n {
pkts[i] = buildTCPv4(seq, tcpAck, pay)
seq += uint32(payloadLen)
}
return pkts
}
// buildTCPv4Interleaved returns nFlows * perFlow packets with per-flow
// seq continuity but round-robin across flows — worst case for any
// "last-slot" cache.
func buildTCPv4Interleaved(nFlows, perFlow, payloadLen int) [][]byte {
pay := make([]byte, payloadLen)
seqs := make([]uint32, nFlows)
for i := range seqs {
seqs[i] = uint32(1000 + i*1000000)
}
pkts := make([][]byte, 0, nFlows*perFlow)
for range perFlow {
for f := range nFlows {
sport := uint16(10000 + f)
pkts = append(pkts, buildTCPv4Ports(sport, 2000, seqs[f], tcpAck, pay))
seqs[f] += uint32(payloadLen)
}
}
return pkts
}
// buildICMPv4 returns a minimal non-TCP packet that takes the passthrough
// branch in Commit.
func buildICMPv4() []byte {
pkt := make([]byte, 28)
pkt[0] = 0x45
binary.BigEndian.PutUint16(pkt[2:4], 28)
pkt[9] = 1 // ICMP
copy(pkt[12:16], []byte{10, 0, 0, 1})
copy(pkt[16:20], []byte{10, 0, 0, 2})
return pkt
}
// runCommitBench drives Commit over pkts batchSize at a time, flushing
// between batches, and reports per-packet cost.
func runCommitBench(b *testing.B, pkts [][]byte, batchSize int) {
b.Helper()
c := NewTCPCoalescer(nopTunWriter{})
b.ReportAllocs()
b.SetBytes(int64(len(pkts[0])))
b.ResetTimer()
for i := 0; i < b.N; i++ {
pkt := pkts[i%len(pkts)]
if err := c.Commit(pkt); err != nil {
b.Fatal(err)
}
if (i+1)%batchSize == 0 {
if err := c.Flush(); err != nil {
b.Fatal(err)
}
}
}
// Drain any trailing partial batch so slot state doesn't leak across runs.
_ = c.Flush()
}
// BenchmarkCommitSingleFlow is the bulk-TCP steady state: one flow,
// contiguous seq, 1200-byte payloads. Every packet past the seed should
// append onto the open slot. This is the case we most care about.
func BenchmarkCommitSingleFlow(b *testing.B) {
pkts := buildTCPv4BulkFlow(tcpCoalesceMaxSegs, 1200)
runCommitBench(b, pkts, tcpCoalesceMaxSegs)
}
// BenchmarkCommitInterleaved4 has 4 concurrent bulk flows round-robined.
// A single-entry fast-path cache will miss on every packet; an N-way
// cache or map lookup carries the weight.
func BenchmarkCommitInterleaved4(b *testing.B) {
pkts := buildTCPv4Interleaved(4, tcpCoalesceMaxSegs, 1200)
runCommitBench(b, pkts, len(pkts))
}
// BenchmarkCommitInterleaved16 stresses the map at higher flow counts.
func BenchmarkCommitInterleaved16(b *testing.B) {
pkts := buildTCPv4Interleaved(16, tcpCoalesceMaxSegs, 1200)
runCommitBench(b, pkts, len(pkts))
}
// BenchmarkCommitPassthrough exercises the non-TCP branch: parseTCPBase
// bails early and addPassthrough is the only work.
func BenchmarkCommitPassthrough(b *testing.B) {
pkt := buildICMPv4()
pkts := make([][]byte, 64)
for i := range pkts {
pkts[i] = pkt
}
runCommitBench(b, pkts, 64)
}
// BenchmarkCommitNonCoalesceableTCP sends SYN|ACK packets on one flow.
// Each packet takes the "TCP but not admissible" branch which does a
// map delete + passthrough. Measures the seal-without-slot cost.
func BenchmarkCommitNonCoalesceableTCP(b *testing.B) {
pay := make([]byte, 0)
pkts := make([][]byte, 64)
for i := range pkts {
pkts[i] = buildTCPv4(uint32(1000+i), tcpSyn|tcpAck, pay)
}
runCommitBench(b, pkts, 64)
}
// runMultiCommitBench drives MultiCoalescer.Commit. The dispatcher does
// the IP/L4 parse once and passes the parsed struct to the lane, so this
// is the bench that shows the savings of skipping the lane's re-parse.
func runMultiCommitBench(b *testing.B, pkts [][]byte, batchSize int) {
b.Helper()
m := NewMultiCoalescer(nopTunWriter{}, true, true)
b.ReportAllocs()
b.SetBytes(int64(len(pkts[0])))
b.ResetTimer()
for i := 0; i < b.N; i++ {
pkt := pkts[i%len(pkts)]
if err := m.Commit(pkt); err != nil {
b.Fatal(err)
}
if (i+1)%batchSize == 0 {
if err := m.Flush(); err != nil {
b.Fatal(err)
}
}
}
_ = m.Flush()
}
// BenchmarkMultiCommitSingleFlow is the multi-lane analogue of
// BenchmarkCommitSingleFlow — same workload but routed through the
// dispatcher. The delta vs the single-lane bench measures dispatcher
// overhead.
func BenchmarkMultiCommitSingleFlow(b *testing.B) {
pkts := buildTCPv4BulkFlow(tcpCoalesceMaxSegs, 1200)
runMultiCommitBench(b, pkts, tcpCoalesceMaxSegs)
}
// BenchmarkMultiCommitInterleaved4 mirrors BenchmarkCommitInterleaved4
// through the dispatcher.
func BenchmarkMultiCommitInterleaved4(b *testing.B) {
pkts := buildTCPv4Interleaved(4, tcpCoalesceMaxSegs, 1200)
runMultiCommitBench(b, pkts, len(pkts))
}

File diff suppressed because it is too large Load Diff

View File

@@ -4,58 +4,63 @@ import "net/netip"
const SendBatchCap = 128 const SendBatchCap = 128
// SendBatch accumulates encrypted UDP packets for potential TX offloading. // batchWriter is the minimal subset of udp.Conn needed by SendBatch to flush.
type batchWriter interface {
WriteBatch(bufs [][]byte, addrs []netip.AddrPort, outerECNs []byte) error
}
// SendBatch accumulates encrypted UDP packets and flushes them via WriteBatch.
// One SendBatch is owned by each listenIn goroutine; no locking is needed. // One SendBatch is owned by each listenIn goroutine; no locking is needed.
// The backing storage holds up to batchCap packets of slotCap bytes each; // The backing arena grows on demand: when there isn't room for the next slot
// bufs and dsts are parallel slices of committed slots. // we allocate a fresh backing array. Already-committed slices keep referencing
// the old array and remain valid until Flush drops them.
type SendBatch struct { type SendBatch struct {
out batchWriter
bufs [][]byte bufs [][]byte
dsts []netip.AddrPort dsts []netip.AddrPort
ecns []byte
backing []byte backing []byte
slotCap int
batchCap int
nextSlot int
} }
func NewSendBatch(batchCap, slotCap int) *SendBatch { func NewSendBatch(out batchWriter, batchCap, slotCap int) *SendBatch {
return &SendBatch{ return &SendBatch{
out: out,
bufs: make([][]byte, 0, batchCap), bufs: make([][]byte, 0, batchCap),
dsts: make([]netip.AddrPort, 0, batchCap), dsts: make([]netip.AddrPort, 0, batchCap),
backing: make([]byte, batchCap*slotCap), ecns: make([]byte, 0, batchCap),
slotCap: slotCap, backing: make([]byte, 0, batchCap*slotCap),
batchCap: batchCap,
} }
} }
func (b *SendBatch) Next() []byte { func (b *SendBatch) Reserve(sz int) []byte {
if b.nextSlot >= b.batchCap { if len(b.backing)+sz > cap(b.backing) {
return nil // Grow: allocate a fresh backing. Already-committed slices still
// reference the old array and remain valid until Flush drops them.
newCap := max(cap(b.backing)*2, sz)
b.backing = make([]byte, 0, newCap)
} }
start := b.nextSlot * b.slotCap start := len(b.backing)
return b.backing[start : start : start+b.slotCap] //set len to 0 but cap to slotCap b.backing = b.backing[:start+sz]
return b.backing[start : start+sz : start+sz]
} }
func (b *SendBatch) Commit(n int, dst netip.AddrPort) { func (b *SendBatch) Commit(pkt []byte, dst netip.AddrPort, outerECN byte) {
start := b.nextSlot * b.slotCap b.bufs = append(b.bufs, pkt)
b.bufs = append(b.bufs, b.backing[start:start+n])
b.dsts = append(b.dsts, dst) b.dsts = append(b.dsts, dst)
b.nextSlot++ b.ecns = append(b.ecns, outerECN)
} }
func (b *SendBatch) Reset() { func (b *SendBatch) Flush() error {
var err error
if len(b.bufs) > 0 {
err = b.out.WriteBatch(b.bufs, b.dsts, b.ecns)
}
for i := range b.bufs {
b.bufs[i] = nil
}
b.bufs = b.bufs[:0] b.bufs = b.bufs[:0]
b.dsts = b.dsts[:0] b.dsts = b.dsts[:0]
b.nextSlot = 0 b.ecns = b.ecns[:0]
} b.backing = b.backing[:0]
return err
func (b *SendBatch) Len() int {
return len(b.bufs)
}
func (b *SendBatch) Cap() int {
return b.batchCap
}
func (b *SendBatch) Get() ([][]byte, []netip.AddrPort) {
return b.bufs, b.dsts
} }

View File

@@ -5,65 +5,120 @@ import (
"testing" "testing"
) )
func TestSendBatchBookkeeping(t *testing.T) { type fakeBatchWriter struct {
b := NewSendBatch(4, 32) bufs [][]byte
if b.Len() != 0 || b.Cap() != 4 { addrs []netip.AddrPort
t.Fatalf("fresh batch: len=%d cap=%d", b.Len(), b.Cap()) ecns []byte
}
func (w *fakeBatchWriter) WriteBatch(bufs [][]byte, addrs []netip.AddrPort, ecns []byte) error {
// Snapshot — SendBatch.Flush nils its slot pointers right after WriteBatch
// returns, so tests must capture data before that happens.
w.bufs = make([][]byte, len(bufs))
for i, b := range bufs {
cp := make([]byte, len(b))
copy(cp, b)
w.bufs[i] = cp
} }
w.addrs = append(w.addrs[:0], addrs...)
w.ecns = append(w.ecns[:0], ecns...)
return nil
}
func TestSendBatchReserveCommitFlush(t *testing.T) {
fw := &fakeBatchWriter{}
b := NewSendBatch(fw, 4, 32)
ap := netip.MustParseAddrPort("10.0.0.1:4242") ap := netip.MustParseAddrPort("10.0.0.1:4242")
for i := 0; i < 4; i++ { for i := 0; i < 4; i++ {
slot := b.Next() slot := b.Reserve(32)
if slot == nil { if cap(slot) != 32 {
t.Fatalf("slot %d: Next returned nil before cap", i) t.Fatalf("slot %d: cap=%d want 32", i, cap(slot))
} }
if cap(slot) != 32 || len(slot) != 0 { pkt := append(slot[:0], byte(i), byte(i+1), byte(i+2))
t.Fatalf("slot %d: got len=%d cap=%d want len=0 cap=32", i, len(slot), cap(slot)) b.Commit(pkt, ap, 0)
} }
// Write a marker byte. if err := b.Flush(); err != nil {
slot = append(slot, byte(i), byte(i+1), byte(i+2)) t.Fatalf("Flush: %v", err)
b.Commit(len(slot), ap)
} }
if b.Next() != nil { if len(fw.bufs) != 4 {
t.Fatalf("Next should return nil when full") t.Fatalf("WriteBatch got %d bufs want 4", len(fw.bufs))
} }
if b.Len() != 4 { for i, buf := range fw.bufs {
t.Fatalf("Len=%d want 4", b.Len())
}
for i, buf := range b.bufs {
if len(buf) != 3 || buf[0] != byte(i) { if len(buf) != 3 || buf[0] != byte(i) {
t.Errorf("buf %d: %x", i, buf) t.Errorf("buf %d: %x", i, buf)
} }
if b.dsts[i] != ap { if fw.addrs[i] != ap {
t.Errorf("dst %d: got %v want %v", i, b.dsts[i], ap) t.Errorf("addr %d: got %v want %v", i, fw.addrs[i], ap)
} }
} }
// Reset returns empty and Next works again. // Flush again with nothing committed — should be a no-op.
b.Reset() fw.bufs = nil
if b.Len() != 0 { if err := b.Flush(); err != nil {
t.Fatalf("after Reset Len=%d want 0", b.Len()) t.Fatalf("empty Flush: %v", err)
} }
slot := b.Next() if fw.bufs != nil {
if slot == nil || cap(slot) != 32 { t.Fatalf("empty Flush triggered WriteBatch")
t.Fatalf("after Reset Next nil or wrong cap: %v cap=%d", slot == nil, cap(slot)) }
// Reuse after Flush.
slot := b.Reserve(32)
if cap(slot) != 32 {
t.Fatalf("after Flush Reserve wrong cap: %d", cap(slot))
} }
} }
func TestSendBatchSlotsDoNotOverlap(t *testing.T) { func TestSendBatchSlotsDoNotOverlap(t *testing.T) {
b := NewSendBatch(3, 8) fw := &fakeBatchWriter{}
b := NewSendBatch(fw, 3, 8)
ap := netip.MustParseAddrPort("10.0.0.1:80") ap := netip.MustParseAddrPort("10.0.0.1:80")
// Fill three slots, each with its own sentinel byte.
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
s := b.Next() s := b.Reserve(8)
s = append(s, byte(0xA0+i), byte(0xB0+i)) pkt := append(s[:0], byte(0xA0+i), byte(0xB0+i))
b.Commit(len(s), ap) b.Commit(pkt, ap, 0)
}
if err := b.Flush(); err != nil {
t.Fatalf("Flush: %v", err)
} }
for i, buf := range b.bufs { for i, buf := range fw.bufs {
if buf[0] != byte(0xA0+i) || buf[1] != byte(0xB0+i) { if buf[0] != byte(0xA0+i) || buf[1] != byte(0xB0+i) {
t.Errorf("slot %d corrupted: %x", i, buf) t.Errorf("slot %d corrupted: %x", i, buf)
} }
} }
} }
func TestSendBatchGrowPreservesCommitted(t *testing.T) {
fw := &fakeBatchWriter{}
// Tiny initial backing forces a grow on the second Reserve.
b := NewSendBatch(fw, 1, 4)
ap := netip.MustParseAddrPort("10.0.0.1:80")
s1 := b.Reserve(4)
pkt1 := append(s1[:0], 0x11, 0x22, 0x33, 0x44)
b.Commit(pkt1, ap, 0)
s2 := b.Reserve(8) // exceeds remaining cap, triggers grow
pkt2 := append(s2[:0], 0xA, 0xB, 0xC, 0xD, 0xE)
b.Commit(pkt2, ap, 0)
// pkt1 must still be intact even though backing reallocated.
if pkt1[0] != 0x11 || pkt1[3] != 0x44 {
t.Fatalf("first packet corrupted by grow: %x", pkt1)
}
if err := b.Flush(); err != nil {
t.Fatalf("Flush: %v", err)
}
if len(fw.bufs) != 2 {
t.Fatalf("got %d bufs want 2", len(fw.bufs))
}
if fw.bufs[0][0] != 0x11 || fw.bufs[0][3] != 0x44 {
t.Errorf("first packet on the wire: %x", fw.bufs[0])
}
if fw.bufs[1][0] != 0xA || fw.bufs[1][4] != 0xE {
t.Errorf("second packet on the wire: %x", fw.bufs[1])
}
}

View File

@@ -0,0 +1,342 @@
package batch
import (
"encoding/binary"
"io"
"github.com/slackhq/nebula/overlay/tio"
)
// ipProtoUDP is the IANA protocol number for UDP.
const ipProtoUDP = 17
// udpCoalesceBufSize caps total bytes per UDP superpacket. Mirrors the
// kernel's gso_max_size; payloads beyond this are emitted as-is.
const udpCoalesceBufSize = 65535
// udpCoalesceMaxSegs caps how many segments we'll coalesce. Kernel UDP-GSO
// accepts up to 64 segments per skb (UDP_MAX_SEGMENTS); stay under that.
const udpCoalesceMaxSegs = 64
// udpCoalesceHdrCap is the scratch space we copy a seed's IP+UDP header
// into. IPv6 (40) + UDP (8) = 48; round up for safety.
const udpCoalesceHdrCap = 64
// udpSlot is one entry in the UDPCoalescer's ordered event queue. Same
// passthrough-vs-coalesced shape as the TCP coalescer's slot, but no
// seq/PSH/CWR bookkeeping — UDP segments only need 5-tuple + length
// matching to coalesce.
type udpSlot struct {
passthrough bool
rawPkt []byte // borrowed when passthrough
fk flowKey
hdrBuf [udpCoalesceHdrCap]byte
hdrLen int
ipHdrLen int
isV6 bool
gsoSize int // per-segment UDP payload length
numSeg int
totalPay int
// sealed closes the chain: set when a sub-gsoSize segment is appended
// (kernel UDP-GSO requires every segment but the last to be exactly
// gsoSize) or when limits are hit. No further appends after.
sealed bool
payIovs [][]byte
}
// UDPCoalescer accumulates adjacent in-flow UDP datagrams across multiple
// concurrent flows and emits each flow's run as a single GSO_UDP_L4
// superpacket via tio.GSOWriter. Falls back to per-packet writes when the
// underlying writer doesn't support USO.
//
// All output — coalesced or not — is deferred until Flush so per-flow
// arrival order is preserved on the wire. Cross-flow order is NOT preserved
// across the TCP/UDP/passthrough split when this coalescer runs alongside
// others — see multi_coalesce.go. Per-flow order is preserved because a
// single 5-tuple only ever lands in one lane and each lane preserves its
// own slot order.
//
// Owns no locks; one coalescer per TUN write queue.
type UDPCoalescer struct {
plainW io.Writer
gsoW tio.GSOWriter // nil when the queue can't accept GSO_UDP_L4
slots []*udpSlot
openSlots map[flowKey]*udpSlot
pool []*udpSlot
backing []byte
}
// NewUDPCoalescer wraps w. The caller is responsible for only constructing
// this when the underlying Queue's Capabilities advertise USO; otherwise
// the kernel may reject GSO_UDP_L4 writes. If w does not implement
// tio.GSOWriter at all (single-packet Queue), the coalescer degrades to
// plain Writes — same defensive shape as the TCP coalescer.
func NewUDPCoalescer(w io.Writer) *UDPCoalescer {
c := &UDPCoalescer{
plainW: w,
slots: make([]*udpSlot, 0, initialSlots),
openSlots: make(map[flowKey]*udpSlot, initialSlots),
pool: make([]*udpSlot, 0, initialSlots),
backing: make([]byte, 0, initialSlots*udpCoalesceBufSize),
}
if gw, ok := tio.SupportsGSO(w, tio.GSOProtoUDP); ok {
c.gsoW = gw
}
return c
}
// parsedUDP holds the fields extracted from a single parse so later steps
// (admission, slot lookup, canAppend) don't re-walk the header.
type parsedUDP struct {
fk flowKey
ipHdrLen int
hdrLen int // ipHdrLen + 8
payLen int
}
// parseUDP extracts the flow key and IP/UDP offsets for a UDP packet.
// Returns ok=false for non-UDP, malformed, or unsupported header shapes
// (IPv4 with options/fragmentation, IPv6 with extension headers).
func parseUDP(pkt []byte) (parsedUDP, bool) {
var p parsedUDP
ip, ok := parseIPPrologue(pkt, ipProtoUDP)
if !ok {
return p, false
}
pkt = ip.pkt
p.fk = ip.fk
p.ipHdrLen = ip.ipHdrLen
if len(pkt) < p.ipHdrLen+8 {
return p, false
}
p.hdrLen = p.ipHdrLen + 8
// UDP `length` field: must equal IP-derived length-of-UDP-header-plus-payload.
udpLen := int(binary.BigEndian.Uint16(pkt[p.ipHdrLen+4 : p.ipHdrLen+6]))
if udpLen < 8 || udpLen > len(pkt)-p.ipHdrLen {
return p, false
}
p.payLen = udpLen - 8
p.fk.sport = binary.BigEndian.Uint16(pkt[p.ipHdrLen : p.ipHdrLen+2])
p.fk.dport = binary.BigEndian.Uint16(pkt[p.ipHdrLen+2 : p.ipHdrLen+4])
return p, true
}
func (c *UDPCoalescer) Reserve(sz int) []byte {
return reserveFromBacking(&c.backing, sz)
}
// Commit borrows pkt. The caller must keep pkt valid until the next Flush.
func (c *UDPCoalescer) Commit(pkt []byte) error {
if c.gsoW == nil {
c.addPassthrough(pkt)
return nil
}
info, ok := parseUDP(pkt)
if !ok {
c.addPassthrough(pkt)
return nil
}
return c.commitParsed(pkt, info)
}
// commitParsed is the post-parse half of Commit. The caller must have
// already verified parseUDP succeeded. Used by MultiCoalescer.Commit to
// avoid re-walking the IP/UDP header.
func (c *UDPCoalescer) commitParsed(pkt []byte, info parsedUDP) error {
if c.gsoW == nil {
c.addPassthrough(pkt)
return nil
}
if open := c.openSlots[info.fk]; open != nil {
if c.canAppend(open, pkt, info) {
c.appendPayload(open, pkt, info)
if open.sealed {
delete(c.openSlots, info.fk)
}
return nil
}
// Can't extend — seal it and fall through to seed a fresh slot.
delete(c.openSlots, info.fk)
}
c.seed(pkt, info)
return nil
}
func (c *UDPCoalescer) Flush() error {
var first error
for _, s := range c.slots {
var err error
if s.passthrough {
_, err = c.plainW.Write(s.rawPkt)
} else {
err = c.flushSlot(s)
}
if err != nil && first == nil {
first = err
}
c.release(s)
}
for i := range c.slots {
c.slots[i] = nil
}
c.slots = c.slots[:0]
for k := range c.openSlots {
delete(c.openSlots, k)
}
c.backing = c.backing[:0]
return first
}
func (c *UDPCoalescer) addPassthrough(pkt []byte) {
s := c.take()
s.passthrough = true
s.rawPkt = pkt
c.slots = append(c.slots, s)
}
func (c *UDPCoalescer) seed(pkt []byte, info parsedUDP) {
if info.hdrLen > udpCoalesceHdrCap || info.hdrLen+info.payLen > udpCoalesceBufSize {
c.addPassthrough(pkt)
return
}
s := c.take()
s.passthrough = false
s.rawPkt = nil
copy(s.hdrBuf[:], pkt[:info.hdrLen])
s.hdrLen = info.hdrLen
s.ipHdrLen = info.ipHdrLen
s.isV6 = info.fk.isV6
s.fk = info.fk
s.gsoSize = info.payLen
s.numSeg = 1
s.totalPay = info.payLen
s.sealed = false
s.payIovs = append(s.payIovs[:0], pkt[info.hdrLen:info.hdrLen+info.payLen])
c.slots = append(c.slots, s)
c.openSlots[info.fk] = s
}
// canAppend reports whether info's packet extends the slot's seed.
// Kernel UDP-GSO requires every segment except possibly the last to be
// exactly gsoSize, and the last may be shorter (≤ gsoSize).
func (c *UDPCoalescer) canAppend(s *udpSlot, pkt []byte, info parsedUDP) bool {
if s.sealed {
return false
}
if info.hdrLen != s.hdrLen {
return false
}
if s.numSeg >= udpCoalesceMaxSegs {
return false
}
if info.payLen > s.gsoSize {
return false
}
if s.hdrLen+s.totalPay+info.payLen > udpCoalesceBufSize {
return false
}
if !udpHeadersMatch(s.hdrBuf[:s.hdrLen], pkt[:info.hdrLen], s.isV6, s.ipHdrLen) {
return false
}
return true
}
func (c *UDPCoalescer) appendPayload(s *udpSlot, pkt []byte, info parsedUDP) {
s.payIovs = append(s.payIovs, pkt[info.hdrLen:info.hdrLen+info.payLen])
s.numSeg++
s.totalPay += info.payLen
// Merge IP-level CE marks into the seed (same trick TCP coalescer uses).
mergeECNIntoSeed(s.hdrBuf[:s.ipHdrLen], pkt[:s.ipHdrLen], s.isV6)
if info.payLen < s.gsoSize {
// Last-segment-can-be-shorter: this seals the chain.
s.sealed = true
}
}
func (c *UDPCoalescer) take() *udpSlot {
if n := len(c.pool); n > 0 {
s := c.pool[n-1]
c.pool[n-1] = nil
c.pool = c.pool[:n-1]
return s
}
return &udpSlot{}
}
func (c *UDPCoalescer) release(s *udpSlot) {
s.passthrough = false
s.rawPkt = nil
for i := range s.payIovs {
s.payIovs[i] = nil
}
s.payIovs = s.payIovs[:0]
s.numSeg = 0
s.totalPay = 0
s.sealed = false
c.pool = append(c.pool, s)
}
// flushSlot patches the IP header total length / IPv6 payload length and
// the UDP length to the *total* across all coalesced segments, then seeds
// the UDP checksum field with the pseudo-header partial (single-fold, not
// inverted) per virtio NEEDS_CSUM. The kernel's ip_rcv_core (v4) and
// ip6_rcv_core (v6) trim the skb to those length fields, so per-segment
// values would silently drop everything but the first segment. The kernel
// then walks each segment in __udp_gso_segment, recomputing per-segment
// uh->len / iph->tot_len / IPv6 plen and adjusting the checksum via
// `check = csum16_add(csum16_sub(uh->check, uh->len), newlen)` — meaning
// our seed's uh->check must be consistent with the seed's uh->len, which
// is what passing the total to both pseudoSum and the UDP length field
// guarantees.
func (c *UDPCoalescer) flushSlot(s *udpSlot) error {
hdr := s.hdrBuf[:s.hdrLen]
total := s.hdrLen + s.totalPay // full IP+UDP+all_payloads bytes
l4Len := total - s.ipHdrLen // total UDP (8 + sum of payloads)
if s.isV6 {
binary.BigEndian.PutUint16(hdr[4:6], uint16(l4Len))
} else {
binary.BigEndian.PutUint16(hdr[2:4], uint16(total))
hdr[10] = 0
hdr[11] = 0
binary.BigEndian.PutUint16(hdr[10:12], ipv4HdrChecksum(hdr[:s.ipHdrLen]))
}
// UDP length field (offset 4 inside the UDP header) = total UDP size.
binary.BigEndian.PutUint16(hdr[s.ipHdrLen+4:s.ipHdrLen+6], uint16(l4Len))
var psum uint32
if s.isV6 {
psum = pseudoSumIPv6(hdr[8:24], hdr[24:40], ipProtoUDP, l4Len)
} else {
psum = pseudoSumIPv4(hdr[12:16], hdr[16:20], ipProtoUDP, l4Len)
}
udpCsumOff := s.ipHdrLen + 6
binary.BigEndian.PutUint16(hdr[udpCsumOff:udpCsumOff+2], foldOnceNoInvert(psum))
return c.gsoW.WriteGSO(hdr[:s.ipHdrLen], hdr[s.ipHdrLen:], s.payIovs, tio.GSOProtoUDP)
}
// udpHeadersMatch compares two IP+UDP header prefixes for byte-equality on
// every field that must be identical across coalesced segments. Length
// fields and the ECN bits in IP TOS/TC are masked out — appendPayload
// merges CE into the seed; flushSlot rewrites lengths.
func udpHeadersMatch(a, b []byte, isV6 bool, ipHdrLen int) bool {
if len(a) != len(b) {
return false
}
if !ipHeadersMatch(a, b, isV6) {
return false
}
// UDP: compare sport+dport ([0:4]). Skip length [4:6] and checksum [6:8] —
// length varies (we rewrite at flush) and the checksum will be redone.
udp := ipHdrLen
if a[udp] != b[udp] || a[udp+1] != b[udp+1] || a[udp+2] != b[udp+2] || a[udp+3] != b[udp+3] {
return false
}
return true
}

View File

@@ -0,0 +1,383 @@
package batch
import (
"encoding/binary"
"testing"
)
// buildUDPv4 builds a minimal IPv4+UDP packet with the given payload and ports.
func buildUDPv4(sport, dport uint16, payload []byte) []byte {
const ipHdrLen = 20
const udpHdrLen = 8
total := ipHdrLen + udpHdrLen + len(payload)
pkt := make([]byte, total)
pkt[0] = 0x45
pkt[1] = 0x00
binary.BigEndian.PutUint16(pkt[2:4], uint16(total))
binary.BigEndian.PutUint16(pkt[4:6], 0)
binary.BigEndian.PutUint16(pkt[6:8], 0x4000)
pkt[8] = 64
pkt[9] = ipProtoUDP
copy(pkt[12:16], []byte{10, 0, 0, 1})
copy(pkt[16:20], []byte{10, 0, 0, 2})
binary.BigEndian.PutUint16(pkt[20:22], sport)
binary.BigEndian.PutUint16(pkt[22:24], dport)
binary.BigEndian.PutUint16(pkt[24:26], uint16(udpHdrLen+len(payload)))
binary.BigEndian.PutUint16(pkt[26:28], 0)
copy(pkt[28:], payload)
return pkt
}
// buildUDPv6 builds a minimal IPv6+UDP packet.
func buildUDPv6(sport, dport uint16, payload []byte) []byte {
const ipHdrLen = 40
const udpHdrLen = 8
total := ipHdrLen + udpHdrLen + len(payload)
pkt := make([]byte, total)
pkt[0] = 0x60
binary.BigEndian.PutUint16(pkt[4:6], uint16(udpHdrLen+len(payload)))
pkt[6] = ipProtoUDP
pkt[7] = 64
pkt[8] = 0xfe
pkt[9] = 0x80
pkt[23] = 1
pkt[24] = 0xfe
pkt[25] = 0x80
pkt[39] = 2
binary.BigEndian.PutUint16(pkt[40:42], sport)
binary.BigEndian.PutUint16(pkt[42:44], dport)
binary.BigEndian.PutUint16(pkt[44:46], uint16(udpHdrLen+len(payload)))
binary.BigEndian.PutUint16(pkt[46:48], 0)
copy(pkt[48:], payload)
return pkt
}
func TestUDPCoalescerPassthroughWhenGSOUnavailable(t *testing.T) {
w := &fakeTunWriter{gsoEnabled: false}
c := NewUDPCoalescer(w)
pkt := buildUDPv4(1000, 53, make([]byte, 100))
if err := c.Commit(pkt); err != nil {
t.Fatal(err)
}
if len(w.writes) != 0 || len(w.gsoWrites) != 0 {
t.Fatalf("no Add-time writes: writes=%d gso=%d", len(w.writes), len(w.gsoWrites))
}
if err := c.Flush(); err != nil {
t.Fatal(err)
}
if len(w.writes) != 1 || len(w.gsoWrites) != 0 {
t.Fatalf("want single plain write, got writes=%d gso=%d", len(w.writes), len(w.gsoWrites))
}
}
func TestUDPCoalescerNonUDPPassthrough(t *testing.T) {
w := &fakeTunWriter{gsoEnabled: true}
c := NewUDPCoalescer(w)
// ICMP packet
pkt := make([]byte, 28)
pkt[0] = 0x45
binary.BigEndian.PutUint16(pkt[2:4], 28)
pkt[9] = 1
copy(pkt[12:16], []byte{10, 0, 0, 1})
copy(pkt[16:20], []byte{10, 0, 0, 2})
if err := c.Commit(pkt); err != nil {
t.Fatal(err)
}
if err := c.Flush(); err != nil {
t.Fatal(err)
}
if len(w.writes) != 1 || len(w.gsoWrites) != 0 {
t.Fatalf("ICMP must pass through unchanged: writes=%d gso=%d", len(w.writes), len(w.gsoWrites))
}
}
func TestUDPCoalescerSeedThenFlushAlone(t *testing.T) {
w := &fakeTunWriter{gsoEnabled: true}
c := NewUDPCoalescer(w)
pkt := buildUDPv4(1000, 53, make([]byte, 800))
if err := c.Commit(pkt); err != nil {
t.Fatal(err)
}
if err := c.Flush(); err != nil {
t.Fatal(err)
}
// Single-segment flush goes through WriteGSO; the writer infers GSO_NONE
// from len(pays)==1 and the kernel fills in the UDP csum (NEEDS_CSUM).
if len(w.gsoWrites) != 1 || len(w.writes) != 0 {
t.Fatalf("single-seg flush: writes=%d gso=%d", len(w.writes), len(w.gsoWrites))
}
}
func TestUDPCoalescerCoalescesEqualSized(t *testing.T) {
w := &fakeTunWriter{gsoEnabled: true}
c := NewUDPCoalescer(w)
pay := make([]byte, 1200)
for i := 0; i < 3; i++ {
if err := c.Commit(buildUDPv4(1000, 53, pay)); err != nil {
t.Fatal(err)
}
}
if err := c.Flush(); err != nil {
t.Fatal(err)
}
if len(w.gsoWrites) != 1 {
t.Fatalf("want 1 gso write, got %d (plain=%d)", len(w.gsoWrites), len(w.writes))
}
g := w.gsoWrites[0]
if g.gsoSize != 1200 {
t.Errorf("gsoSize=%d want 1200", g.gsoSize)
}
if len(g.pays) != 3 {
t.Errorf("pay count=%d want 3", len(g.pays))
}
if g.csumStart != 20 {
t.Errorf("csumStart=%d want 20", g.csumStart)
}
// IP totalLen and UDP length must be the TOTAL across all segments —
// the kernel's ip_rcv_core trims skbs to iph->tot_len, so a per-segment
// value would silently drop everything but the first segment. Total =
// IP(20) + UDP(8) + 3*1200 = 3628.
gotTotalLen := binary.BigEndian.Uint16(g.hdr[2:4])
if gotTotalLen != 3628 {
t.Errorf("ipv4 total_len=%d want 3628 (must be total across segments)", gotTotalLen)
}
gotUDPLen := binary.BigEndian.Uint16(g.hdr[20+4 : 20+6])
if gotUDPLen != 8+3*1200 {
t.Errorf("udp len=%d want %d", gotUDPLen, 8+3*1200)
}
}
// Last segment may be shorter, sealing the chain.
func TestUDPCoalescerShortLastSegmentSeals(t *testing.T) {
w := &fakeTunWriter{gsoEnabled: true}
c := NewUDPCoalescer(w)
full := make([]byte, 1200)
tail := make([]byte, 600)
if err := c.Commit(buildUDPv4(1000, 53, full)); err != nil {
t.Fatal(err)
}
if err := c.Commit(buildUDPv4(1000, 53, full)); err != nil {
t.Fatal(err)
}
if err := c.Commit(buildUDPv4(1000, 53, tail)); err != nil {
t.Fatal(err)
}
// A 4th packet, even same-sized, must NOT join — chain is sealed.
if err := c.Commit(buildUDPv4(1000, 53, full)); err != nil {
t.Fatal(err)
}
if err := c.Flush(); err != nil {
t.Fatal(err)
}
if len(w.gsoWrites) != 2 {
t.Fatalf("want 2 gso writes (sealed + new seed), got %d", len(w.gsoWrites))
}
if len(w.gsoWrites[0].pays) != 3 {
t.Errorf("first super: want 3 pays, got %d", len(w.gsoWrites[0].pays))
}
if len(w.gsoWrites[1].pays) != 1 {
t.Errorf("second super: want 1 pay (re-seed), got %d", len(w.gsoWrites[1].pays))
}
}
// A larger-than-gsoSize packet cannot extend the slot — it reseeds.
func TestUDPCoalescerLargerThanSeedReseeds(t *testing.T) {
w := &fakeTunWriter{gsoEnabled: true}
c := NewUDPCoalescer(w)
if err := c.Commit(buildUDPv4(1000, 53, make([]byte, 800))); err != nil {
t.Fatal(err)
}
if err := c.Commit(buildUDPv4(1000, 53, make([]byte, 1200))); err != nil {
t.Fatal(err)
}
if err := c.Flush(); err != nil {
t.Fatal(err)
}
if len(w.gsoWrites) != 2 {
t.Fatalf("want 2 separate seeds, got %d", len(w.gsoWrites))
}
}
// Different 5-tuples must not coalesce.
func TestUDPCoalescerDifferentFlowsKeepSeparate(t *testing.T) {
w := &fakeTunWriter{gsoEnabled: true}
c := NewUDPCoalescer(w)
pay := make([]byte, 800)
if err := c.Commit(buildUDPv4(1000, 53, pay)); err != nil {
t.Fatal(err)
}
if err := c.Commit(buildUDPv4(2000, 53, pay)); err != nil {
t.Fatal(err)
}
if err := c.Commit(buildUDPv4(1000, 53, pay)); err != nil {
t.Fatal(err)
}
if err := c.Commit(buildUDPv4(2000, 53, pay)); err != nil {
t.Fatal(err)
}
if err := c.Flush(); err != nil {
t.Fatal(err)
}
// Two flows × 2 datagrams each = 2 superpackets of 2 segments.
if len(w.gsoWrites) != 2 {
t.Fatalf("want 2 gso writes (one per flow), got %d", len(w.gsoWrites))
}
for i, g := range w.gsoWrites {
if len(g.pays) != 2 {
t.Errorf("super %d: want 2 pays, got %d", i, len(g.pays))
}
}
}
// Caps at udpCoalesceMaxSegs.
func TestUDPCoalescerCapsAtMaxSegs(t *testing.T) {
w := &fakeTunWriter{gsoEnabled: true}
c := NewUDPCoalescer(w)
pay := make([]byte, 100)
for i := 0; i < udpCoalesceMaxSegs+5; i++ {
if err := c.Commit(buildUDPv4(1000, 53, pay)); err != nil {
t.Fatal(err)
}
}
if err := c.Flush(); err != nil {
t.Fatal(err)
}
// First superpacket holds udpCoalesceMaxSegs segments; the spillover
// reseeds a new one.
if len(w.gsoWrites) != 2 {
t.Fatalf("want 2 gso writes (cap then reseed), got %d", len(w.gsoWrites))
}
if len(w.gsoWrites[0].pays) != udpCoalesceMaxSegs {
t.Errorf("first super: pays=%d want %d", len(w.gsoWrites[0].pays), udpCoalesceMaxSegs)
}
if len(w.gsoWrites[1].pays) != 5 {
t.Errorf("second super: pays=%d want 5", len(w.gsoWrites[1].pays))
}
}
// CE marks on appended segments must be merged into the seed's IP TOS.
func TestUDPCoalescerMergesCEMark(t *testing.T) {
w := &fakeTunWriter{gsoEnabled: true}
c := NewUDPCoalescer(w)
pay := make([]byte, 800)
pkt0 := buildUDPv4(1000, 53, pay) // ECN=00
pkt1 := buildUDPv4(1000, 53, pay)
pkt1[1] = 0x03 // CE
pkt2 := buildUDPv4(1000, 53, pay)
if err := c.Commit(pkt0); err != nil {
t.Fatal(err)
}
if err := c.Commit(pkt1); err != nil {
t.Fatal(err)
}
if err := c.Commit(pkt2); err != nil {
t.Fatal(err)
}
if err := c.Flush(); err != nil {
t.Fatal(err)
}
if len(w.gsoWrites) != 1 {
t.Fatalf("want 1 merged gso write, got %d (plain=%d)", len(w.gsoWrites), len(w.writes))
}
if w.gsoWrites[0].hdr[1]&0x03 != 0x03 {
t.Errorf("CE not merged into seed (tos=%#x)", w.gsoWrites[0].hdr[1])
}
}
// IPv6 path: same flow, equal-sized → coalesced.
func TestUDPCoalescerIPv6Coalesces(t *testing.T) {
w := &fakeTunWriter{gsoEnabled: true}
c := NewUDPCoalescer(w)
pay := make([]byte, 1200)
for i := 0; i < 3; i++ {
if err := c.Commit(buildUDPv6(1000, 53, pay)); err != nil {
t.Fatal(err)
}
}
if err := c.Flush(); err != nil {
t.Fatal(err)
}
if len(w.gsoWrites) != 1 {
t.Fatalf("want 1 gso write, got %d", len(w.gsoWrites))
}
g := w.gsoWrites[0]
if !g.isV6 {
t.Errorf("expected v6 write")
}
if g.csumStart != 40 {
t.Errorf("csumStart=%d want 40", g.csumStart)
}
// IPv6 payload_len and UDP length must be TOTAL — kernel's
// ip6_rcv_core trims to payload_len + ipv6 hdr size. Total UDP = 8 +
// 3*1200 = 3608.
gotPlen := binary.BigEndian.Uint16(g.hdr[4:6])
if gotPlen != 8+3*1200 {
t.Errorf("ipv6 payload_len=%d want %d (must be total)", gotPlen, 8+3*1200)
}
gotUDPLen := binary.BigEndian.Uint16(g.hdr[40+4 : 40+6])
if gotUDPLen != 8+3*1200 {
t.Errorf("udp len=%d want %d", gotUDPLen, 8+3*1200)
}
}
// DSCP differences must reseed (headers don't match outside ECN).
func TestUDPCoalescerDSCPMismatchReseeds(t *testing.T) {
w := &fakeTunWriter{gsoEnabled: true}
c := NewUDPCoalescer(w)
pay := make([]byte, 800)
pkt0 := buildUDPv4(1000, 53, pay)
pkt1 := buildUDPv4(1000, 53, pay)
pkt1[1] = 0xb8 // EF DSCP, ECN=0
if err := c.Commit(pkt0); err != nil {
t.Fatal(err)
}
if err := c.Commit(pkt1); err != nil {
t.Fatal(err)
}
if err := c.Flush(); err != nil {
t.Fatal(err)
}
if len(w.gsoWrites) != 2 {
t.Fatalf("want 2 separate seeds (different DSCP), got %d", len(w.gsoWrites))
}
}
// Fragmented IPv4 must not be coalesced.
func TestUDPCoalescerFragmentedIPv4PassesThrough(t *testing.T) {
w := &fakeTunWriter{gsoEnabled: true}
c := NewUDPCoalescer(w)
pkt := buildUDPv4(1000, 53, make([]byte, 200))
binary.BigEndian.PutUint16(pkt[6:8], 0x2000) // MF=1
if err := c.Commit(pkt); err != nil {
t.Fatal(err)
}
if err := c.Flush(); err != nil {
t.Fatal(err)
}
if len(w.writes) != 1 || len(w.gsoWrites) != 0 {
t.Fatalf("frag must pass through plain, got writes=%d gso=%d", len(w.writes), len(w.gsoWrites))
}
}
// IPv4 with options is not admissible (we require IHL=5).
func TestUDPCoalescerIPv4WithOptionsPassesThrough(t *testing.T) {
w := &fakeTunWriter{gsoEnabled: true}
c := NewUDPCoalescer(w)
pkt := buildUDPv4(1000, 53, make([]byte, 200))
pkt[0] = 0x46 // IHL = 6 (24-byte IPv4 header — has options)
if err := c.Commit(pkt); err != nil {
t.Fatal(err)
}
if err := c.Flush(); err != nil {
t.Fatal(err)
}
if len(w.writes) != 1 || len(w.gsoWrites) != 0 {
t.Fatalf("ipv4-with-options must pass through plain, got writes=%d gso=%d", len(w.writes), len(w.gsoWrites))
}
}

View File

@@ -18,7 +18,7 @@ type Device interface {
Networks() []netip.Prefix Networks() []netip.Prefix
Name() string Name() string
RoutesFor(netip.Addr) routing.Gateways RoutesFor(netip.Addr) routing.Gateways
SupportsMultiqueue() bool //todo remove? SupportsMultiqueue() bool
NewMultiQueueReader() error NewMultiQueueReader() error
Readers() []tio.Queue Readers() []tio.Queue
} }

View File

@@ -31,7 +31,7 @@ func (NoopTun) Name() string {
return "noop" return "noop"
} }
func (NoopTun) Read() ([][]byte, error) { func (NoopTun) Read() ([]tio.Packet, error) {
return nil, nil return nil, nil
} }

View File

@@ -0,0 +1,79 @@
package tio
import (
"encoding/binary"
"errors"
"fmt"
"golang.org/x/sys/unix"
)
type offloadQueueSet struct {
pq []*Offload
// pqi is exactly the same as pq, but stored as the interface type
pqi []Queue
shutdownFd int
// usoEnabled is true when newTun successfully negotiated TUN_F_USO4|6
// with the kernel. Queues created by Add inherit this and surface it
// via Offload.USOSupported so coalescers can gate USO emission.
usoEnabled bool
}
// NewOffloadQueueSet creates a QueueSet that uses virtio_net_hdr to do
// TSO segmentation in userspace. usoEnabled tells downstream queues whether
// the kernel agreed to deliver/accept GSO_UDP_L4 superpackets — coalescers
// should fall back to per-packet writes when this is false.
func NewOffloadQueueSet(usoEnabled bool) (QueueSet, error) {
shutdownFd, err := unix.Eventfd(0, unix.EFD_NONBLOCK|unix.EFD_CLOEXEC)
if err != nil {
return nil, fmt.Errorf("failed to create eventfd: %w", err)
}
out := &offloadQueueSet{
pq: []*Offload{},
pqi: []Queue{},
shutdownFd: shutdownFd,
usoEnabled: usoEnabled,
}
return out, nil
}
func (c *offloadQueueSet) Queues() []Queue {
return c.pqi
}
func (c *offloadQueueSet) Add(fd int) error {
x, err := newOffload(fd, c.shutdownFd, c.usoEnabled)
if err != nil {
return err
}
c.pq = append(c.pq, x)
c.pqi = append(c.pqi, x)
return nil
}
func (c *offloadQueueSet) wakeForShutdown() error {
var buf [8]byte
binary.NativeEndian.PutUint64(buf[:], 1)
_, err := unix.Write(c.shutdownFd, buf[:])
return err
}
func (c *offloadQueueSet) Close() error {
errs := []error{}
// Signal all readers blocked in poll to wake up and exit
if err := c.wakeForShutdown(); err != nil {
errs = append(errs, err)
}
for _, x := range c.pq {
if err := x.Close(); err != nil {
errs = append(errs, err)
}
}
return errors.Join(errs...)
}

View File

@@ -8,20 +8,20 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
type pollContainer struct { type pollQueueSet struct {
pq []*Poll pq []*Poll
// pqi is exactly the same as pq, but stored as the interface type // pqi is exactly the same as pq, but stored as the interface type
pqi []Queue pqi []Queue
shutdownFd int shutdownFd int
} }
func NewPollContainer() (Container, error) { func NewPollQueueSet() (QueueSet, error) {
shutdownFd, err := unix.Eventfd(0, unix.EFD_NONBLOCK|unix.EFD_CLOEXEC) shutdownFd, err := unix.Eventfd(0, unix.EFD_NONBLOCK|unix.EFD_CLOEXEC)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create eventfd: %w", err) return nil, fmt.Errorf("failed to create eventfd: %w", err)
} }
out := &pollContainer{ out := &pollQueueSet{
pq: []*Poll{}, pq: []*Poll{},
pqi: []Queue{}, pqi: []Queue{},
shutdownFd: shutdownFd, shutdownFd: shutdownFd,
@@ -30,11 +30,11 @@ func NewPollContainer() (Container, error) {
return out, nil return out, nil
} }
func (c *pollContainer) Queues() []Queue { func (c *pollQueueSet) Queues() []Queue {
return c.pqi return c.pqi
} }
func (c *pollContainer) Add(fd int) error { func (c *pollQueueSet) Add(fd int) error {
x, err := newPoll(fd, c.shutdownFd) x, err := newPoll(fd, c.shutdownFd)
if err != nil { if err != nil {
return err return err
@@ -45,14 +45,14 @@ func (c *pollContainer) Add(fd int) error {
return nil return nil
} }
func (c *pollContainer) wakeForShutdown() error { func (c *pollQueueSet) wakeForShutdown() error {
var buf [8]byte var buf [8]byte
binary.NativeEndian.PutUint64(buf[:], 1) binary.NativeEndian.PutUint64(buf[:], 1)
_, err := unix.Write(int(c.shutdownFd), buf[:]) _, err := unix.Write(int(c.shutdownFd), buf[:])
return err return err
} }
func (c *pollContainer) Close() error { func (c *pollQueueSet) Close() error {
errs := []error{} errs := []error{}
if err := c.wakeForShutdown(); err != nil { if err := c.wakeForShutdown(); err != nil {

View File

@@ -0,0 +1,65 @@
//go:build linux && !android && !e2e_testing
package tio
import "testing"
// fakeBatch stands in for batch.TxBatcher inside the bench — same shape
// of pointer-capturing closure that sendInsideMessage builds.
type fakeBatch struct{ buf [65536]byte }
func (b *fakeBatch) Reserve(sz int) []byte { return b.buf[:sz] }
func (b *fakeBatch) Commit([]byte) {}
type fakeHostInfo struct {
remoteIndexId uint32
counter uint64
}
type fakeIface struct {
rebindCount uint8
hi *fakeHostInfo
}
// BenchmarkSegmentSuperpacketAllocsTSO measures allocation per
// SegmentSuperpacket call when a closure captures pointer-bearing
// receivers — the realistic shape of sendInsideMessage's closure.
func BenchmarkSegmentSuperpacketAllocsTSO(b *testing.B) {
const mss = 1400
const numSeg = 32
pkt := buildTSOv6(mss*numSeg, mss)
gso := GSOInfo{
Size: mss,
HdrLen: 60, // 40 (IPv6) + 20 (TCP)
CsumStart: 40,
Proto: GSOProtoTCP,
}
p := Packet{Bytes: pkt, GSO: gso}
hi := &fakeHostInfo{remoteIndexId: 0xdeadbeef}
f := &fakeIface{rebindCount: 7, hi: hi}
fb := &fakeBatch{}
// SegmentSuperpacket consumes pkt destructively; refresh from a master
// copy each iter (matches the production pattern where every TUN read
// hands the segmenter a fresh kernel-supplied buffer).
master := append([]byte(nil), pkt...)
work := make([]byte, len(pkt))
p.Bytes = work
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
copy(work, master)
err := SegmentSuperpacket(p, func(seg []byte) error {
out := fb.Reserve(16 + len(seg) + 16)
out[0] = byte(f.rebindCount)
out[1] = byte(hi.counter)
hi.counter++
fb.Commit(out)
return nil
})
if err != nil {
b.Fatalf("SegmentSuperpacket: %v", err)
}
}
}

View File

@@ -0,0 +1,18 @@
//go:build !linux || android || e2e_testing
package tio
import "fmt"
// SegmentSuperpacket invokes fn once per segment of pkt. On non-Linux
// builds (and Android/e2e_testing) this package does not provide a Queue
// implementation, so any caller that does construct a Packet here can only
// be operating on non-superpacket bytes and the stub forwards them
// directly. A non-zero GSO field is a programming error from the caller
// and returns an explicit error rather than silently misbehaving.
func SegmentSuperpacket(pkt Packet, scratch []byte, fn func(seg []byte) error) error {
if pkt.GSO.IsSuperpacket() {
return fmt.Errorf("tio: GSO superpacket on platform without segmentation support")
}
return fn(pkt.Bytes)
}

View File

@@ -1,56 +1,170 @@
package tio package tio
import "io" import (
"io"
)
// defaultBatchBufSize is the per-Queue scratch size for Read on backends // QueueSet holds one or many Queue objects and helps close them in an orderly way.
// that don't do TSO segmentation. 65535 covers any single IP packet. type QueueSet interface {
const defaultBatchBufSize = 65535
// Container holds one or many Queue objects and helps close them in an orderly way
type Container interface {
io.Closer io.Closer
Queues() []Queue Queues() []Queue
// Add takes a tun fd, adds it to the container, and prepares it for use as a Queue // Add takes a tun fd, adds it to the set, and prepares it for use as a Queue.
Add(fd int) error Add(fd int) error
} }
// Capabilities advertises which kernel offload features a Queue
// successfully negotiated. Callers consult this to decide which coalescers
// to wire onto the write path — a Queue without TSO can't usefully accept a
// TCPCoalescer, and a Queue without USO can't accept a UDPCoalescer.
type Capabilities struct {
// TSO means the FD was opened with IFF_VNET_HDR and the kernel agreed
// to TUN_F_TSO4|TSO6 — i.e. WriteGSO with GSOProtoTCP is safe.
TSO bool
// USO means the kernel additionally agreed to TUN_F_USO4|USO6, so
// WriteGSO with GSOProtoUDP is safe. Linux ≥ 6.2.
USO bool
}
// Queue is a readable/writable Poll queue. One Queue is driven by a single // Queue is a readable/writable Poll queue. One Queue is driven by a single
// read goroutine plus concurrent writers (see Write / WriteReject below). // read goroutine plus a single writer (see Write below).
type Queue interface { type Queue interface {
io.Closer io.Closer
// Read returns one or more packets. The returned slices are borrowed // Read returns one or more packets. The returned Packet.Bytes slices
// from the Queue's internal buffer and are only valid until the next // are borrowed from the Queue's internal buffer and are only valid
// Read or Close on this Queue - callers must encrypt or copy each // until the next Read or Close on this Queue - callers must encrypt
// slice before the next call. Not safe for concurrent Reads. // or copy each slice before the next call. A Packet may carry a
Read() ([][]byte, error) // GSO/USO superpacket (see GSOInfo); when GSO.IsSuperpacket() is
// true the caller must segment Bytes before treating it as a single
// IP datagram. Not safe for concurrent Reads.
Read() ([]Packet, error)
// Write emits a single packet on the plaintext (outside→inside) // Write emits a single packet on the plaintext (outside→inside)
// delivery path. Not safe for concurrent Writes. // delivery path. Not safe for concurrent Writes.
Write(p []byte) (int, error) Write(p []byte) (int, error)
} }
// GSOWriter is implemented by Queues that can emit a TCP TSO superpacket // Packet is the unit Queue.Read returns. Bytes points into the queue's
// internal buffer and is only valid until the next Read or Close on the
// queue that produced it. GSO is the zero value for an already-segmented
// IP datagram; when non-zero it describes a kernel-supplied TSO/USO
// superpacket the caller must segment before consuming.
type Packet struct {
Bytes []byte
GSO GSOInfo
}
// GSOInfo describes a kernel-supplied superpacket sitting in Packet.Bytes.
// The zero value means "not a superpacket" — Bytes is one regular IP
// datagram and no segmentation is required.
type GSOInfo struct {
// Size is the GSO segment size: max payload bytes per segment
// (== TCP MSS for TSO, == UDP payload chunk for USO). Zero means
// not a superpacket.
Size uint16
// HdrLen is the total L3+L4 header length within Bytes (already
// corrected via correctHdrLen, so safe to slice on).
HdrLen uint16
// CsumStart is the L4 header offset inside Bytes (== L3 header
// length).
CsumStart uint16
// Proto picks the L4 protocol (TCP or UDP) so the segmenter knows
// which checksum/header layout to apply.
Proto GSOProto
}
// IsSuperpacket reports whether g describes a multi-segment GSO/USO
// superpacket that needs segmentation before its bytes can be encrypted
// and sent on the wire.
func (g GSOInfo) IsSuperpacket() bool { return g.Size > 0 }
// Clone returns a Packet whose Bytes is a freshly allocated copy of p.Bytes,
// safe to retain past the next Read or Close on the originating Queue.
// GSO metadata is copied verbatim. Use this only when a caller genuinely
// needs to outlive the borrowed-slice contract — the hot path reads should
// continue to consume the borrow synchronously to avoid the allocation.
func (p Packet) Clone() Packet {
if p.Bytes == nil {
return p
}
cp := make([]byte, len(p.Bytes))
copy(cp, p.Bytes)
return Packet{Bytes: cp, GSO: p.GSO}
}
// CapsProvider is an optional interface implemented by Queues that
// successfully negotiated kernel offload features at open time. Callers
// pick a write-path coalescer based on the result. Queues that don't
// implement it are treated as having no offload capability — callers must
// fall back to plain per-packet writes.
type CapsProvider interface {
Capabilities() Capabilities
}
// QueueCapabilities returns q's negotiated offload capabilities, or the
// zero value when q does not advertise any.
func QueueCapabilities(q Queue) Capabilities {
if cp, ok := q.(CapsProvider); ok {
return cp.Capabilities()
}
return Capabilities{}
}
// GSOProto selects the L4 protocol for a GSO superpacket. Determines which
// VIRTIO_NET_HDR_GSO_* type the writer stamps and which checksum offset
// inside the transport header virtio NEEDS_CSUM expects.
type GSOProto uint8
const (
GSOProtoTCP GSOProto = iota
GSOProtoUDP
)
// GSOWriter is implemented by Queues that can emit a TCP or UDP superpacket
// assembled from a header prefix plus one or more borrowed payload // assembled from a header prefix plus one or more borrowed payload
// fragments, in a single vectored write (writev with a leading // fragments, in a single vectored write (writev with a leading
// virtio_net_hdr). This lets the coalescer avoid copying payload bytes // virtio_net_hdr). This lets the coalescer avoid copying payload bytes
// between the caller's decrypt buffer and the TUN. Backends without GSO // between the caller's decrypt buffer and the TUN. Backends without GSO
// support return false from GSOSupported and coalescing is skipped. // support do not implement this interface and coalescing is skipped.
// //
// hdr contains the IPv4/IPv6 + TCP header prefix (mutable - callers will // hdr contains the IPv4/IPv6 header prefix (mutable - callers will have
// have filled in total length and pseudo-header partial). pays are // filled in total length and IP csum). transportHdr is the TCP or UDP
// non-overlapping payload fragments whose concatenation is the full // header (mutable - the L4 checksum field must hold the pseudo-header
// superpacket payload; they are read-only from the writer's perspective // partial, single-fold not inverted, per virtio NEEDS_CSUM semantics).
// and must remain valid until the call returns. gsoSize is the MSS: // pays are non-overlapping payload fragments whose concatenation is the
// every segment except possibly the last is exactly that many bytes. // full superpacket payload; they are read-only from the writer's
// csumStart is the byte offset where the TCP header begins within hdr. // perspective and must remain valid until the call returns. Every segment
// in pays except possibly the last is exactly the same size. proto picks
// the L4 protocol so the writer knows which GSOType / CsumOffset to set.
// //
// # TODO fold into Queue // Callers should also consult CapsProvider (via SupportsGSO or
// // QueueCapabilities) for the per-protocol negotiated capability; an
// hdr's TCP checksum field must already hold the pseudo-header partial // implementation of GSOWriter is necessary but not sufficient since USO
// sum (single-fold, not inverted), per virtio NEEDS_CSUM semantics. // may not have been negotiated even when TSO was.
type GSOWriter interface { type GSOWriter interface {
WriteGSO(hdr []byte, pays [][]byte, gsoSize uint16, isV6 bool, csumStart uint16) error WriteGSO(hdr []byte, transportHdr []byte, pays [][]byte, proto GSOProto) error
GSOSupported() bool }
// SupportsGSO reports whether w implements GSOWriter and the underlying
// queue advertises the negotiated capability for `want`. A writer that
// implements GSOWriter but not CapsProvider is treated as permissive
// (used by tests and fakes that don't negotiate).
func SupportsGSO(w any, want GSOProto) (GSOWriter, bool) {
gw, ok := w.(GSOWriter)
if !ok {
return nil, false
}
cp, ok := w.(CapsProvider)
if !ok {
return gw, true
}
caps := cp.Capabilities()
switch want {
case GSOProtoTCP:
return gw, caps.TSO
case GSOProtoUDP:
return gw, caps.USO
}
return gw, false
} }

View File

@@ -0,0 +1,461 @@
package tio
import (
"fmt"
"io"
"log/slog"
"os"
"sync"
"sync/atomic"
"syscall"
"unsafe"
"golang.org/x/sys/unix"
"github.com/slackhq/nebula/overlay/tio/virtio"
)
// tunRxBufSize is the per-Read worst-case footprint inside rxBuf: one
// kernel-supplied packet body, which is at most ~64 KiB (tunReadBufSize).
// Segmentation happens at encrypt time on a per-routine MTU-sized scratch
// (see SegmentSuperpacket), so rxBuf only holds raw kernel-supplied bytes.
// We round up to give comfortable margin for the drain headroom check
// below.
const tunRxBufSize = 64 * 1024
// tunRxBufCap is the total size we allocate for the per-reader rx
// buffer. With reads landing directly in rxBuf, each drain iteration
// consumes up to tunRxBufSize of headroom for the kernel-supplied bytes.
// Sized to two such iterations so the initial blocking read plus one
// drain read both fit without partial-drop.
const tunRxBufCap = tunRxBufSize * 2
// tunDrainCap caps how many packets a single Read will accumulate via
// the post-wake drain loop. Sized to soak up a burst of small ACKs while
// bounding how much work a single caller holds before handing off.
const tunDrainCap = 64
// gsoMaxIovs caps the iovec budget WriteGSO assembles per call: 3 fixed
// entries (virtio_net_hdr, IP hdr, transport hdr) plus up to gsoMaxIovs-3
// payload fragments. Sized comfortably above the typical kernel GSO
// segment cap (Linux UDP_GRO is 64) so realistic coalesced bursts never
// touch the limit. iovecs are tiny (16 bytes), so the entire scratch is
// 4 KiB — fine to keep resident on every queue. WriteGSO returns an error
// rather than reallocating when a caller exceeds this budget.
const gsoMaxIovs = 256
// validVnetHdr is the 10-byte virtio_net_hdr we prepend to every non-GSO TUN
// write. Only flag set is VIRTIO_NET_HDR_F_DATA_VALID, which marks the skb
// CHECKSUM_UNNECESSARY so the receiving network stack skips L4 checksum
// verification. All packets that reach the plain Write paths already carry
// a valid L4 checksum (either supplied by a remote peer whose ciphertext we
// AEAD-authenticated, produced by segmentTCPYield/segmentUDPYield during
// superpacket segmentation, or built locally by CreateRejectPacket), so
// trusting them is safe.
var validVnetHdr = [virtio.Size]byte{unix.VIRTIO_NET_HDR_F_DATA_VALID}
// Offload wraps a TUN file descriptor with poll-based reads. The FD provided will be changed to non-blocking.
// A shared eventfd allows Close to wake all readers blocked in poll.
type Offload struct {
fd int
shutdownFd int
readPoll [2]unix.PollFd
writePoll [2]unix.PollFd
// writeLock serializes blockOnWrite's read+clear of writePoll[*].Revents.
// Any goroutine that calls Write may end up parked in poll(2); without
// the lock concurrent waiters could race the Revents reset and lose
// events.
writeLock sync.Mutex
closed atomic.Bool
rxBuf []byte // backing store for kernel-handed packets read this drain
rxOff int // cursor into rxBuf for the current Read drain
pending []Packet // packets returned from the most recent Read
// readVnetScratch holds the 10-byte virtio_net_hdr split off the front of
// every TUN read via readv(2). Decoupling the header from the packet body
// lets us read the body directly into rxBuf at the current rxOff with
// no userspace copy on the GSO_NONE fast path.
readVnetScratch [virtio.Size]byte
// readIovs is the readv(2) iovec scratch wired once at construction —
// iovec[0] points at readVnetScratch; iovec[1].Base/Len is updated per
// read to address the current rxBuf slot.
readIovs [2]unix.Iovec
// usoEnabled records whether the kernel agreed to TUN_F_USO* on this FD,
// so writers can decide whether emitting GSO_UDP_L4 superpackets is safe.
usoEnabled bool
// gsoHdrBuf is a per-queue 10-byte scratch for the virtio_net_hdr emitted
// by WriteGSO. Kept separate from the read-only package-level validVnetHdr
// so non-GSO Writes can ship that constant directly while WriteGSO
// rewrites this scratch on every call.
gsoHdrBuf [virtio.Size]byte
// gsoIovs is the writev iovec scratch for WriteGSO. Pre-sized to
// gsoMaxIovs at construction; never grown. WriteGSO returns an error
// (and drops the call) if a caller hands it more fragments than fit.
gsoIovs []unix.Iovec
}
func newOffload(fd int, shutdownFd int, usoEnabled bool) (*Offload, error) {
if err := unix.SetNonblock(fd, true); err != nil {
return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err)
}
out := &Offload{
fd: fd,
shutdownFd: shutdownFd,
usoEnabled: usoEnabled,
closed: atomic.Bool{},
readPoll: [2]unix.PollFd{
{Fd: int32(fd), Events: unix.POLLIN},
{Fd: int32(shutdownFd), Events: unix.POLLIN},
},
writePoll: [2]unix.PollFd{
{Fd: int32(fd), Events: unix.POLLOUT},
{Fd: int32(shutdownFd), Events: unix.POLLIN},
},
writeLock: sync.Mutex{},
rxBuf: make([]byte, tunRxBufCap),
gsoIovs: make([]unix.Iovec, 2, gsoMaxIovs),
}
out.gsoIovs[0].Base = &out.gsoHdrBuf[0]
out.gsoIovs[0].SetLen(virtio.Size)
// readIovs[0] is wired once to the virtio_net_hdr scratch; per-read we
// only repoint readIovs[1] at the next rxBuf slot (see readPacket).
out.readIovs[0].Base = &out.readVnetScratch[0]
out.readIovs[0].SetLen(virtio.Size)
return out, nil
}
func (r *Offload) blockOnRead() error {
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
var err error
for {
_, err = unix.Poll(r.readPoll[:], -1)
if err != unix.EINTR {
break
}
}
//always reset these!
tunEvents := r.readPoll[0].Revents
shutdownEvents := r.readPoll[1].Revents
r.readPoll[0].Revents = 0
r.readPoll[1].Revents = 0
//do the err check before trusting the potentially bogus bits we just got
if err != nil {
return err
}
if shutdownEvents&(unix.POLLIN|problemFlags) != 0 {
return os.ErrClosed
} else if tunEvents&problemFlags != 0 {
return os.ErrClosed
}
return nil
}
func (r *Offload) blockOnWrite() error {
const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR
var err error
for {
_, err = unix.Poll(r.writePoll[:], -1)
if err != unix.EINTR {
break
}
}
//always reset these!
r.writeLock.Lock()
tunEvents := r.writePoll[0].Revents
shutdownEvents := r.writePoll[1].Revents
r.writePoll[0].Revents = 0
r.writePoll[1].Revents = 0
r.writeLock.Unlock()
//do the err check before trusting the potentially bogus bits we just got
if err != nil {
return err
}
if shutdownEvents&(unix.POLLIN|problemFlags) != 0 {
return os.ErrClosed
} else if tunEvents&problemFlags != 0 {
return os.ErrClosed
}
return nil
}
// readPacket issues a single readv(2) splitting the virtio_net_hdr off
// into readVnetScratch and reading the packet body directly into rxBuf at
// the current rxOff. Returns the body length (zero virtio header bytes,
// just the IP packet/superpacket). block controls whether EAGAIN is
// retried via poll: the initial read of a drain blocks; subsequent drain
// reads do not.
//
// The body iovec capacity is always tunReadBufSize; callers (the Read
// drain loop) gate entry on tunRxBufCap-rxOff >= tunRxBufSize, sized to
// hold one worst-case kernel-supplied packet body. Without that gate the
// body iovec could be smaller than the next inbound packet and the
// kernel would truncate.
func (r *Offload) readPacket(block bool) (int, error) {
for {
r.readIovs[1].Base = &r.rxBuf[r.rxOff]
r.readIovs[1].SetLen(tunReadBufSize)
n, _, errno := syscall.Syscall(unix.SYS_READV, uintptr(r.fd), uintptr(unsafe.Pointer(&r.readIovs[0])), uintptr(len(r.readIovs)))
if errno == 0 {
if int(n) < virtio.Size {
return 0, io.ErrShortWrite
}
return int(n) - virtio.Size, nil
}
if errno == unix.EAGAIN {
if !block {
return 0, errno
}
if err := r.blockOnRead(); err != nil {
return 0, err
}
continue
}
if errno == unix.EINTR {
continue
}
if errno == unix.EBADF {
return 0, os.ErrClosed
}
return 0, errno
}
}
// Read returns one or more packets from the tun. Each Packet either
// carries a single ready-to-use IP datagram (GSO zero) or a TSO/USO
// superpacket plus the GSOInfo a caller needs to segment it (see
// SegmentSuperpacket). The first read blocks via poll; once the fd is
// known readable we drain additional packets non-blocking until the
// kernel queue is empty (EAGAIN), we've collected tunDrainCap packets,
// or we're out of rxBuf headroom. This amortizes the poll wake over
// bursts of small packets (e.g. TCP ACKs). Packet.Bytes slices point
// into the Offload's internal buffer and are only valid until the next
// Read or Close on this Queue.
func (r *Offload) Read() ([]Packet, error) {
r.pending = r.pending[:0]
r.rxOff = 0
// Initial (blocking) read. Retry on decode errors so a single bad
// packet does not stall the reader.
for {
n, err := r.readPacket(true)
if err != nil {
return nil, err
}
if err := r.decodeRead(n); err != nil {
// Drop and read again — a bad packet should not kill the reader.
continue
}
break
}
// Drain: non-blocking reads until the kernel queue is empty, the drain
// cap is reached, or rxBuf no longer has room for another worst-case
// kernel-supplied packet (tunRxBufSize).
for len(r.pending) < tunDrainCap && tunRxBufCap-r.rxOff >= tunRxBufSize {
n, err := r.readPacket(false)
if err != nil {
// EAGAIN / EINTR / anything else: stop draining. We already
// have a valid batch from the first read.
break
}
if n <= 0 {
break
}
if err := r.decodeRead(n); err != nil {
// Drop this packet and stop the drain; we'd rather hand off
// what we have than keep spinning here.
break
}
}
return r.pending, nil
}
// decodeRead processes the packet sitting in rxBuf at rxOff (length
// pktLen). The bytes stay in rxBuf — for GSO_NONE we slice them as a
// regular IP datagram (running finishChecksum if NEEDS_CSUM is set);
// for TSO/USO superpackets we attach the corrected GSO metadata so the
// caller can segment lazily at encrypt time. rxOff advances past the
// kernel-supplied body and nothing else, since segmentation no longer
// writes back into rxBuf.
func (r *Offload) decodeRead(pktLen int) error {
if pktLen <= 0 {
return fmt.Errorf("short tun read: %d", pktLen)
}
var hdr virtio.Hdr
hdr.Decode(r.readVnetScratch[:])
body := r.rxBuf[r.rxOff : r.rxOff+pktLen]
if hdr.GSOType == unix.VIRTIO_NET_HDR_GSO_NONE {
if hdr.Flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 {
if err := virtio.FinishChecksum(body, hdr); err != nil {
return err
}
}
r.pending = append(r.pending, Packet{Bytes: body})
r.rxOff += pktLen
return nil
}
// GSO superpacket: validate, fix the kernel-supplied HdrLen on the
// FORWARD path (CorrectHdrLen), pick the L4 protocol, and attach
// the metadata. The bytes stay in rxBuf untouched, segmentation
// happens in SegmentSuperpacket at encrypt time.
if err := virtio.CheckValid(body, hdr); err != nil {
return err
}
if err := virtio.CorrectHdrLen(body, &hdr); err != nil {
return err
}
proto, err := protoFromGSOType(hdr.GSOType)
if err != nil {
return err
}
r.pending = append(r.pending, Packet{
Bytes: body,
GSO: GSOInfo{
Size: hdr.GSOSize,
HdrLen: hdr.HdrLen,
CsumStart: hdr.CsumStart,
Proto: proto,
},
})
r.rxOff += pktLen
return nil
}
func (r *Offload) Write(buf []byte) (int, error) {
iovs := [2]unix.Iovec{
{Base: &validVnetHdr[0]},
{Base: &buf[0]},
}
iovs[0].SetLen(virtio.Size)
iovs[1].SetLen(len(buf))
return r.writeWithScratch(buf, &iovs)
}
func (r *Offload) writeWithScratch(buf []byte, iovs *[2]unix.Iovec) (int, error) {
if len(buf) == 0 {
return 0, nil
}
iovs[1].Base = &buf[0]
iovs[1].SetLen(len(buf))
return r.rawWrite(unsafe.Slice(&iovs[0], len(iovs)))
}
func (r *Offload) rawWrite(iovs []unix.Iovec) (int, error) {
for {
n, _, errno := syscall.Syscall(unix.SYS_WRITEV, uintptr(r.fd), uintptr(unsafe.Pointer(&iovs[0])), uintptr(len(iovs)))
if errno == 0 {
if int(n) < virtio.Size {
return 0, io.ErrShortWrite
}
return int(n) - virtio.Size, nil
}
if errno == unix.EAGAIN {
if err := r.blockOnWrite(); err != nil {
return 0, err
}
continue
}
if errno == unix.EINTR {
continue
}
if errno == unix.EBADF {
return 0, os.ErrClosed
}
return 0, errno
}
}
// Capabilities reports the offload features negotiated for this Queue. TSO
// is always true for Offload (we only construct it on IFF_VNET_HDR FDs);
// USO is true only when the kernel agreed to TUN_F_USO4|6 at open time
// (Linux ≥ 6.2).
func (r *Offload) Capabilities() Capabilities {
return Capabilities{TSO: true, USO: r.usoEnabled}
}
func (r *Offload) WriteGSO(hdr []byte, transportHdr []byte, pays [][]byte, proto GSOProto) error {
if len(hdr) == 0 || len(pays) == 0 || len(transportHdr) == 0 {
return nil
}
// L4 checksum offset inside transportHdr: TCP=16 (the `check` field after
// seq/ack/dataoff/flags/window), UDP=6 (after sport/dport/length).
var csumOff uint16
switch proto {
case GSOProtoUDP:
csumOff = 6
default:
csumOff = 16
}
vhdr := virtio.Hdr{
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
HdrLen: uint16(len(hdr) + len(transportHdr)),
GSOSize: uint16(len(pays[0])),
CsumStart: uint16(len(hdr)),
CsumOffset: csumOff,
}
if len(pays) > 1 {
ipVer := hdr[0] >> 4
switch {
case proto == GSOProtoUDP && (ipVer == 4 || ipVer == 6):
vhdr.GSOType = unix.VIRTIO_NET_HDR_GSO_UDP_L4
case ipVer == 6:
vhdr.GSOType = unix.VIRTIO_NET_HDR_GSO_TCPV6
case ipVer == 4:
vhdr.GSOType = unix.VIRTIO_NET_HDR_GSO_TCPV4
default:
vhdr.GSOType = unix.VIRTIO_NET_HDR_GSO_NONE
vhdr.GSOSize = 0
}
} else {
vhdr.GSOType = unix.VIRTIO_NET_HDR_GSO_NONE
vhdr.GSOSize = 0
}
vhdr.Encode(r.gsoHdrBuf[:])
// Build the iovec array: [virtio_hdr, hdr, transportHdr, pays...]. r.gsoIovs[0] is
// wired to gsoHdrBuf at construction and never changes.
need := 3 + len(pays)
if need > cap(r.gsoIovs) {
slog.Default().Warn("tio: WriteGSO iovec budget exceeded; dropping superpacket",
"need", need, "cap", cap(r.gsoIovs), "segments", len(pays))
return fmt.Errorf("tio: WriteGSO needs %d iovecs but cap is %d", need, cap(r.gsoIovs))
}
r.gsoIovs = r.gsoIovs[:need]
r.gsoIovs[1].Base = &hdr[0]
r.gsoIovs[1].SetLen(len(hdr))
r.gsoIovs[2].Base = &transportHdr[0]
r.gsoIovs[2].SetLen(len(transportHdr))
for i, p := range pays {
r.gsoIovs[3+i].Base = &p[0]
r.gsoIovs[3+i].SetLen(len(p))
}
_, err := r.rawWrite(r.gsoIovs)
return err
}
func (r *Offload) Close() error {
if r.closed.Swap(true) {
return nil
}
//shutdownFd is owned by the container, so we should not close it
var err error
if r.fd >= 0 {
err = unix.Close(r.fd)
r.fd = -1
}
return err
}

View File

@@ -21,7 +21,7 @@ type Poll struct {
closed atomic.Bool closed atomic.Bool
readBuf []byte readBuf []byte
batchRet [1][]byte batchRet [1]Packet
} }
func newPoll(fd int, shutdownFd int) (*Poll, error) { func newPoll(fd int, shutdownFd int) (*Poll, error) {
@@ -97,12 +97,12 @@ func (t *Poll) blockOnWrite() error {
return nil return nil
} }
func (t *Poll) Read() ([][]byte, error) { func (t *Poll) Read() ([]Packet, error) {
n, err := t.readOne(t.readBuf) n, err := t.readOne(t.readBuf)
if err != nil { if err != nil {
return nil, err return nil, err
} }
t.batchRet[0] = t.readBuf[:n] t.batchRet[0] = Packet{Bytes: t.readBuf[:n]}
return t.batchRet[:], nil return t.batchRet[:], nil
} }

View File

@@ -15,7 +15,7 @@ import (
) )
// newReadPipe returns a read fd. The matching write fd is registered for cleanup. // newReadPipe returns a read fd. The matching write fd is registered for cleanup.
// The caller takes ownership of the read fd (pass it to newOffload / newFriend). // The caller takes ownership of the read fd (pass it into a QueueSet).
func newReadPipe(t *testing.T) int { func newReadPipe(t *testing.T) int {
t.Helper() t.Helper()
var fds [2]int var fds [2]int
@@ -29,7 +29,7 @@ func newReadPipe(t *testing.T) int {
func TestPoll_WakeForShutdown_WakesFriends(t *testing.T) { func TestPoll_WakeForShutdown_WakesFriends(t *testing.T) {
pipe1 := newReadPipe(t) pipe1 := newReadPipe(t)
pipe2 := newReadPipe(t) pipe2 := newReadPipe(t)
parent, err := NewPollContainer() parent, err := NewPollQueueSet()
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, parent.Add(pipe1)) require.NoError(t, parent.Add(pipe1))
require.NoError(t, parent.Add(pipe2)) require.NoError(t, parent.Add(pipe2))

View File

@@ -0,0 +1,51 @@
//go:build linux && !android && !e2e_testing
// +build linux,!android,!e2e_testing
package tio
import (
"fmt"
"golang.org/x/sys/unix"
"github.com/slackhq/nebula/overlay/tio/virtio"
)
// protoFromGSOType maps a virtio_net_hdr GSOType to the GSOProto value the
// segment-time helpers use. Returns an error for GSO_NONE or any unknown
// value — the caller should only invoke this on a confirmed superpacket.
func protoFromGSOType(t uint8) (GSOProto, error) {
switch t {
case unix.VIRTIO_NET_HDR_GSO_TCPV4, unix.VIRTIO_NET_HDR_GSO_TCPV6:
return GSOProtoTCP, nil
case unix.VIRTIO_NET_HDR_GSO_UDP_L4:
return GSOProtoUDP, nil
default:
return 0, fmt.Errorf("unsupported virtio gso type: %d", t)
}
}
// SegmentSuperpacket invokes fn once per segment of pkt. For non-GSO pkts
// fn is called once with pkt.Bytes (no segmentation, no copy). For GSO/USO
// superpackets fn is called once per segment with a slice of pkt.Bytes
// holding that segment's plaintext (a freshly-patched L3+L4 header sliced
// in front of the original payload chunk). The slide is destructive: pkt is
// consumed by this call and its bytes are in an undefined state when
// SegmentSuperpacket returns. Callers must not retain pkt or any earlier
// seg slice past fn's return for that segment. The scratch parameter is
// unused on the destructive path and kept only for cross-platform
// signature compatibility. Aborts and returns the first error from fn or
// from per-segment construction.
func SegmentSuperpacket(pkt Packet, fn func(seg []byte) error) error {
if !pkt.GSO.IsSuperpacket() {
return fn(pkt.Bytes)
}
switch pkt.GSO.Proto {
case GSOProtoTCP:
return virtio.SegmentTCP(pkt.Bytes, pkt.GSO.HdrLen, pkt.GSO.CsumStart, pkt.GSO.Size, fn)
case GSOProtoUDP:
return virtio.SegmentUDP(pkt.Bytes, pkt.GSO.HdrLen, pkt.GSO.CsumStart, pkt.GSO.Size, fn)
default:
return fmt.Errorf("unsupported gso proto: %d", pkt.GSO.Proto)
}
}

View File

@@ -0,0 +1,794 @@
//go:build linux && !android && !e2e_testing
// +build linux,!android,!e2e_testing
package tio
import (
"encoding/binary"
"os"
"testing"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/tcpip/checksum"
"github.com/slackhq/nebula/overlay/tio/virtio"
)
// testSegScratchSize is a generous segmentation scratch sized to fit any
// of the synthetic TSO/USO superpackets these tests generate (one
// worst-case 64 KiB superpacket plus replicated per-segment headers).
const testSegScratchSize = 192 * 1024
// verifyChecksum confirms that the one's-complement sum across `b`, seeded
// with a folded pseudo-header sum, equals all-ones (valid).
func verifyChecksum(b []byte, pseudo uint16) bool {
return checksum.Checksum(b, pseudo) == 0xffff
}
// segmentForTest is the test-only counterpart to the production
// SegmentSuperpacket path. It handles GSO_NONE (with optional
// finishChecksum) inline and dispatches GSO superpackets through
// SegmentSuperpacket, draining each yielded segment into a
// freshly-copied [][]byte slot so callers can iterate after the call
// returns. Tests pre-set hdr.HdrLen correctly, so correctHdrLen is not
// invoked here.
func segmentForTest(pkt []byte, hdr virtio.Hdr, out *[][]byte, scratch []byte) error {
if hdr.GSOType == unix.VIRTIO_NET_HDR_GSO_NONE {
cp := append([]byte(nil), pkt...)
if hdr.Flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 {
if err := virtio.FinishChecksum(cp, hdr); err != nil {
return err
}
}
*out = append(*out, cp)
return nil
}
proto, err := protoFromGSOType(hdr.GSOType)
if err != nil {
return err
}
gso := GSOInfo{
Size: hdr.GSOSize,
HdrLen: hdr.HdrLen,
CsumStart: hdr.CsumStart,
Proto: proto,
}
return SegmentSuperpacket(Packet{Bytes: pkt, GSO: gso}, func(seg []byte) error {
*out = append(*out, append([]byte(nil), seg...))
return nil
})
}
// pseudoHeaderIPv4 returns the folded pseudo-header sum used to verify a
// TCP/UDP segment's checksum in tests. src/dst are 4 bytes each.
func pseudoHeaderIPv4(src, dst []byte, proto byte, l4Len int) uint16 {
s := uint32(checksum.Checksum(src, 0)) + uint32(checksum.Checksum(dst, 0))
s += uint32(proto) + uint32(l4Len)
s = (s & 0xffff) + (s >> 16)
s = (s & 0xffff) + (s >> 16)
return uint16(s)
}
// pseudoHeaderIPv6 returns the folded pseudo-header sum used to verify a
// TCP/UDP segment's checksum in tests. src/dst are 16 bytes each.
func pseudoHeaderIPv6(src, dst []byte, proto byte, l4Len int) uint16 {
s := uint32(checksum.Checksum(src, 0)) + uint32(checksum.Checksum(dst, 0))
s += uint32(l4Len>>16) + uint32(l4Len&0xffff) + uint32(proto)
s = (s & 0xffff) + (s >> 16)
s = (s & 0xffff) + (s >> 16)
return uint16(s)
}
// buildTSOv4 builds a synthetic IPv4/TCP TSO superpacket with a payload of
// `payLen` bytes split at `mss`.
func buildTSOv4(t *testing.T, payLen, mss int) ([]byte, virtio.Hdr) {
t.Helper()
const ipLen = 20
const tcpLen = 20
pkt := make([]byte, ipLen+tcpLen+payLen)
// IPv4 header
pkt[0] = 0x45 // version 4, IHL 5
// total length is meaningless for TSO but set it anyway
binary.BigEndian.PutUint16(pkt[2:4], uint16(ipLen+tcpLen+payLen))
binary.BigEndian.PutUint16(pkt[4:6], 0x4242) // original ID
pkt[8] = 64 // TTL
pkt[9] = unix.IPPROTO_TCP
copy(pkt[12:16], []byte{10, 0, 0, 1}) // src
copy(pkt[16:20], []byte{10, 0, 0, 2}) // dst
// TCP header
binary.BigEndian.PutUint16(pkt[20:22], 12345) // sport
binary.BigEndian.PutUint16(pkt[22:24], 80) // dport
binary.BigEndian.PutUint32(pkt[24:28], 10000) // seq
binary.BigEndian.PutUint32(pkt[28:32], 20000) // ack
pkt[32] = 0x50 // data offset 5 words
pkt[33] = 0x18 // ACK | PSH
binary.BigEndian.PutUint16(pkt[34:36], 65535) // window
// payload
for i := 0; i < payLen; i++ {
pkt[ipLen+tcpLen+i] = byte(i & 0xff)
}
return pkt, virtio.Hdr{
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
GSOType: unix.VIRTIO_NET_HDR_GSO_TCPV4,
HdrLen: uint16(ipLen + tcpLen),
GSOSize: uint16(mss),
CsumStart: uint16(ipLen),
CsumOffset: 16,
}
}
func TestSegmentTCPv4(t *testing.T) {
const mss = 100
const numSeg = 3
pkt, hdr := buildTSOv4(t, mss*numSeg, mss)
scratch := make([]byte, testSegScratchSize)
var out [][]byte
if err := segmentForTest(pkt, hdr, &out, scratch); err != nil {
t.Fatalf("segmentForTest: %v", err)
}
if len(out) != numSeg {
t.Fatalf("expected %d segments, got %d", numSeg, len(out))
}
for i, seg := range out {
if len(seg) != 40+mss {
t.Errorf("seg %d: unexpected len %d", i, len(seg))
}
totalLen := binary.BigEndian.Uint16(seg[2:4])
if totalLen != uint16(40+mss) {
t.Errorf("seg %d: total_len=%d want %d", i, totalLen, 40+mss)
}
id := binary.BigEndian.Uint16(seg[4:6])
if id != 0x4242+uint16(i) {
t.Errorf("seg %d: ip id=%#x want %#x", i, id, 0x4242+uint16(i))
}
seq := binary.BigEndian.Uint32(seg[24:28])
wantSeq := uint32(10000 + i*mss)
if seq != wantSeq {
t.Errorf("seg %d: seq=%d want %d", i, seq, wantSeq)
}
flags := seg[33]
wantFlags := byte(0x10) // ACK only, PSH cleared
if i == numSeg-1 {
wantFlags = 0x18 // ACK | PSH preserved on last
}
if flags != wantFlags {
t.Errorf("seg %d: flags=%#x want %#x", i, flags, wantFlags)
}
// IPv4 header checksum must verify against itself.
if !verifyChecksum(seg[:20], 0) {
t.Errorf("seg %d: bad IPv4 header checksum", i)
}
// TCP checksum must verify against the pseudo-header.
psum := pseudoHeaderIPv4(seg[12:16], seg[16:20], unix.IPPROTO_TCP, 20+mss)
if !verifyChecksum(seg[20:], psum) {
t.Errorf("seg %d: bad TCP checksum", i)
}
}
}
func TestSegmentTCPv4OddTail(t *testing.T) {
// Payload of 250 bytes with MSS 100 → segments of 100, 100, 50.
pkt, hdr := buildTSOv4(t, 250, 100)
scratch := make([]byte, testSegScratchSize)
var out [][]byte
if err := segmentForTest(pkt, hdr, &out, scratch); err != nil {
t.Fatalf("segmentForTest: %v", err)
}
if len(out) != 3 {
t.Fatalf("want 3 segments, got %d", len(out))
}
wantPayLens := []int{100, 100, 50}
for i, seg := range out {
if len(seg)-40 != wantPayLens[i] {
t.Errorf("seg %d: pay len %d want %d", i, len(seg)-40, wantPayLens[i])
}
if !verifyChecksum(seg[:20], 0) {
t.Errorf("seg %d: bad IPv4 header checksum", i)
}
psum := pseudoHeaderIPv4(seg[12:16], seg[16:20], unix.IPPROTO_TCP, 20+wantPayLens[i])
if !verifyChecksum(seg[20:], psum) {
t.Errorf("seg %d: bad TCP checksum", i)
}
}
}
func TestSegmentTCPv6(t *testing.T) {
const ipLen = 40
const tcpLen = 20
const mss = 120
const numSeg = 2
payLen := mss * numSeg
pkt := make([]byte, ipLen+tcpLen+payLen)
// IPv6 header
pkt[0] = 0x60 // version 6
binary.BigEndian.PutUint16(pkt[4:6], uint16(tcpLen+payLen))
pkt[6] = unix.IPPROTO_TCP
pkt[7] = 64
// src/dst fe80::1 / fe80::2
pkt[8] = 0xfe
pkt[9] = 0x80
pkt[23] = 1
pkt[24] = 0xfe
pkt[25] = 0x80
pkt[39] = 2
// TCP header
binary.BigEndian.PutUint16(pkt[40:42], 12345)
binary.BigEndian.PutUint16(pkt[42:44], 80)
binary.BigEndian.PutUint32(pkt[44:48], 7)
binary.BigEndian.PutUint32(pkt[48:52], 99)
pkt[52] = 0x50
pkt[53] = 0x19 // FIN | ACK | PSH — exercise FIN clearing too
binary.BigEndian.PutUint16(pkt[54:56], 65535)
for i := 0; i < payLen; i++ {
pkt[ipLen+tcpLen+i] = byte(i)
}
hdr := virtio.Hdr{
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
GSOType: unix.VIRTIO_NET_HDR_GSO_TCPV6,
HdrLen: uint16(ipLen + tcpLen),
GSOSize: uint16(mss),
CsumStart: uint16(ipLen),
CsumOffset: 16,
}
scratch := make([]byte, testSegScratchSize)
var out [][]byte
if err := segmentForTest(pkt, hdr, &out, scratch); err != nil {
t.Fatalf("segmentForTest: %v", err)
}
if len(out) != numSeg {
t.Fatalf("want %d segments, got %d", numSeg, len(out))
}
for i, seg := range out {
if len(seg) != ipLen+tcpLen+mss {
t.Errorf("seg %d: len %d want %d", i, len(seg), ipLen+tcpLen+mss)
}
pl := binary.BigEndian.Uint16(seg[4:6])
if pl != uint16(tcpLen+mss) {
t.Errorf("seg %d: payload_length=%d want %d", i, pl, tcpLen+mss)
}
seq := binary.BigEndian.Uint32(seg[44:48])
if seq != uint32(7+i*mss) {
t.Errorf("seg %d: seq=%d want %d", i, seq, 7+i*mss)
}
flags := seg[53]
// Original flags = 0x19 (FIN|ACK|PSH). FIN(0x01)+PSH(0x08) should be
// cleared on all but the last; ACK(0x10) always preserved.
wantFlags := byte(0x10)
if i == numSeg-1 {
wantFlags = 0x19
}
if flags != wantFlags {
t.Errorf("seg %d: flags=%#x want %#x", i, flags, wantFlags)
}
psum := pseudoHeaderIPv6(seg[8:24], seg[24:40], unix.IPPROTO_TCP, tcpLen+mss)
if !verifyChecksum(seg[ipLen:], psum) {
t.Errorf("seg %d: bad TCP checksum", i)
}
}
}
func TestSegmentGSONonePassesThrough(t *testing.T) {
pkt, hdr := buildTSOv4(t, 100, 100)
hdr.GSOType = unix.VIRTIO_NET_HDR_GSO_NONE
hdr.Flags = 0 // no NEEDS_CSUM, leave packet untouched
scratch := make([]byte, testSegScratchSize)
var out [][]byte
if err := segmentForTest(pkt, hdr, &out, scratch); err != nil {
t.Fatalf("segmentForTest: %v", err)
}
if len(out) != 1 {
t.Fatalf("want 1 segment, got %d", len(out))
}
if len(out[0]) != len(pkt) {
t.Fatalf("unexpected length: %d vs %d", len(out[0]), len(pkt))
}
}
// TestSegmentRejectsLegacyUDPGSO ensures the legacy GSO_UDP (UFO) marker is
// still rejected; only modern GSO_UDP_L4 (USO) is supported.
func TestSegmentRejectsLegacyUDPGSO(t *testing.T) {
hdr := virtio.Hdr{GSOType: unix.VIRTIO_NET_HDR_GSO_UDP}
var out [][]byte
if err := segmentForTest(nil, hdr, &out, nil); err == nil {
t.Fatalf("expected rejection for legacy UDP GSO")
}
}
// buildUSOv4 builds a synthetic IPv4/UDP USO superpacket with payload of
// payLen bytes, segmented at gsoSize.
func buildUSOv4(t *testing.T, payLen, gsoSize int) ([]byte, virtio.Hdr) {
t.Helper()
const ipLen = 20
const udpLen = 8
pkt := make([]byte, ipLen+udpLen+payLen)
// IPv4 header
pkt[0] = 0x45 // version 4, IHL 5
binary.BigEndian.PutUint16(pkt[2:4], uint16(ipLen+udpLen+payLen))
binary.BigEndian.PutUint16(pkt[4:6], 0x4242)
pkt[8] = 64
pkt[9] = unix.IPPROTO_UDP
copy(pkt[12:16], []byte{10, 0, 0, 1})
copy(pkt[16:20], []byte{10, 0, 0, 2})
// UDP header (length + checksum filled in per segment by segmentUDPYield)
binary.BigEndian.PutUint16(pkt[20:22], 12345) // sport
binary.BigEndian.PutUint16(pkt[22:24], 53) // dport
for i := 0; i < payLen; i++ {
pkt[ipLen+udpLen+i] = byte(i & 0xff)
}
return pkt, virtio.Hdr{
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
GSOType: unix.VIRTIO_NET_HDR_GSO_UDP_L4,
HdrLen: uint16(ipLen + udpLen),
GSOSize: uint16(gsoSize),
CsumStart: uint16(ipLen),
CsumOffset: 6,
}
}
func TestSegmentUDPv4(t *testing.T) {
const gso = 100
const numSeg = 3
pkt, hdr := buildUSOv4(t, gso*numSeg, gso)
scratch := make([]byte, testSegScratchSize)
var out [][]byte
if err := segmentForTest(pkt, hdr, &out, scratch); err != nil {
t.Fatalf("segmentForTest: %v", err)
}
if len(out) != numSeg {
t.Fatalf("expected %d segments, got %d", numSeg, len(out))
}
for i, seg := range out {
if len(seg) != 28+gso {
t.Errorf("seg %d: len %d want %d", i, len(seg), 28+gso)
}
totalLen := binary.BigEndian.Uint16(seg[2:4])
if totalLen != uint16(28+gso) {
t.Errorf("seg %d: total_len=%d want %d", i, totalLen, 28+gso)
}
// kernel UDP-GSO does NOT bump the IPv4 ID across segments; every
// segment carries the same ID as the seed.
id := binary.BigEndian.Uint16(seg[4:6])
if id != 0x4242 {
t.Errorf("seg %d: ip id=%#x want %#x", i, id, 0x4242)
}
udpLen := binary.BigEndian.Uint16(seg[24:26])
if udpLen != uint16(8+gso) {
t.Errorf("seg %d: udp len=%d want %d", i, udpLen, 8+gso)
}
if !verifyChecksum(seg[:20], 0) {
t.Errorf("seg %d: bad IPv4 header checksum", i)
}
psum := pseudoHeaderIPv4(seg[12:16], seg[16:20], unix.IPPROTO_UDP, 8+gso)
if !verifyChecksum(seg[20:], psum) {
t.Errorf("seg %d: bad UDP checksum", i)
}
}
}
func TestSegmentUDPv4OddTail(t *testing.T) {
// 250 bytes payload, gsoSize=100 → segments of 100, 100, 50.
pkt, hdr := buildUSOv4(t, 250, 100)
scratch := make([]byte, testSegScratchSize)
var out [][]byte
if err := segmentForTest(pkt, hdr, &out, scratch); err != nil {
t.Fatalf("segmentForTest: %v", err)
}
if len(out) != 3 {
t.Fatalf("want 3 segments, got %d", len(out))
}
wantPay := []int{100, 100, 50}
for i, seg := range out {
if len(seg)-28 != wantPay[i] {
t.Errorf("seg %d: pay len %d want %d", i, len(seg)-28, wantPay[i])
}
udpLen := binary.BigEndian.Uint16(seg[24:26])
if udpLen != uint16(8+wantPay[i]) {
t.Errorf("seg %d: udp len=%d want %d", i, udpLen, 8+wantPay[i])
}
if !verifyChecksum(seg[:20], 0) {
t.Errorf("seg %d: bad IPv4 header checksum", i)
}
psum := pseudoHeaderIPv4(seg[12:16], seg[16:20], unix.IPPROTO_UDP, 8+wantPay[i])
if !verifyChecksum(seg[20:], psum) {
t.Errorf("seg %d: bad UDP checksum", i)
}
}
}
func TestSegmentUDPv6(t *testing.T) {
const ipLen = 40
const udpLen = 8
const gso = 120
const numSeg = 2
payLen := gso * numSeg
pkt := make([]byte, ipLen+udpLen+payLen)
// IPv6 header
pkt[0] = 0x60
binary.BigEndian.PutUint16(pkt[4:6], uint16(udpLen+payLen))
pkt[6] = unix.IPPROTO_UDP
pkt[7] = 64
pkt[8] = 0xfe
pkt[9] = 0x80
pkt[23] = 1
pkt[24] = 0xfe
pkt[25] = 0x80
pkt[39] = 2
binary.BigEndian.PutUint16(pkt[40:42], 12345)
binary.BigEndian.PutUint16(pkt[42:44], 53)
for i := 0; i < payLen; i++ {
pkt[ipLen+udpLen+i] = byte(i)
}
hdr := virtio.Hdr{
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
GSOType: unix.VIRTIO_NET_HDR_GSO_UDP_L4,
HdrLen: uint16(ipLen + udpLen),
GSOSize: uint16(gso),
CsumStart: uint16(ipLen),
CsumOffset: 6,
}
scratch := make([]byte, testSegScratchSize)
var out [][]byte
if err := segmentForTest(pkt, hdr, &out, scratch); err != nil {
t.Fatalf("segmentForTest: %v", err)
}
if len(out) != numSeg {
t.Fatalf("want %d segments, got %d", numSeg, len(out))
}
for i, seg := range out {
if len(seg) != ipLen+udpLen+gso {
t.Errorf("seg %d: len %d want %d", i, len(seg), ipLen+udpLen+gso)
}
pl := binary.BigEndian.Uint16(seg[4:6])
if pl != uint16(udpLen+gso) {
t.Errorf("seg %d: payload_length=%d want %d", i, pl, udpLen+gso)
}
ul := binary.BigEndian.Uint16(seg[ipLen+4 : ipLen+6])
if ul != uint16(udpLen+gso) {
t.Errorf("seg %d: udp len=%d want %d", i, ul, udpLen+gso)
}
psum := pseudoHeaderIPv6(seg[8:24], seg[24:40], unix.IPPROTO_UDP, udpLen+gso)
if !verifyChecksum(seg[ipLen:], psum) {
t.Errorf("seg %d: bad UDP checksum", i)
}
}
}
// TestSegmentUDPCEPropagates confirms IP-level CE marks on the seed appear on
// every segment. UDP has no transport-level CWR/ECE: the IP TOS/TC byte is
// copied verbatim into every segment by the segment-prefix copy.
func TestSegmentUDPCEPropagates(t *testing.T) {
pkt, hdr := buildUSOv4(t, 200, 100)
pkt[1] = 0x03 // CE codepoint in IP-ECN
scratch := make([]byte, testSegScratchSize)
var out [][]byte
if err := segmentForTest(pkt, hdr, &out, scratch); err != nil {
t.Fatalf("segmentForTest: %v", err)
}
if len(out) != 2 {
t.Fatalf("want 2 segments, got %d", len(out))
}
for i, seg := range out {
if seg[1]&0x03 != 0x03 {
t.Errorf("seg %d: CE missing (tos=%#x)", i, seg[1])
}
if !verifyChecksum(seg[:20], 0) {
t.Errorf("seg %d: bad IPv4 header checksum", i)
}
}
}
// TestSegmentTCPCwrFirstSegmentOnly confirms RFC 3168 §6.1.2: when a TSO
// burst's seed has CWR set, only the first emitted segment carries CWR.
// ECE is preserved on every segment (different signal, persistent state).
func TestSegmentTCPCwrFirstSegmentOnly(t *testing.T) {
const mss = 100
const numSeg = 3
pkt, hdr := buildTSOv4(t, mss*numSeg, mss)
// Seed flags: CWR | ECE | ACK | PSH.
pkt[33] = 0x80 | 0x40 | 0x10 | 0x08
scratch := make([]byte, testSegScratchSize)
var out [][]byte
if err := segmentForTest(pkt, hdr, &out, scratch); err != nil {
t.Fatalf("segmentForTest: %v", err)
}
if len(out) != numSeg {
t.Fatalf("expected %d segments, got %d", numSeg, len(out))
}
for i, seg := range out {
flags := seg[33]
hasCwr := flags&0x80 != 0
hasEce := flags&0x40 != 0
hasPsh := flags&0x08 != 0
wantCwr := i == 0
wantPsh := i == numSeg-1
if hasCwr != wantCwr {
t.Errorf("seg %d: CWR=%v want %v (flags=%#x)", i, hasCwr, wantCwr, flags)
}
if !hasEce {
t.Errorf("seg %d: ECE missing (flags=%#x)", i, flags)
}
if hasPsh != wantPsh {
t.Errorf("seg %d: PSH=%v want %v (flags=%#x)", i, hasPsh, wantPsh, flags)
}
// IP and TCP checksums must still verify after the flag rewrite.
if !verifyChecksum(seg[:20], 0) {
t.Errorf("seg %d: bad IPv4 header checksum", i)
}
psum := pseudoHeaderIPv4(seg[12:16], seg[16:20], unix.IPPROTO_TCP, 20+mss)
if !verifyChecksum(seg[20:], psum) {
t.Errorf("seg %d: bad TCP checksum", i)
}
}
}
func BenchmarkSegmentTCPv4(b *testing.B) {
sizes := []struct {
name string
payLen int
mss int
}{
{"64KiB_MSS1460", 65000, 1460},
{"16KiB_MSS1460", 16384, 1460},
{"4KiB_MSS1460", 4096, 1460},
}
for _, sz := range sizes {
b.Run(sz.name, func(b *testing.B) {
const ipLen = 20
const tcpLen = 20
pkt := make([]byte, ipLen+tcpLen+sz.payLen)
pkt[0] = 0x45
binary.BigEndian.PutUint16(pkt[2:4], uint16(ipLen+tcpLen+sz.payLen))
binary.BigEndian.PutUint16(pkt[4:6], 0x4242)
pkt[8] = 64
pkt[9] = unix.IPPROTO_TCP
copy(pkt[12:16], []byte{10, 0, 0, 1})
copy(pkt[16:20], []byte{10, 0, 0, 2})
binary.BigEndian.PutUint16(pkt[20:22], 12345)
binary.BigEndian.PutUint16(pkt[22:24], 80)
binary.BigEndian.PutUint32(pkt[24:28], 10000)
binary.BigEndian.PutUint32(pkt[28:32], 20000)
pkt[32] = 0x50
pkt[33] = 0x18
binary.BigEndian.PutUint16(pkt[34:36], 65535)
for i := 0; i < sz.payLen; i++ {
pkt[ipLen+tcpLen+i] = byte(i)
}
hdr := virtio.Hdr{
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
GSOType: unix.VIRTIO_NET_HDR_GSO_TCPV4,
HdrLen: uint16(ipLen + tcpLen),
GSOSize: uint16(sz.mss),
CsumStart: uint16(ipLen),
CsumOffset: 16,
}
scratch := make([]byte, testSegScratchSize)
out := make([][]byte, 0, 64)
// SegmentSuperpacket consumes its input destructively; restore
// pkt from a master copy each iteration. The restore mirrors the
// kernel→userspace copy that hands a fresh GSO blob to the
// segmenter in production, so it's representative cost rather
// than bench overhead.
master := append([]byte(nil), pkt...)
work := make([]byte, len(pkt))
b.SetBytes(int64(len(pkt)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
copy(work, master)
out = out[:0]
if err := segmentForTest(work, hdr, &out, scratch); err != nil {
b.Fatal(err)
}
}
})
}
}
// TestTunFileWriteVnetHdrNoAlloc verifies the IFF_VNET_HDR fast-path write is
// allocation-free. We write to /dev/null so every call succeeds synchronously.
func TestTunFileWriteVnetHdrNoAlloc(t *testing.T) {
fd, err := unix.Open("/dev/null", os.O_WRONLY, 0)
if err != nil {
t.Fatalf("open /dev/null: %v", err)
}
t.Cleanup(func() { _ = unix.Close(fd) })
tf := &Offload{fd: fd}
payload := make([]byte, 1400)
// Warm up (first call may trigger one-time internal allocations elsewhere).
if _, err := tf.Write(payload); err != nil {
t.Fatalf("Write: %v", err)
}
allocs := testing.AllocsPerRun(1000, func() {
if _, err := tf.Write(payload); err != nil {
t.Fatalf("Write: %v", err)
}
})
if allocs != 0 {
t.Fatalf("Write allocated %.1f times per call, want 0", allocs)
}
}
// buildTSOv6 builds a synthetic IPv6/TCP TSO superpacket with payLen bytes
// of payload, segmented at gso. Returns the packet bytes only; the
// virtio_net_hdr is the caller's responsibility.
func buildTSOv6(payLen, gso int) []byte {
const ipLen = 40
const tcpLen = 20
pkt := make([]byte, ipLen+tcpLen+payLen)
pkt[0] = 0x60 // version 6
binary.BigEndian.PutUint16(pkt[4:6], uint16(tcpLen+payLen))
pkt[6] = unix.IPPROTO_TCP
pkt[7] = 64
pkt[8] = 0xfe
pkt[9] = 0x80
pkt[23] = 1
pkt[24] = 0xfe
pkt[25] = 0x80
pkt[39] = 2
binary.BigEndian.PutUint16(pkt[40:42], 12345)
binary.BigEndian.PutUint16(pkt[42:44], 80)
binary.BigEndian.PutUint32(pkt[44:48], 7)
binary.BigEndian.PutUint32(pkt[48:52], 99)
pkt[52] = 0x50
pkt[53] = 0x10 // ACK only
binary.BigEndian.PutUint16(pkt[54:56], 65535)
for i := 0; i < payLen; i++ {
pkt[ipLen+tcpLen+i] = byte(i)
}
return pkt
}
// TestDecodeReadFitsMaxTSOAtDrainThreshold proves the rxBuf sizing is
// correct: when rxOff is at the maximum value the drain headroom check
// allows, decodeRead must still be able to absorb a worst-case 64KiB
// TSO superpacket without dropping the burst. With segmentation deferred
// to encrypt time, decodeRead writes only the kernel-supplied bytes into
// rxBuf, so the size requirement is just "fit one worst-case input."
//
// Regression history: in a prior layout the rx buffer doubled as the
// segmentation output, a near-threshold drain read returned "scratch too
// small", the whole 45-segment TSO burst was dropped, and the remote's TCP
// fast-retransmit collapsed cwnd. Keeping this test in the new layout
// guards against re-introducing a drain headroom shortfall.
func TestDecodeReadFitsMaxTSOAtDrainThreshold(t *testing.T) {
const ipv6HdrLen = 40
const tcpHdrLen = 20
const headerLen = ipv6HdrLen + tcpHdrLen
// Maximum TUN read body. The tunReadBufSize cap on readv's body iovec
// is what bounds the kernel's superpacket length.
pktLen := tunReadBufSize
payLen := pktLen - headerLen
const targetSegs = 64
gsoSize := (payLen + targetSegs - 1) / targetSegs
pkt := buildTSOv6(payLen, gsoSize)
if len(pkt) != pktLen {
t.Fatalf("buildTSOv6 produced %d bytes, want %d", len(pkt), pktLen)
}
o := &Offload{
rxBuf: make([]byte, tunRxBufCap),
}
// rxOff at the maximum value the drain headroom check permits before
// it would refuse another read. Any drain-time read up to this
// threshold MUST still process correctly.
o.rxOff = tunRxBufCap - tunRxBufSize
// Stage the body in rxBuf as if readv(2) just placed it there.
copy(o.rxBuf[o.rxOff:], pkt)
// Encode the matching virtio_net_hdr.
hdr := virtio.Hdr{
Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
GSOType: unix.VIRTIO_NET_HDR_GSO_TCPV6,
HdrLen: uint16(headerLen),
GSOSize: uint16(gsoSize),
CsumStart: uint16(ipv6HdrLen),
CsumOffset: 16,
}
hdr.Encode(o.readVnetScratch[:])
startRxOff := o.rxOff
if err := o.decodeRead(pktLen); err != nil {
t.Fatalf("decodeRead at drain threshold returned %v — rxBuf sizing regression: "+
"tunRxBufSize=%d must hold one worst-case input (%d)",
err, tunRxBufSize, pktLen)
}
if len(o.pending) != 1 {
t.Fatalf("got %d packets, want 1 superpacket entry", len(o.pending))
}
got := o.pending[0]
if !got.GSO.IsSuperpacket() {
t.Fatalf("expected superpacket GSO metadata, got %+v", got.GSO)
}
if got.GSO.Proto != GSOProtoTCP {
t.Errorf("GSO.Proto=%d want TCP", got.GSO.Proto)
}
if got.GSO.Size != uint16(gsoSize) {
t.Errorf("GSO.Size=%d want %d", got.GSO.Size, gsoSize)
}
if got.GSO.HdrLen != uint16(headerLen) {
t.Errorf("GSO.HdrLen=%d want %d", got.GSO.HdrLen, headerLen)
}
if got.GSO.CsumStart != uint16(ipv6HdrLen) {
t.Errorf("GSO.CsumStart=%d want %d", got.GSO.CsumStart, ipv6HdrLen)
}
if len(got.Bytes) != pktLen {
t.Errorf("len(Bytes)=%d want %d", len(got.Bytes), pktLen)
}
// rxOff advances exactly by the kernel-supplied body length — no
// segmentation output to account for any more.
if o.rxOff != startRxOff+pktLen {
t.Errorf("rxOff=%d want %d", o.rxOff, startRxOff+pktLen)
}
if o.rxOff > tunRxBufCap {
t.Fatalf("rxOff=%d overran rxBuf (cap=%d)", o.rxOff, tunRxBufCap)
}
// Validate that segmenting the returned superpacket reproduces the
// expected per-segment IPv6 payload length and TCP checksum.
wantSegs := (payLen + gsoSize - 1) / gsoSize
gotSegs := 0
if err := SegmentSuperpacket(got, func(seg []byte) error {
defer func() { gotSegs++ }()
if len(seg) < headerLen+1 {
t.Errorf("seg %d too short: %d", gotSegs, len(seg))
return nil
}
if seg[0]>>4 != 6 {
t.Errorf("seg %d: bad IP version %#x", gotSegs, seg[0])
}
segPay := len(seg) - headerLen
gotPL := binary.BigEndian.Uint16(seg[4:6])
if gotPL != uint16(tcpHdrLen+segPay) {
t.Errorf("seg %d: payload_len=%d want %d", gotSegs, gotPL, tcpHdrLen+segPay)
}
psum := pseudoHeaderIPv6(seg[8:24], seg[24:40], unix.IPPROTO_TCP, tcpHdrLen+segPay)
if !verifyChecksum(seg[ipv6HdrLen:], psum) {
t.Errorf("seg %d: bad TCP checksum", gotSegs)
}
return nil
}); err != nil {
t.Fatalf("SegmentSuperpacket: %v", err)
}
if gotSegs != wantSegs {
t.Fatalf("got %d segments, want %d", gotSegs, wantSegs)
}
}

View File

@@ -0,0 +1,43 @@
//go:build linux && !android && !e2e_testing
// +build linux,!android,!e2e_testing
package virtio
import "encoding/binary"
// Size is the on-wire length of struct virtio_net_hdr the kernel
// prepends/expects on a TUN opened with IFF_VNET_HDR (TUNSETVNETHDRSZ
// not set).
const Size = 10
// Hdr is the Go view of the legacy virtio_net_hdr.
type Hdr struct {
Flags uint8
GSOType uint8
HdrLen uint16
GSOSize uint16
CsumStart uint16
CsumOffset uint16
}
// Decode reads a virtio_net_hdr in host byte order (TUN default; we never
// call TUNSETVNETLE so the kernel matches our endianness).
func (h *Hdr) Decode(b []byte) {
h.Flags = b[0]
h.GSOType = b[1]
h.HdrLen = binary.NativeEndian.Uint16(b[2:4])
h.GSOSize = binary.NativeEndian.Uint16(b[4:6])
h.CsumStart = binary.NativeEndian.Uint16(b[6:8])
h.CsumOffset = binary.NativeEndian.Uint16(b[8:10])
}
// Encode is the inverse of Decode: writes the virtio_net_hdr fields into b
// (must be at least Size bytes). Used to emit a TSO superpacket on egress.
func (h *Hdr) Encode(b []byte) {
b[0] = h.Flags
b[1] = h.GSOType
binary.NativeEndian.PutUint16(b[2:4], h.HdrLen)
binary.NativeEndian.PutUint16(b[4:6], h.GSOSize)
binary.NativeEndian.PutUint16(b[6:8], h.CsumStart)
binary.NativeEndian.PutUint16(b[8:10], h.CsumOffset)
}

View File

@@ -0,0 +1,401 @@
//go:build linux && !android && !e2e_testing
// +build linux,!android,!e2e_testing
// Package virtio implements the pure validation, header-correction, and
// per-segment slicing logic for kernel-supplied TSO/USO superpackets on
// IFF_VNET_HDR TUN devices. It is FD-free and depends only on the byte
// layout of the virtio_net_hdr and the IP/TCP/UDP headers it describes,
// so it can be unit-tested in isolation from the tio Queue runtime.
package virtio
import (
"encoding/binary"
"errors"
"fmt"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/tcpip/checksum"
)
// Protocol header size bounds used to validate / cap kernel-supplied offsets.
const (
ipv4HeaderMinLen = 20 // IHL=5, no options
ipv4HeaderMaxLen = 60 // IHL=15, max options
ipv6FixedLen = 40 // IPv6 base header; extensions would extend this
tcpHeaderMinLen = 20 // data-offset=5, no options
tcpHeaderMaxLen = 60 // data-offset=15, max options
)
// Byte offsets inside an IPv4 header.
const (
ipv4TotalLenOff = 2
ipv4IDOff = 4
ipv4ChecksumOff = 10
ipv4SrcOff = 12
ipv4AddrsEnd = 20 // end of dst address (ipv4SrcOff + 2*4)
)
// Byte offsets inside an IPv6 header.
const (
ipv6PayloadLenOff = 4
ipv6SrcOff = 8
ipv6AddrsEnd = 40 // end of dst address (ipv6SrcOff + 2*16)
)
// Byte offsets inside a TCP header (relative to its start, i.e. csumStart).
const (
tcpSeqOff = 4
tcpDataOffOff = 12 // upper nibble is header len in 32-bit words
tcpFlagsOff = 13
tcpChecksumOff = 16
)
// UDP header is fixed at 8 bytes: {sport, dport, length, checksum}.
const (
udpHeaderLen = 8
udpLengthOff = 4
udpChecksumOff = 6
)
// tcpFinPshMask is cleared on every segment except the last of a TSO burst.
const tcpFinPshMask = 0x09 // FIN(0x01) | PSH(0x08)
// tcpCwrFlag is cleared on every segment except the first. Per RFC 3168
// §6.1.2 the CWR bit signals a one-shot transition (the sender just halved
// its window) and must appear on the first segment of a TSO burst only.
const tcpCwrFlag = 0x80
// CheckValid rejects packets whose virtio_net_hdr/IP combination would
// cause a downstream miscompute. The TUN should never emit RSC_INFO and
// the GSO type must agree with the IP version nibble.
func CheckValid(pkt []byte, hdr Hdr) error {
// When RSC_INFO is set the csum_start/csum_offset fields are repurposed to
// carry coalescing info rather than checksum offsets. A TUN writing via
// IFF_VNET_HDR should never emit this, but if it did we would silently
// miscompute the segment checksums — refuse the packet instead.
if hdr.Flags&unix.VIRTIO_NET_HDR_F_RSC_INFO != 0 {
return fmt.Errorf("virtio RSC_INFO flag not supported on TUN reads")
}
if len(pkt) < ipv4HeaderMinLen {
return fmt.Errorf("packet too short")
}
ipVersion := pkt[0] >> 4
switch hdr.GSOType {
case unix.VIRTIO_NET_HDR_GSO_TCPV4:
if ipVersion != 4 {
return fmt.Errorf("invalid IP version %d for GSO type %d", ipVersion, hdr.GSOType)
}
case unix.VIRTIO_NET_HDR_GSO_TCPV6:
if ipVersion != 6 {
return fmt.Errorf("invalid IP version %d for GSO type %d", ipVersion, hdr.GSOType)
}
case unix.VIRTIO_NET_HDR_GSO_UDP_L4:
// USO carries either v4 or v6; the leading nibble disambiguates.
if !(ipVersion == 4 || ipVersion == 6) {
return fmt.Errorf("invalid IP version %d for GSO type %d", ipVersion, hdr.GSOType)
}
default:
if !(ipVersion == 6 || ipVersion == 4) {
return fmt.Errorf("invalid IP version %d for GSO type %d", ipVersion, hdr.GSOType)
}
}
return nil
}
// CorrectHdrLen rewrites hdr.HdrLen based on the actual transport header
// length read out of pkt. The kernel's hdr.HdrLen on the FORWARD path can
// be the length of the entire first packet, so we don't trust it.
func CorrectHdrLen(pkt []byte, hdr *Hdr) error {
// Thank you wireguard-go for documenting these edge-cases
// Don't trust hdr.hdrLen from the kernel as it can be equal to the length
// of the entire first packet when the kernel is handling it as part of a
// FORWARD path. Instead, parse the transport header length and add it onto
// csumStart, which is synonymous for IP header length.
if hdr.GSOType == unix.VIRTIO_NET_HDR_GSO_UDP_L4 {
hdr.HdrLen = hdr.CsumStart + 8
} else {
if len(pkt) <= int(hdr.CsumStart+tcpDataOffOff) {
return errors.New("packet is too short")
}
tcpHLen := uint16(pkt[hdr.CsumStart+tcpDataOffOff] >> 4 * 4)
if tcpHLen < 20 || tcpHLen > 60 {
// A TCP header must be between 20 and 60 bytes in length.
return fmt.Errorf("tcp header len is invalid: %d", tcpHLen)
}
hdr.HdrLen = hdr.CsumStart + tcpHLen
}
if len(pkt) < int(hdr.HdrLen) {
return fmt.Errorf("length of packet (%d) < virtioNetHdr.HdrLen (%d)", len(pkt), hdr.HdrLen)
}
if hdr.HdrLen < hdr.CsumStart {
return fmt.Errorf("virtioNetHdr.HdrLen (%d) < virtioNetHdr.CsumStart (%d)", hdr.HdrLen, hdr.CsumStart)
}
cSumAt := int(hdr.CsumStart + hdr.CsumStart)
if cSumAt+1 >= len(pkt) {
return fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(pkt))
}
return nil
}
// SegmentTCP walks a TSO superpacket pkt, yielding each segment as a
// slice into pkt itself. Per-segment plaintext is laid out by sliding a
// freshly-patched copy of the L3+L4 header into pkt at offset i*gsoSize,
// where it sits immediately before that segment's payload chunk in the
// original buffer. The slide is destructive: iter i's header write overwrites
// the last hdrLen bytes of seg_{i-1}'s payload, which is dead by the time
// the next iteration begins. pkt is consumed by this call and must not be
// inspected by the caller after the final yield.
func SegmentTCP(pkt []byte, hdrLenU, csumStartU, gsoSizeU uint16, yield func(seg []byte) error) error {
if gsoSizeU == 0 {
return fmt.Errorf("gso_size is zero")
}
if csumStartU == 0 {
return fmt.Errorf("csum_start is zero")
}
headerLen := int(hdrLenU)
csumStart := int(csumStartU)
isV4 := pkt[0]>>4 == 4
tcpHdrLen := int(pkt[csumStart+tcpDataOffOff]>>4) * 4
payLen := len(pkt) - headerLen
gsoSize := int(gsoSizeU)
numSeg := (payLen + gsoSize - 1) / gsoSize
if numSeg == 0 {
numSeg = 1
}
origSeq := binary.BigEndian.Uint32(pkt[csumStart+tcpSeqOff : csumStart+tcpSeqOff+4])
origFlags := pkt[csumStart+tcpFlagsOff]
var tmp [tcpHeaderMaxLen]byte
copy(tmp[:tcpHdrLen], pkt[csumStart:headerLen])
tmp[tcpSeqOff], tmp[tcpSeqOff+1], tmp[tcpSeqOff+2], tmp[tcpSeqOff+3] = 0, 0, 0, 0
tmp[tcpFlagsOff] = 0
tmp[tcpChecksumOff], tmp[tcpChecksumOff+1] = 0, 0
baseTcpHdrSum := uint32(checksum.Checksum(tmp[:tcpHdrLen], 0))
var baseProtoSum uint32
if isV4 {
baseProtoSum = uint32(checksum.Checksum(pkt[ipv4SrcOff:ipv4AddrsEnd], 0))
} else {
baseProtoSum = uint32(checksum.Checksum(pkt[ipv6SrcOff:ipv6AddrsEnd], 0))
}
baseProtoSum += uint32(unix.IPPROTO_TCP)
var origIPID uint16
var baseIPHdrSum uint32
if isV4 {
origIPID = binary.BigEndian.Uint16(pkt[ipv4IDOff : ipv4IDOff+2])
ihl := int(pkt[0]&0x0f) * 4
if ihl < ipv4HeaderMinLen || ihl > csumStart {
return fmt.Errorf("bad IPv4 IHL: %d", ihl)
}
var ipTmp [ipv4HeaderMaxLen]byte
copy(ipTmp[:ihl], pkt[:ihl])
ipTmp[ipv4TotalLenOff], ipTmp[ipv4TotalLenOff+1] = 0, 0
ipTmp[ipv4IDOff], ipTmp[ipv4IDOff+1] = 0, 0
ipTmp[ipv4ChecksumOff], ipTmp[ipv4ChecksumOff+1] = 0, 0
baseIPHdrSum = uint32(checksum.Checksum(ipTmp[:ihl], 0))
}
for i := 0; i < numSeg; i++ {
segStart := i * gsoSize
segEnd := segStart + gsoSize
if segEnd > payLen {
segEnd = payLen
}
segPayLen := segEnd - segStart
segLen := headerLen + segPayLen
headerOff := i * gsoSize
// Slide the header into place immediately before this segment's
// payload. Iter 0's header is already at pkt[:headerLen]; for
// i ≥ 1 we copy from there. The constant-byte fields of pkt[:headerLen]
// survive iter 0's in-place patches (only seq/flags/cksum/totalLen/id
// are touched), and iter 0's stale variable-field values are
// overwritten by the per-segment patches below.
if i > 0 {
copy(pkt[headerOff:headerOff+headerLen], pkt[:headerLen])
}
seg := pkt[headerOff : headerOff+segLen]
segSeq := origSeq + uint32(segStart)
segFlags := origFlags
if i != 0 {
segFlags &^= tcpCwrFlag
}
if i != numSeg-1 {
segFlags &^= tcpFinPshMask
}
totalLen := segLen
if isV4 {
segID := origIPID + uint16(i)
binary.BigEndian.PutUint16(seg[ipv4TotalLenOff:ipv4TotalLenOff+2], uint16(totalLen))
binary.BigEndian.PutUint16(seg[ipv4IDOff:ipv4IDOff+2], segID)
ipSum := baseIPHdrSum + uint32(totalLen) + uint32(segID)
binary.BigEndian.PutUint16(seg[ipv4ChecksumOff:ipv4ChecksumOff+2], foldComplement(ipSum))
} else {
binary.BigEndian.PutUint16(seg[ipv6PayloadLenOff:ipv6PayloadLenOff+2], uint16(headerLen-ipv6FixedLen+segPayLen))
}
binary.BigEndian.PutUint32(seg[csumStart+tcpSeqOff:csumStart+tcpSeqOff+4], segSeq)
seg[csumStart+tcpFlagsOff] = segFlags
tcpLen := tcpHdrLen + segPayLen
// Payload bytes still live at their original offset in pkt. The
// header slide above only writes into pkt[i*G : i*G+H], which is
// the tail of seg_{i-1}'s payload (already consumed) and never
// overlaps seg_i's own payload at pkt[H+i*G : H+(i+1)*G].
paySum := uint32(checksum.Checksum(pkt[headerLen+segStart:headerLen+segEnd], 0))
wide := uint64(baseTcpHdrSum) + uint64(paySum) + uint64(baseProtoSum)
wide += uint64(segSeq) + uint64(segFlags) + uint64(tcpLen)
wide = (wide & 0xffffffff) + (wide >> 32)
wide = (wide & 0xffffffff) + (wide >> 32)
binary.BigEndian.PutUint16(seg[csumStart+tcpChecksumOff:csumStart+tcpChecksumOff+2], foldComplement(uint32(wide)))
if err := yield(seg); err != nil {
return err
}
}
return nil
}
// SegmentUDP walks a USO superpacket, sliding a per-segment-patched
// L3+L4 header into pkt at offset i*gsoSize and yielding pkt[i*G:i*G+segLen]
// to the caller. Per-segment patches are total_len + IPv4 csum (or IPv6
// payload_len) plus the UDP length and checksum. pkt is consumed
// destructively; see SegmentTCP for the layout reasoning.
//
// UDP-GSO leaves the IPv4 ID identical across segments (the kernel does not
// bump it), which is why the IP-level per-segment work is limited to
// total_len + IPv4 header checksum (v4) or payload_len (v6).
func SegmentUDP(pkt []byte, hdrLenU, csumStartU, gsoSizeU uint16, yield func(seg []byte) error) error {
if gsoSizeU == 0 {
return fmt.Errorf("gso_size is zero")
}
if csumStartU == 0 {
return fmt.Errorf("csum_start is zero")
}
isV4 := pkt[0]>>4 == 4
headerLen := int(hdrLenU)
csumStart := int(csumStartU)
if headerLen-csumStart != udpHeaderLen {
return fmt.Errorf("udp header len mismatch: %d", headerLen-csumStart)
}
payLen := len(pkt) - headerLen
gsoSize := int(gsoSizeU)
numSeg := (payLen + gsoSize - 1) / gsoSize
if numSeg == 0 {
numSeg = 1
}
var udpTmp [udpHeaderLen]byte
copy(udpTmp[:], pkt[csumStart:headerLen])
udpTmp[udpLengthOff], udpTmp[udpLengthOff+1] = 0, 0
udpTmp[udpChecksumOff], udpTmp[udpChecksumOff+1] = 0, 0
baseUDPHdrSum := uint32(checksum.Checksum(udpTmp[:], 0))
var baseProtoSum uint32
if isV4 {
baseProtoSum = uint32(checksum.Checksum(pkt[ipv4SrcOff:ipv4AddrsEnd], 0))
} else {
baseProtoSum = uint32(checksum.Checksum(pkt[ipv6SrcOff:ipv6AddrsEnd], 0))
}
baseProtoSum += uint32(unix.IPPROTO_UDP)
var baseIPHdrSum uint32
if isV4 {
ihl := int(pkt[0]&0x0f) * 4
if ihl < ipv4HeaderMinLen || ihl > csumStart {
return fmt.Errorf("bad IPv4 IHL: %d", ihl)
}
var ipTmp [ipv4HeaderMaxLen]byte
copy(ipTmp[:ihl], pkt[:ihl])
ipTmp[ipv4TotalLenOff], ipTmp[ipv4TotalLenOff+1] = 0, 0
ipTmp[ipv4ChecksumOff], ipTmp[ipv4ChecksumOff+1] = 0, 0
baseIPHdrSum = uint32(checksum.Checksum(ipTmp[:ihl], 0))
}
for i := 0; i < numSeg; i++ {
segStart := i * gsoSize
segEnd := segStart + gsoSize
if segEnd > payLen {
segEnd = payLen
}
segPayLen := segEnd - segStart
segLen := headerLen + segPayLen
headerOff := i * gsoSize
if i > 0 {
copy(pkt[headerOff:headerOff+headerLen], pkt[:headerLen])
}
seg := pkt[headerOff : headerOff+segLen]
totalLen := segLen
udpLen := udpHeaderLen + segPayLen
if isV4 {
binary.BigEndian.PutUint16(seg[ipv4TotalLenOff:ipv4TotalLenOff+2], uint16(totalLen))
ipSum := baseIPHdrSum + uint32(totalLen)
binary.BigEndian.PutUint16(seg[ipv4ChecksumOff:ipv4ChecksumOff+2], foldComplement(ipSum))
} else {
binary.BigEndian.PutUint16(seg[ipv6PayloadLenOff:ipv6PayloadLenOff+2], uint16(headerLen-ipv6FixedLen+segPayLen))
}
binary.BigEndian.PutUint16(seg[csumStart+udpLengthOff:csumStart+udpLengthOff+2], uint16(udpLen))
paySum := uint32(checksum.Checksum(pkt[headerLen+segStart:headerLen+segEnd], 0))
wide := uint64(baseUDPHdrSum) + uint64(paySum) + uint64(baseProtoSum)
wide += uint64(udpLen) + uint64(udpLen)
wide = (wide & 0xffffffff) + (wide >> 32)
wide = (wide & 0xffffffff) + (wide >> 32)
csum := foldComplement(uint32(wide))
if csum == 0 {
csum = 0xffff
}
binary.BigEndian.PutUint16(seg[csumStart+udpChecksumOff:csumStart+udpChecksumOff+2], csum)
if err := yield(seg); err != nil {
return err
}
}
return nil
}
// FinishChecksum computes the L4 checksum for a non-GSO packet that the kernel
// handed us with NEEDS_CSUM set. csum_start / csum_offset point at the 16-bit
// checksum field; we zero it, fold a full sum (the field was pre-loaded with
// the pseudo-header partial sum by the kernel), and store the result.
func FinishChecksum(seg []byte, hdr Hdr) error {
cs := int(hdr.CsumStart)
co := int(hdr.CsumOffset)
if cs+co+2 > len(seg) {
return fmt.Errorf("csum offsets out of range: start=%d offset=%d len=%d", cs, co, len(seg))
}
// The kernel stores a partial pseudo-header sum at [cs+co:]; sum over the
// L4 region starting at cs, folding the prior partial in as the seed.
partial := binary.BigEndian.Uint16(seg[cs+co : cs+co+2])
seg[cs+co] = 0
seg[cs+co+1] = 0
binary.BigEndian.PutUint16(seg[cs+co:cs+co+2], ^checksum.Checksum(seg[cs:], partial))
return nil
}
// foldComplement folds a 32-bit one's-complement partial sum to 16 bits and
// complements it, yielding the on-wire Internet checksum value.
func foldComplement(sum uint32) uint16 {
sum = (sum & 0xffff) + (sum >> 16)
sum = (sum & 0xffff) + (sum >> 16)
return ^uint16(sum)
}

View File

@@ -27,15 +27,15 @@ type tun struct {
l *slog.Logger l *slog.Logger
readBuf []byte readBuf []byte
batchRet [1][]byte batchRet [1]tio.Packet
} }
func (t *tun) Read() ([][]byte, error) { func (t *tun) Read() ([]tio.Packet, error) {
n, err := t.rwc.Read(t.readBuf) n, err := t.rwc.Read(t.readBuf)
if err != nil { if err != nil {
return nil, err return nil, err
} }
t.batchRet[0] = t.readBuf[:n] t.batchRet[0] = tio.Packet{Bytes: t.readBuf[:n]}
return t.batchRet[:], nil return t.batchRet[:], nil
} }

View File

@@ -37,7 +37,7 @@ type tun struct {
out []byte out []byte
readBuf []byte readBuf []byte
batchRet [1][]byte batchRet [1]tio.Packet
} }
type ifReq struct { type ifReq struct {
@@ -516,12 +516,12 @@ func (t *tun) readOne(to []byte) (int, error) {
return n - 4, err return n - 4, err
} }
func (t *tun) Read() ([][]byte, error) { func (t *tun) Read() ([]tio.Packet, error) {
n, err := t.readOne(t.readBuf) n, err := t.readOne(t.readBuf)
if err != nil { if err != nil {
return nil, err return nil, err
} }
t.batchRet[0] = t.readBuf[:n] t.batchRet[0] = tio.Packet{Bytes: t.readBuf[:n]}
return t.batchRet[:], nil return t.batchRet[:], nil
} }

View File

@@ -23,23 +23,41 @@ type disabledTun struct {
rx metrics.Counter rx metrics.Counter
l *slog.Logger l *slog.Logger
numReaders int numReaders int
batchRet [1][]byte
} }
func (t *disabledTun) Read() ([][]byte, error) { // disabledQueue is one tio.Queue view onto a shared disabledTun. Each queue
r, ok := <-t.read // owns a private batchRet so concurrent Read calls from different reader
// goroutines do not race on the returned slice.
type disabledQueue struct {
parent *disabledTun
batchRet [1]tio.Packet
}
func (q *disabledQueue) Read() ([]tio.Packet, error) {
r, ok := <-q.parent.read
if !ok { if !ok {
return nil, io.EOF return nil, io.EOF
} }
t.tx.Inc(1) q.parent.tx.Inc(1)
if t.l.Enabled(context.Background(), slog.LevelDebug) { if q.parent.l.Enabled(context.Background(), slog.LevelDebug) {
t.l.Debug("Write payload", "raw", prettyPacket(r)) q.parent.l.Debug("Write payload", "raw", prettyPacket(r))
} }
t.batchRet[0] = r q.batchRet[0] = tio.Packet{Bytes: r}
return t.batchRet[:], nil return q.batchRet[:], nil
}
// Write on a queue forwards to the underlying disabledTun. All queues share
// one ICMP-handling/log path so this is a thin pass-through.
func (q *disabledQueue) Write(b []byte) (int, error) {
return q.parent.Write(b)
}
// Close on a queue is a no-op. The shared channel and metrics are owned by
// the disabledTun; Close on the device tears them down once for everybody.
func (q *disabledQueue) Close() error {
return nil
} }
func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *slog.Logger) *disabledTun { func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *slog.Logger) *disabledTun {
@@ -120,7 +138,7 @@ func (t *disabledTun) NewMultiQueueReader() error {
func (t *disabledTun) Readers() []tio.Queue { func (t *disabledTun) Readers() []tio.Queue {
out := make([]tio.Queue, t.numReaders) out := make([]tio.Queue, t.numReaders)
for i := range t.numReaders { for i := range t.numReaders {
out[i] = t out[i] = &disabledQueue{parent: t}
} }
return out return out
} }

View File

@@ -104,7 +104,7 @@ type tun struct {
closed atomic.Bool closed atomic.Bool
readBuf []byte readBuf []byte
batchRet [1][]byte batchRet [1]tio.Packet
} }
// blockOnRead waits until the tun fd is readable or shutdown has been signaled. // blockOnRead waits until the tun fd is readable or shutdown has been signaled.
@@ -159,12 +159,12 @@ func (t *tun) blockOnWrite() error {
return nil return nil
} }
func (t *tun) Read() ([][]byte, error) { func (t *tun) Read() ([]tio.Packet, error) {
n, err := t.readOne(t.readBuf) n, err := t.readOne(t.readBuf)
if err != nil { if err != nil {
return nil, err return nil, err
} }
t.batchRet[0] = t.readBuf[:n] t.batchRet[0] = tio.Packet{Bytes: t.readBuf[:n]}
return t.batchRet[:], nil return t.batchRet[:], nil
} }

View File

@@ -29,15 +29,15 @@ type tun struct {
l *slog.Logger l *slog.Logger
readBuf []byte readBuf []byte
batchRet [1][]byte batchRet [1]tio.Packet
} }
func (t *tun) Read() ([][]byte, error) { func (t *tun) Read() ([]tio.Packet, error) {
n, err := t.rwc.Read(t.readBuf) n, err := t.rwc.Read(t.readBuf)
if err != nil { if err != nil {
return nil, err return nil, err
} }
t.batchRet[0] = t.readBuf[:n] t.batchRet[0] = tio.Packet{Bytes: t.readBuf[:n]}
return t.batchRet[:], nil return t.batchRet[:], nil
} }

View File

@@ -25,7 +25,7 @@ import (
) )
type tun struct { type tun struct {
readers tio.Container readers tio.QueueSet
closeLock sync.Mutex closeLock sync.Mutex
Device string Device string
vpnNetworks []netip.Prefix vpnNetworks []netip.Prefix
@@ -34,6 +34,14 @@ type tun struct {
TXQueueLen int TXQueueLen int
deviceIndex int deviceIndex int
ioctlFd uintptr ioctlFd uintptr
vnetHdr bool
// routeFeatureECN, when true, sets RTAX_FEATURE_ECN on every route we
// install for the tun. The kernel then actively negotiates ECN for
// connections destined to those prefixes (equivalent to `ip route
// change ... features ecn`) regardless of net.ipv4.tcp_ecn, so flows
// across the nebula mesh use ECN even when the host default is the
// passive setting (=2). Disable via tunnels.ecn=false.
routeFeatureECN bool
Routes atomic.Pointer[[]Route] Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]] routeTree atomic.Pointer[bart.Table[routing.Gateways]]
@@ -72,7 +80,9 @@ type ifreqQLEN struct {
} }
func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
t, err := newTunGeneric(c, l, deviceFd, vpnNetworks) // We don't know what flags the caller opened this fd with and can't turn
// on IFF_VNET_HDR after TUNSETIFF, so skip offload on inherited fds.
t, err := newTunGeneric(c, l, deviceFd, false, false, vpnNetworks)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -117,6 +127,18 @@ func tunSetIff(fd int, name string, flags uint16) (string, error) {
return strings.Trim(string(req.Name[:]), "\x00"), nil return strings.Trim(string(req.Name[:]), "\x00"), nil
} }
// tsoOffloadFlags are the TUN_F_* bits we ask the kernel to enable when a
// TSO-capable TUN is available. CSUM is required as a prerequisite for TSO.
// TSO_ECN tells the kernel we propagate ECN correctly through coalesce and
// segmentation, so it can deliver superpackets whose seed has CWR/ECE set
// or whose IP-level codepoint is CE.
const tsoOffloadFlags = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6 | unix.TUN_F_TSO_ECN
// usoOffloadFlags adds UDP Segmentation Offload to tsoOffloadFlags. Requires
// Linux ≥ 6.2; older kernels reject it and we fall back to TCP-only TSO via
// tsoOffloadFlags.
const usoOffloadFlags = tsoOffloadFlags | unix.TUN_F_USO4 | unix.TUN_F_USO6
func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) { func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
baseFlags := uint16(unix.IFF_TUN | unix.IFF_NO_PI) baseFlags := uint16(unix.IFF_TUN | unix.IFF_NO_PI)
if multiqueue { if multiqueue {
@@ -124,17 +146,51 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, multiqueue
} }
nameStr := c.GetString("tun.dev", "") nameStr := c.GetString("tun.dev", "")
// First try to enable IFF_VNET_HDR via TUNSETIFF and negotiate TUN_F_*
// offloads via TUNSETOFFLOAD so we can receive TSO/USO superpackets.
// We try TSO+USO first, fall back to TSO-only on kernels without USO
// (Linux < 6.2), and finally give up on virtio headers entirely and
// reopen as a plain TUN if neither offload mask is accepted.
fd, err := openTunDev() fd, err := openTunDev()
if err != nil { if err != nil {
return nil, err return nil, err
} }
name, err := tunSetIff(fd, nameStr, baseFlags) vnetHdr := true
usoEnabled := false
name, err := tunSetIff(fd, nameStr, baseFlags|unix.IFF_VNET_HDR)
if err != nil {
_ = unix.Close(fd)
vnetHdr = false
} else {
// Try TSO+USO first. On kernels without USO support (Linux < 6.2)
// the ioctl returns EINVAL; fall back to the TCP-only mask before
// giving up on VNET_HDR entirely.
if err = ioctl(uintptr(fd), unix.TUNSETOFFLOAD, uintptr(usoOffloadFlags)); err == nil {
usoEnabled = true
} else if err = ioctl(uintptr(fd), unix.TUNSETOFFLOAD, uintptr(tsoOffloadFlags)); err != nil {
l.Warn("Failed to enable TUN offload (TSO); proceeding without virtio headers", "error", err)
_ = unix.Close(fd)
vnetHdr = false
}
}
if !vnetHdr {
fd, err = openTunDev()
if err != nil {
return nil, err
}
name, err = tunSetIff(fd, nameStr, baseFlags)
if err != nil { if err != nil {
_ = unix.Close(fd) _ = unix.Close(fd)
return nil, &NameError{Name: nameStr, Underlying: err} return nil, &NameError{Name: nameStr, Underlying: err}
} }
}
t, err := newTunGeneric(c, l, fd, vpnNetworks) if vnetHdr {
l.Info("TUN offload enabled", "tso", true, "uso", usoEnabled)
}
t, err := newTunGeneric(c, l, fd, vnetHdr, usoEnabled, vpnNetworks)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -145,25 +201,34 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, multiqueue
} }
// newTunGeneric does all the stuff common to different tun initialization paths. It will close your files on error. // newTunGeneric does all the stuff common to different tun initialization paths. It will close your files on error.
func newTunGeneric(c *config.C, l *slog.Logger, fd int, vpnNetworks []netip.Prefix) (*tun, error) { func newTunGeneric(c *config.C, l *slog.Logger, fd int, vnetHdr, usoEnabled bool, vpnNetworks []netip.Prefix) (*tun, error) {
container, err := tio.NewPollContainer() var qs tio.QueueSet
var err error
if vnetHdr {
qs, err = tio.NewOffloadQueueSet(usoEnabled)
} else {
qs, err = tio.NewPollQueueSet()
}
if err != nil { if err != nil {
_ = unix.Close(fd) _ = unix.Close(fd)
return nil, err return nil, err
} }
err = container.Add(fd) err = qs.Add(fd)
if err != nil { if err != nil {
_ = unix.Close(fd) _ = unix.Close(fd)
return nil, err return nil, err
} }
t := &tun{ t := &tun{
readers: container, readers: qs,
closeLock: sync.Mutex{}, closeLock: sync.Mutex{},
vnetHdr: vnetHdr,
vpnNetworks: vpnNetworks, vpnNetworks: vpnNetworks,
TXQueueLen: c.GetInt("tun.tx_queue", 500), TXQueueLen: c.GetInt("tun.tx_queue", 500),
useSystemRoutes: c.GetBool("tun.use_system_route_table", false), useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0), useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0),
routeFeatureECN: c.GetBool("tunnels.ecn", true),
routesFromSystem: map[netip.Prefix]routing.Gateways{}, routesFromSystem: map[netip.Prefix]routing.Gateways{},
l: l, l: l,
} }
@@ -271,11 +336,21 @@ func (t *tun) NewMultiQueueReader() error {
} }
flags := uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE) flags := uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE)
if t.vnetHdr {
flags |= unix.IFF_VNET_HDR
}
if _, err = tunSetIff(fd, t.Device, flags); err != nil { if _, err = tunSetIff(fd, t.Device, flags); err != nil {
_ = unix.Close(fd) _ = unix.Close(fd)
return err return err
} }
if t.vnetHdr {
if err = ioctl(uintptr(fd), unix.TUNSETOFFLOAD, uintptr(tsoOffloadFlags)); err != nil {
_ = unix.Close(fd)
return fmt.Errorf("failed to enable offload on multiqueue tun fd: %w", err)
}
}
err = t.readers.Add(fd) err = t.readers.Add(fd)
if err != nil { if err != nil {
_ = unix.Close(fd) _ = unix.Close(fd)
@@ -450,6 +525,18 @@ func (t *tun) setDefaultRoute(cidr netip.Prefix) error {
Table: unix.RT_TABLE_MAIN, Table: unix.RT_TABLE_MAIN,
Type: unix.RTN_UNICAST, Type: unix.RTN_UNICAST,
} }
// Match the metric the kernel uses for its auto-installed connected
// route, so RouteReplace overwrites it in place instead of adding a
// second route at a worse metric. IPv6 connected routes are installed
// at metric 256 (IP6_RT_PRIO_KERN); IPv4 uses 0. Without this, the
// kernel route wins lookups and our MTU / AdvMSS / Features never
// apply on v6.
if cidr.Addr().Is6() {
nr.Priority = 256
}
if t.routeFeatureECN {
nr.Features |= unix.RTAX_FEATURE_ECN
}
err := netlink.RouteReplace(&nr) err := netlink.RouteReplace(&nr)
if err != nil { if err != nil {
t.l.Warn("Failed to set default route MTU, retrying", "error", err, "cidr", cidr) t.l.Warn("Failed to set default route MTU, retrying", "error", err, "cidr", cidr)
@@ -499,6 +586,9 @@ func (t *tun) addRoutes(logErrors bool) error {
if r.Metric > 0 { if r.Metric > 0 {
nr.Priority = r.Metric nr.Priority = r.Metric
} }
if t.routeFeatureECN {
nr.Features |= unix.RTAX_FEATURE_ECN
}
err := netlink.RouteReplace(&nr) err := netlink.RouteReplace(&nr)
if err != nil { if err != nil {

View File

@@ -68,15 +68,15 @@ type tun struct {
fd int fd int
readBuf []byte readBuf []byte
batchRet [1][]byte batchRet [1]tio.Packet
} }
func (t *tun) Read() ([][]byte, error) { func (t *tun) Read() ([]tio.Packet, error) {
n, err := t.readOne(t.readBuf) n, err := t.readOne(t.readBuf)
if err != nil { if err != nil {
return nil, err return nil, err
} }
t.batchRet[0] = t.readBuf[:n] t.batchRet[0] = tio.Packet{Bytes: t.readBuf[:n]}
return t.batchRet[:], nil return t.batchRet[:], nil
} }

View File

@@ -61,15 +61,15 @@ type tun struct {
out []byte out []byte
readBuf []byte readBuf []byte
batchRet [1][]byte batchRet [1]tio.Packet
} }
func (t *tun) Read() ([][]byte, error) { func (t *tun) Read() ([]tio.Packet, error) {
n, err := t.readOne(t.readBuf) n, err := t.readOne(t.readBuf)
if err != nil { if err != nil {
return nil, err return nil, err
} }
t.batchRet[0] = t.readBuf[:n] t.batchRet[0] = tio.Packet{Bytes: t.readBuf[:n]}
return t.batchRet[:], nil return t.batchRet[:], nil
} }

View File

@@ -30,7 +30,7 @@ type TestTun struct {
rxPackets chan []byte // Packets to receive into nebula rxPackets chan []byte // Packets to receive into nebula
TxPackets chan []byte // Packets transmitted outside by nebula TxPackets chan []byte // Packets transmitted outside by nebula
batchRet [1][]byte batchRet [1]tio.Packet
} }
func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) { func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) {
@@ -51,7 +51,9 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*T
l: l, l: l,
rxPackets: make(chan []byte, 10), rxPackets: make(chan []byte, 10),
TxPackets: make(chan []byte, 10), TxPackets: make(chan []byte, 10),
batchRet: [1][]byte{make([]byte, udp.MTU)}, batchRet: [1]tio.Packet{
tio.Packet{Bytes: make([]byte, udp.MTU)},
},
}, nil }, nil
} }
@@ -166,13 +168,13 @@ func (t *TestTun) Close() error {
return nil return nil
} }
func (t *TestTun) Read() ([][]byte, error) { func (t *TestTun) Read() ([]tio.Packet, error) {
t.batchRet[0] = t.batchRet[0][:udp.MTU] t.batchRet[0].Bytes = t.batchRet[0].Bytes[:udp.MTU]
n, err := t.read(t.batchRet[0]) n, err := t.read(t.batchRet[0].Bytes)
if err != nil { if err != nil {
return nil, err return nil, err
} }
t.batchRet[0] = t.batchRet[0][:n] t.batchRet[0].Bytes = t.batchRet[0].Bytes[:n]
return t.batchRet[:], nil return t.batchRet[:], nil
} }

View File

@@ -47,15 +47,15 @@ type winTun struct {
tun *wintun.NativeTun tun *wintun.NativeTun
readBuf []byte readBuf []byte
batchRet [1][]byte batchRet [1]tio.Packet
} }
func (t *winTun) Read() ([][]byte, error) { func (t *winTun) Read() ([]tio.Packet, error) {
n, err := t.tun.Read(t.readBuf, 0) n, err := t.tun.Read(t.readBuf, 0)
if err != nil { if err != nil {
return nil, err return nil, err
} }
t.batchRet[0] = t.readBuf[:n] t.batchRet[0] = tio.Packet{Bytes: t.readBuf[:n]}
return t.batchRet[:], nil return t.batchRet[:], nil
} }

View File

@@ -39,10 +39,10 @@ type UserDevice struct {
inboundWriter *io.PipeWriter inboundWriter *io.PipeWriter
readBuf []byte readBuf []byte
batchRet [1][]byte batchRet [1]tio.Packet
} }
func (d *UserDevice) Read() ([][]byte, error) { func (d *UserDevice) Read() ([]tio.Packet, error) {
if d.readBuf == nil { if d.readBuf == nil {
d.readBuf = make([]byte, defaultBatchBufSize) d.readBuf = make([]byte, defaultBatchBufSize)
} }
@@ -50,7 +50,7 @@ func (d *UserDevice) Read() ([][]byte, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
d.batchRet[0] = d.readBuf[:n] d.batchRet[0] = tio.Packet{Bytes: d.readBuf[:n]}
return d.batchRet[:], nil return d.batchRet[:], nil
} }

View File

@@ -11,12 +11,25 @@ const MTU = 9001
// MaxWriteBatch is the largest batch any Conn.WriteBatch implementation is // MaxWriteBatch is the largest batch any Conn.WriteBatch implementation is
// required to accept. Callers SHOULD NOT pass more than this per call; Linux // required to accept. Callers SHOULD NOT pass more than this per call; Linux
// backends preallocate sendmmsg scratch sized to this value, so exceeding it // backends preallocate sendmmsg scratch sized to this value, so exceeding it
// only costs a chunked retry. // only costs additional sendmmsg chunks within a single WriteBatch call.
const MaxWriteBatch = 128 const MaxWriteBatch = 128
// RxMeta carries per-packet metadata extracted from the RX path (ancillary
// data, kernel offload state, etc.) and passed to EncReader callbacks.
// Backends that do not produce a particular signal leave its zero value.
//
// OuterECN is the 2-bit IP-level ECN codepoint stamped on the carrier
// datagram (extracted from IP_TOS / IPV6_TCLASS cmsg on Linux). Zero
// means Not-ECT, which is also the value backends without ECN RX support
// supply on every packet.
type RxMeta struct {
OuterECN byte
}
type EncReader func( type EncReader func(
addr netip.AddrPort, addr netip.AddrPort,
payload []byte, payload []byte,
meta RxMeta,
) )
type Conn interface { type Conn interface {
@@ -30,11 +43,14 @@ type Conn interface {
ListenOut(r EncReader, flush func()) error ListenOut(r EncReader, flush func()) error
WriteTo(b []byte, addr netip.AddrPort) error WriteTo(b []byte, addr netip.AddrPort) error
// WriteBatch sends a contiguous batch of packets, each with its own // WriteBatch sends a contiguous batch of packets, each with its own
// destination. bufs and addrs must have the same length. Linux uses // destination. bufs and addrs must have the same length. outerECNs may
// sendmmsg(2) for a single syscall; other backends fall back to a // be nil (treated as all-zero / Not-ECT); when non-nil it must have the
// WriteTo loop. Returns on the first error; callers may observe a // same length as bufs, and outerECNs[i] is the 2-bit IP-level ECN
// partial send if some packets went out before the error. // codepoint to set on packet i's outer header. Linux uses sendmmsg(2)
WriteBatch(bufs [][]byte, addrs []netip.AddrPort) error // for a single syscall and attaches the value as IP_TOS / IPV6_TCLASS
// cmsg; other backends ignore it. Returns on the first error; callers
// may observe a partial send if some packets went out before the error.
WriteBatch(bufs [][]byte, addrs []netip.AddrPort, outerECNs []byte) error
ReloadConfig(c *config.C) ReloadConfig(c *config.C)
SupportsMultipleReaders() bool SupportsMultipleReaders() bool
Close() error Close() error
@@ -57,7 +73,7 @@ func (NoopConn) SupportsMultipleReaders() bool {
func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error { func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error {
return nil return nil
} }
func (NoopConn) WriteBatch(_ [][]byte, _ []netip.AddrPort) error { func (NoopConn) WriteBatch(_ [][]byte, _ []netip.AddrPort, _ []byte) error {
return nil return nil
} }
func (NoopConn) ReloadConfig(_ *config.C) { func (NoopConn) ReloadConfig(_ *config.C) {

62
udp/raw_sendmmsg_linux.go Normal file
View File

@@ -0,0 +1,62 @@
//go:build !android && !e2e_testing
// +build !android,!e2e_testing
package udp
import (
"net"
"syscall"
"unsafe"
"golang.org/x/sys/unix"
)
// rawSendmmsg performs sendmmsg(2) over a syscall.RawConn without
// allocating a closure per call. The struct holds preallocated in/out
// scratch (chunk/sent/errno) and a method-value bound at construction so
// rawConn.Write receives a stable function pointer instead of a fresh
// closure on every send.
type rawSendmmsg struct {
msgs []rawMessage
chunk int
sent int
errno syscall.Errno
callback func(fd uintptr) bool
}
// bind wires r.callback to r.run. Must be called once after r.msgs is set;
// subsequent send calls invoke r.callback without rebinding.
func (r *rawSendmmsg) bind() { r.callback = r.run }
// run is the preallocated callback rawConn.Write invokes. It reads its
// input (r.chunk) and writes its outputs (r.sent, r.errno) through the
// rawSendmmsg fields so the method value does not capture per-call locals
// and therefore does not heap-allocate.
func (r *rawSendmmsg) run(fd uintptr) bool {
r1, _, errno := unix.Syscall6(unix.SYS_SENDMMSG, fd,
uintptr(unsafe.Pointer(&r.msgs[0])), uintptr(r.chunk),
0, 0, 0,
)
if errno == syscall.EAGAIN || errno == syscall.EWOULDBLOCK {
return false
}
r.sent = int(r1)
r.errno = errno
return true
}
// send issues sendmmsg over rc against the first n entries of r.msgs.
// Returns the number of entries the kernel processed and any error;
// matches the original sendmmsg helper's contract.
func (r *rawSendmmsg) send(rc syscall.RawConn, n int) (int, error) {
r.chunk = n
r.sent = 0
r.errno = 0
if err := rc.Write(r.callback); err != nil {
return r.sent, err
}
if r.errno != 0 {
return r.sent, &net.OpError{Op: "sendmmsg", Err: r.errno}
}
return r.sent, nil
}

86
udp/rx_reorder_linux.go Normal file
View File

@@ -0,0 +1,86 @@
//go:build !android && !e2e_testing
// +build !android,!e2e_testing
package udp
import (
"cmp"
"net/netip"
"slices"
)
// rxSegment is one nebula packet pulled out of a recvmmsg entry — either a
// lone datagram or one segment of a GRO superpacket. cnt is the big-endian
// uint64 message counter at bytes [8:16] of the nebula header; 0 if the
// segment is too short to contain a header. ecn is the 2-bit IP-level ECN
// codepoint stamped on the carrier (one value per slot, since GRO requires
// equal ECN across coalesced datagrams).
type rxSegment struct {
src netip.AddrPort
cnt uint64
buf []byte
ecn byte
}
// rxReorderBuffer accumulates one recvmmsg batch worth of segments,
// splits any GRO superpackets at gso_size boundaries, stable-sorts by
// (src, port, counter), then delivers in order. The reorder distance is
// bounded by len(buf), which the caller sizes to stay well within the
// receiver's ReplayWindow so older arrivals are not rejected as replays.
type rxReorderBuffer struct {
buf []rxSegment
}
func newRxReorderBuffer(initialCap int) *rxReorderBuffer {
return &rxReorderBuffer{buf: make([]rxSegment, 0, initialCap)}
}
// reset prepares the buffer for the next recvmmsg batch.
func (r *rxReorderBuffer) reset() { r.buf = r.buf[:0] }
// addEntry expands one recvmmsg slot into rxSegments. When segSize <= 0 or
// segSize >= len(payload) the payload is appended as a single segment;
// otherwise the kernel-coalesced GRO superpacket is split at segSize
// boundaries (the kernel guarantees every segment is exactly segSize bytes
// except for the final one, which may be short). ecn applies uniformly to
// every produced segment because GRO requires equal ECN across coalesced
// datagrams.
func (r *rxReorderBuffer) addEntry(from netip.AddrPort, payload []byte, segSize int, ecn byte) {
if segSize <= 0 || segSize >= len(payload) {
r.buf = append(r.buf, rxSegment{from, headerCounter(payload), payload, ecn})
return
}
for off := 0; off < len(payload); off += segSize {
end := off + segSize
if end > len(payload) {
end = len(payload)
}
seg := payload[off:end]
r.buf = append(r.buf, rxSegment{from, headerCounter(seg), seg, ecn})
}
}
// sortStable orders the accumulated segments by (src addr, src port,
// counter). Same-source segments are reordered into counter order;
// cross-source relative order is determined by a stable address compare so
// the sort is total and predictable.
func (r *rxReorderBuffer) sortStable() {
slices.SortStableFunc(r.buf, func(a, b rxSegment) int {
if c := a.src.Addr().Compare(b.src.Addr()); c != 0 {
return c
}
if c := cmp.Compare(a.src.Port(), b.src.Port()); c != 0 {
return c
}
return cmp.Compare(a.cnt, b.cnt)
})
}
// deliver invokes fn once per segment in sorted order, then nils the
// per-entry buf reference so the next batch's append doesn't alias it.
func (r *rxReorderBuffer) deliver(fn EncReader) {
for k := range r.buf {
fn(r.buf[k].src, r.buf[k].buf, RxMeta{OuterECN: r.buf[k].ecn})
r.buf[k].buf = nil
}
}

View File

@@ -0,0 +1,203 @@
//go:build !android && !e2e_testing
// +build !android,!e2e_testing
package udp
import (
"encoding/binary"
"net/netip"
"testing"
)
// makeNebulaPkt returns a buffer whose [8:16] bytes encode the given
// counter big-endian, the rest left zero. Anything shorter than 16 bytes
// would yield counter 0; tests use this to simulate well-formed nebula
// headers (the rxReorderBuffer doesn't care about anything else).
func makeNebulaPkt(cnt uint64, payLen int) []byte {
if payLen < 16 {
payLen = 16
}
b := make([]byte, payLen)
binary.BigEndian.PutUint64(b[8:16], cnt)
return b
}
func srcOf(addr string, port uint16) netip.AddrPort {
return netip.AddrPortFrom(netip.MustParseAddr(addr), port)
}
func TestRxReorderBuffer_LonePassesThrough(t *testing.T) {
r := newRxReorderBuffer(8)
pkt := makeNebulaPkt(42, 100)
r.addEntry(srcOf("1.1.1.1", 4242), pkt, 0, 0x02)
if got := len(r.buf); got != 1 {
t.Fatalf("want 1 entry, got %d", got)
}
if r.buf[0].cnt != 42 {
t.Errorf("counter=%d want 42", r.buf[0].cnt)
}
if r.buf[0].ecn != 0x02 {
t.Errorf("ecn=%#x want 0x02", r.buf[0].ecn)
}
if len(r.buf[0].buf) != 100 {
t.Errorf("buf len=%d want 100", len(r.buf[0].buf))
}
}
func TestRxReorderBuffer_SegSizeGEPayloadIsLone(t *testing.T) {
// segSize >= len(payload) means the kernel did not coalesce this slot.
r := newRxReorderBuffer(8)
pkt := makeNebulaPkt(7, 50)
r.addEntry(srcOf("1.1.1.1", 1), pkt, 50, 0)
if got := len(r.buf); got != 1 {
t.Fatalf("segSize==len: want 1 entry, got %d", got)
}
r.reset()
r.addEntry(srcOf("1.1.1.1", 1), pkt, 60, 0)
if got := len(r.buf); got != 1 {
t.Fatalf("segSize>len: want 1 entry, got %d", got)
}
}
func TestRxReorderBuffer_GROSplitExactMultiple(t *testing.T) {
// 3 segments of 80 bytes each, packed into one 240-byte GRO superpacket.
const segSize = 80
const numSeg = 3
pkt := make([]byte, segSize*numSeg)
for i := range numSeg {
off := i * segSize
binary.BigEndian.PutUint64(pkt[off+8:off+16], uint64(100+i))
}
r := newRxReorderBuffer(8)
r.addEntry(srcOf("2.2.2.2", 5555), pkt, segSize, 0x03)
if got := len(r.buf); got != numSeg {
t.Fatalf("want %d segments, got %d", numSeg, got)
}
for i, seg := range r.buf {
if seg.cnt != uint64(100+i) {
t.Errorf("seg %d: cnt=%d want %d", i, seg.cnt, 100+i)
}
if len(seg.buf) != segSize {
t.Errorf("seg %d: buf len=%d want %d", i, len(seg.buf), segSize)
}
if seg.ecn != 0x03 {
t.Errorf("seg %d: ecn=%#x want 0x03 (uniform across GRO)", i, seg.ecn)
}
}
}
func TestRxReorderBuffer_GROSplitShortFinal(t *testing.T) {
// 200-byte payload, segSize=80 → segments of 80, 80, 40.
const segSize = 80
pkt := make([]byte, 200)
binary.BigEndian.PutUint64(pkt[8:16], 1)
binary.BigEndian.PutUint64(pkt[80+8:80+16], 2)
binary.BigEndian.PutUint64(pkt[160+8:160+16], 3)
r := newRxReorderBuffer(8)
r.addEntry(srcOf("3.3.3.3", 1), pkt, segSize, 0)
if got := len(r.buf); got != 3 {
t.Fatalf("want 3 segments, got %d", got)
}
wantLens := []int{80, 80, 40}
for i, seg := range r.buf {
if len(seg.buf) != wantLens[i] {
t.Errorf("seg %d: len=%d want %d", i, len(seg.buf), wantLens[i])
}
}
}
func TestRxReorderBuffer_SortGroupsBySrcThenCounter(t *testing.T) {
r := newRxReorderBuffer(8)
a := srcOf("1.1.1.1", 1)
b := srcOf("2.2.2.2", 1)
// Insert deliberately scrambled.
r.addEntry(a, makeNebulaPkt(3, 16), 0, 0)
r.addEntry(b, makeNebulaPkt(1, 16), 0, 0)
r.addEntry(a, makeNebulaPkt(1, 16), 0, 0)
r.addEntry(b, makeNebulaPkt(2, 16), 0, 0)
r.addEntry(a, makeNebulaPkt(2, 16), 0, 0)
r.sortStable()
want := []struct {
src netip.AddrPort
cnt uint64
}{
{a, 1}, {a, 2}, {a, 3}, {b, 1}, {b, 2},
}
if got := len(r.buf); got != len(want) {
t.Fatalf("len=%d want %d", got, len(want))
}
for i, w := range want {
if r.buf[i].src != w.src || r.buf[i].cnt != w.cnt {
t.Errorf("idx %d: got %v/%d want %v/%d",
i, r.buf[i].src, r.buf[i].cnt, w.src, w.cnt)
}
}
}
func TestRxReorderBuffer_SortStableAcrossPorts(t *testing.T) {
// Same source addr but different ports — must group by port.
r := newRxReorderBuffer(8)
addr := netip.MustParseAddr("4.4.4.4")
p1 := netip.AddrPortFrom(addr, 1)
p2 := netip.AddrPortFrom(addr, 2)
r.addEntry(p2, makeNebulaPkt(10, 16), 0, 0)
r.addEntry(p1, makeNebulaPkt(20, 16), 0, 0)
r.addEntry(p2, makeNebulaPkt(5, 16), 0, 0)
r.sortStable()
// Expect: p1/20 then p2/5 then p2/10.
if r.buf[0].src.Port() != 1 || r.buf[1].src.Port() != 2 || r.buf[2].src.Port() != 2 {
t.Fatalf("port order broken: %v %v %v",
r.buf[0].src.Port(), r.buf[1].src.Port(), r.buf[2].src.Port())
}
if r.buf[1].cnt != 5 || r.buf[2].cnt != 10 {
t.Errorf("counter order in p2: %d %d (want 5 10)", r.buf[1].cnt, r.buf[2].cnt)
}
}
func TestRxReorderBuffer_DeliverInOrderAndNilsRefs(t *testing.T) {
r := newRxReorderBuffer(4)
a := srcOf("5.5.5.5", 1)
r.addEntry(a, makeNebulaPkt(2, 32), 0, 0x01)
r.addEntry(a, makeNebulaPkt(1, 32), 0, 0x01)
r.sortStable()
var seenCnts []uint64
var seenECN []byte
r.deliver(func(src netip.AddrPort, buf []byte, meta RxMeta) {
seenCnts = append(seenCnts, binary.BigEndian.Uint64(buf[8:16]))
seenECN = append(seenECN, meta.OuterECN)
})
if len(seenCnts) != 2 || seenCnts[0] != 1 || seenCnts[1] != 2 {
t.Errorf("delivery order broken: %v", seenCnts)
}
if seenECN[0] != 0x01 || seenECN[1] != 0x01 {
t.Errorf("ecn passed wrong: %v", seenECN)
}
for i := range r.buf {
if r.buf[i].buf != nil {
t.Errorf("buf[%d].buf not nil after deliver", i)
}
}
}
func TestRxReorderBuffer_ResetIsReusable(t *testing.T) {
r := newRxReorderBuffer(2)
r.addEntry(srcOf("6.6.6.6", 1), makeNebulaPkt(1, 16), 0, 0)
r.addEntry(srcOf("6.6.6.6", 1), makeNebulaPkt(2, 16), 0, 0)
r.reset()
if got := len(r.buf); got != 0 {
t.Fatalf("after reset len=%d want 0", got)
}
r.addEntry(srcOf("6.6.6.6", 1), makeNebulaPkt(7, 16), 0, 0)
if r.buf[0].cnt != 7 {
t.Errorf("after reset+add: cnt=%d want 7", r.buf[0].cnt)
}
}

View File

@@ -140,7 +140,7 @@ func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error {
} }
} }
func (u *StdConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort) error { func (u *StdConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort, _ []byte) error {
for i, b := range bufs { for i, b := range bufs {
if err := u.WriteTo(b, addrs[i]); err != nil { if err := u.WriteTo(b, addrs[i]); err != nil {
return err return err
@@ -188,7 +188,7 @@ func (u *StdConn) ListenOut(r EncReader, flush func()) error {
u.l.Error("unexpected udp socket receive error", "error", err) u.l.Error("unexpected udp socket receive error", "error", err)
} }
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n]) r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n], RxMeta{})
flush() flush()
} }
} }

View File

@@ -0,0 +1,61 @@
//go:build linux && !android && !e2e_testing
package udp
import (
"net/netip"
"testing"
)
// TestPlanRunBreaksOnECNChange confirms that two same-destination, same-size
// packets with different outer ECN end up in separate sendmmsg entries (the
// kernel stamps one outer codepoint per entry, so a run that straddled the
// boundary would silently lose information).
func TestPlanRunBreaksOnECNChange(t *testing.T) {
u := &StdConn{gsoSupported: true}
dst := netip.MustParseAddrPort("10.0.0.1:4242")
bufs := [][]byte{
make([]byte, 1200),
make([]byte, 1200),
make([]byte, 1200),
}
addrs := []netip.AddrPort{dst, dst, dst}
t.Run("uniform_ecn_runs_together", func(t *testing.T) {
ecns := []byte{0x02, 0x02, 0x02}
runLen, segSize := u.planRun(bufs, addrs, ecns, 0, 64)
if runLen != 3 {
t.Errorf("runLen=%d want 3 (uniform ECT(0))", runLen)
}
if segSize != 1200 {
t.Errorf("segSize=%d want 1200", segSize)
}
})
t.Run("ecn_change_truncates_run", func(t *testing.T) {
// 0,0,3: first two run together, CE seeds a fresh entry.
ecns := []byte{0x00, 0x00, 0x03}
runLen, _ := u.planRun(bufs, addrs, ecns, 0, 64)
if runLen != 2 {
t.Errorf("runLen=%d want 2 (ECN changes at index 2)", runLen)
}
})
t.Run("nil_ecns_runs_full", func(t *testing.T) {
runLen, _ := u.planRun(bufs, addrs, nil, 0, 64)
if runLen != 3 {
t.Errorf("runLen=%d want 3 (nil ecns means no break)", runLen)
}
})
t.Run("first_ecn_is_singleton", func(t *testing.T) {
// Second packet has different ECN from the first → run halts at 1
// (the first packet alone forms the run).
ecns := []byte{0x00, 0x03, 0x03}
runLen, _ := u.planRun(bufs, addrs, ecns, 0, 64)
if runLen != 1 {
t.Errorf("runLen=%d want 1 (different ECN immediately)", runLen)
}
})
}

View File

@@ -44,7 +44,7 @@ func (u *GenericConn) WriteTo(b []byte, addr netip.AddrPort) error {
return err return err
} }
func (u *GenericConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort) error { func (u *GenericConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort, _ []byte) error {
for i, b := range bufs { for i, b := range bufs {
if _, err := u.UDPConn.WriteToUDPAddrPort(b, addrs[i]); err != nil { if _, err := u.UDPConn.WriteToUDPAddrPort(b, addrs[i]); err != nil {
return err return err
@@ -102,7 +102,7 @@ func (u *GenericConn) ListenOut(r EncReader, flush func()) error {
continue continue
} }
r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n]) r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n], RxMeta{})
flush() flush()
} }
} }

View File

@@ -24,6 +24,58 @@ type StdConn struct {
isV4 bool isV4 bool
l *slog.Logger l *slog.Logger
batch int batch int
// sendmmsg scratch. Each queue has its own StdConn, so no locking is
// needed. Sized to MaxWriteBatch at construction; WriteBatch chunks
// larger inputs.
writeMsgs []rawMessage
writeIovs []iovec
writeNames [][]byte
// Per-entry cmsg scratch. writeCmsg is one contiguous slab of
// MaxWriteBatch * writeCmsgSpace bytes; each entry holds two cmsg
// headers (UDP_SEGMENT then IP_TOS / IPV6_TCLASS) pre-filled once in
// prepareWriteMessages. WriteBatch only rewrites the per-call data
// payloads and toggles Hdr.Control / Hdr.Controllen to point at
// whichever subset of the two cmsgs applies.
writeCmsg []byte
writeCmsgSpace int
writeCmsgSegSpace int
writeCmsgEcnSpace int
// writeEntryEnd[e] is the bufs index *after* the last packet packed
// into mmsghdr entry e. Used to rewind `i` on partial sendmmsg success.
writeEntryEnd []int
// rawSend wraps the sendmmsg(2) callback in a closure-free helper so
// the hot path doesn't heap-allocate a fresh closure per call.
rawSend rawSendmmsg
// UDP GSO (sendmsg with UDP_SEGMENT cmsg) support. gsoSupported is
// probed once at socket creation. When true, WriteBatch packs same-
// destination consecutive packets into a single sendmmsg entry with a
// UDP_SEGMENT cmsg; otherwise each packet is its own entry.
gsoSupported bool
// UDP GRO (recvmsg with UDP_GRO cmsg) support. groSupported is probed
// once at socket creation. When true, listenOutBatch allocates larger
// RX buffers and a per-entry cmsg slot so the kernel can coalesce
// consecutive same-flow datagrams into a single recvmmsg entry; the
// delivered cmsg carries the gso_size used to split them back apart.
groSupported bool
// ecnRecvSupported is true when IP_RECVTOS / IPV6_RECVTCLASS was
// successfully enabled — the kernel will deliver the outer IP-ECN of
// each arriving datagram as a per-slot cmsg, and listenOutBatch passes
// the parsed value to the EncReader callback for RFC 6040 combine.
ecnRecvSupported bool
// rxOrder is the per-batch scratch listenOutBatch uses to gather every
// segment in a recvmmsg call (after splitting GRO superpackets) and
// stable-sort by (source, message-counter) before delivery. Reordering
// fits within the receiver's replay window so briefly out-of-order
// arrivals do not get rejected as replays.
rxOrder *rxReorderBuffer
} }
func setReusePort(network, address string, c syscall.RawConn) error { func setReusePort(network, address string, c syscall.RawConn) error {
@@ -70,9 +122,196 @@ func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int)
} }
out.isV4 = af == unix.AF_INET out.isV4 = af == unix.AF_INET
out.prepareWriteMessages(MaxWriteBatch)
out.rawSend.msgs = out.writeMsgs
out.rawSend.bind()
out.prepareGSO()
// GRO delivers coalesced superpackets that need a cmsg to split back
// into segments. The single-packet RX path uses ReadFromUDPAddrPort
// and cannot see that cmsg, so only enable GRO for the batch path.
if batch > 1 {
out.prepareGRO()
}
// Best-effort: ask the kernel to deliver outer IP-ECN as ancillary data
// on every recvmmsg slot so the decap side can apply RFC 6040 combine.
// On older kernels these may not exist; failing here just means we get
// 0 (Not-ECT) on every slot, which is the same as ecn_mode=disable.
out.prepareECNRecv()
return out, nil return out, nil
} }
// prepareWriteMessages allocates one mmsghdr/iovec/sockaddr/cmsg scratch
// slot per sendmmsg entry. The iovec slab is sized to n so all entries'
// iovecs share one allocation; per-entry fan-out is further capped at
// maxGSOSegments. Hdr.Iov / Hdr.Iovlen / Hdr.Control / Hdr.Controllen are
// wired per call since each entry can span a variable number of iovecs
// and may or may not carry a cmsg.
//
// Per-mmsghdr cmsg layout. Each entry's slot of length writeCmsgSpace holds
// up to two cmsg headers placed at fixed offsets:
//
// [0 .. writeCmsgSegSpace) UDP_SEGMENT (gso_size, uint16)
// [writeCmsgSegSpace .. writeCmsgSpace) IP_TOS or IPV6_TCLASS (int32)
//
// Both headers are pre-filled once here; per-call we only rewrite the data
// payload and toggle Hdr.Control / Hdr.Controllen to point at whichever
// subset applies (none / segment-only / ecn-only / both).
func (u *StdConn) prepareWriteMessages(n int) {
u.writeMsgs = make([]rawMessage, n)
u.writeIovs = make([]iovec, n)
u.writeNames = make([][]byte, n)
u.writeEntryEnd = make([]int, n)
u.writeCmsgSegSpace = unix.CmsgSpace(2)
u.writeCmsgEcnSpace = unix.CmsgSpace(4)
u.writeCmsgSpace = u.writeCmsgSegSpace + u.writeCmsgEcnSpace
u.writeCmsg = make([]byte, n*u.writeCmsgSpace)
ecnLevel := int32(unix.IPPROTO_IP)
ecnType := int32(unix.IP_TOS)
if !u.isV4 {
ecnLevel = unix.IPPROTO_IPV6
ecnType = unix.IPV6_TCLASS
}
for k := 0; k < n; k++ {
base := k * u.writeCmsgSpace
seg := (*unix.Cmsghdr)(unsafe.Pointer(&u.writeCmsg[base]))
seg.Level = unix.SOL_UDP
seg.Type = unix.UDP_SEGMENT
setCmsgLen(seg, unix.CmsgLen(2))
ecn := (*unix.Cmsghdr)(unsafe.Pointer(&u.writeCmsg[base+u.writeCmsgSegSpace]))
ecn.Level = ecnLevel
ecn.Type = ecnType
setCmsgLen(ecn, unix.CmsgLen(4))
}
for i := range u.writeMsgs {
u.writeNames[i] = make([]byte, unix.SizeofSockaddrInet6)
u.writeMsgs[i].Hdr.Name = &u.writeNames[i][0]
}
}
// maxGSOSegments caps the per-sendmsg GSO fan-out. Linux kernels have
// historically capped UDP_MAX_SEGMENTS at 64; newer kernels raise it to 128.
// We stay one below 64 because the kernel's check is
//
// if (cork->length > cork->gso_size * UDP_MAX_SEGMENTS) return -EINVAL;
//
// and cork->length includes the 8-byte UDP header (udp_sendmsg passes
// ulen = len + sizeof(udphdr) to ip_append_data). Packing exactly 64
// same-size segments puts cork->length at gso_size*64 + 8, which is one
// UDP-header over the bound and the kernel rejects the whole sendmmsg
// with EINVAL. 63 leaves room for the header for any segSize >= 8.
const maxGSOSegments = 63
// maxGSOBytes bounds the total payload per sendmsg() when UDP_SEGMENT is
// set. The kernel stitches all iovecs into a single skb whose length the
// UDP length field can represent, and also enforces sk_gso_max_size (which
// on most devices is 65536). We use 65000 to leave headroom under the
// 65535 UDP-length cap, avoiding EMSGSIZE on large TSO superpackets.
const maxGSOBytes = 65000
// prepareGSO probes UDP_SEGMENT support and sets u.gsoSupported on success.
// Best-effort; failure leaves it false.
func (u *StdConn) prepareGSO() {
var probeErr error
if err := u.rawConn.Control(func(fd uintptr) {
probeErr = unix.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT, 0)
}); err != nil {
u.l.Info("udp: GSO disabled", "reason", "rawconn control failed", "error", err)
recordCapability("udp.gso.enabled", false)
return
}
if probeErr != nil {
u.l.Info("udp: GSO disabled", "reason", "kernel rejected probe", "error", probeErr)
recordCapability("udp.gso.enabled", false)
return
}
u.gsoSupported = true
u.l.Info("udp: GSO enabled")
recordCapability("udp.gso.enabled", true)
}
// udpGROBufferSize sizes the per-entry recvmmsg buffer when UDP_GRO is on.
// The kernel stitches a run of same-flow datagrams into a single skb whose
// length is bounded by sk_gso_max_size (typically 65535); anything larger
// would be MSG_TRUNCed. We use the maximum representable UDP length so a
// full superpacket always lands intact.
const udpGROBufferSize = 65535
// udpGROCmsgPayload is the size of the UDP_GRO cmsg data delivered by the
// kernel: a single int (gso_size in bytes). See udp_cmsg_recv() in
// net/ipv4/udp.c.
const udpGROCmsgPayload = 4
// prepareGRO turns on UDP_GRO so the kernel coalesces consecutive same-flow
// datagrams into one recvmmsg entry, with a cmsg carrying the gso_size used
// to split them back apart on the application side.
func (u *StdConn) prepareGRO() {
var probeErr error
if err := u.rawConn.Control(func(fd uintptr) {
probeErr = unix.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1)
}); err != nil {
u.l.Info("udp: GRO disabled", "reason", "rawconn control failed", "error", err)
recordCapability("udp.gro.enabled", false)
return
}
if probeErr != nil {
u.l.Info("udp: GRO disabled", "reason", "kernel rejected probe", "error", probeErr)
recordCapability("udp.gro.enabled", false)
return
}
u.groSupported = true
u.l.Info("udp: GRO enabled")
recordCapability("udp.gro.enabled", true)
}
// prepareECNRecv turns on IP_RECVTOS / IPV6_RECVTCLASS so the outer IP-ECN
// field of each arriving datagram is delivered as ancillary data alongside
// the payload. listenOutBatch reads it via parseRecvCmsg and passes the
// codepoint through the EncReader for RFC 6040 combine on the decap side.
// Best-effort: we keep going on failure.
func (u *StdConn) prepareECNRecv() {
var probeErr error
if err := u.rawConn.Control(func(fd uintptr) {
if u.isV4 {
probeErr = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_RECVTOS, 1)
} else {
probeErr = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVTCLASS, 1)
}
}); err != nil {
u.l.Info("udp: outer-ECN RX disabled", "reason", "rawconn control failed", "error", err)
recordCapability("udp.ecn_rx.enabled", false)
return
}
if probeErr != nil {
u.l.Info("udp: outer-ECN RX disabled", "reason", "kernel rejected probe", "error", probeErr)
recordCapability("udp.ecn_rx.enabled", false)
return
}
u.ecnRecvSupported = true
u.l.Info("udp: outer-ECN RX enabled")
recordCapability("udp.ecn_rx.enabled", true)
}
// recordCapability registers (or updates) a boolean gauge for one of the
// kernel-feature probes. Gauges go to 1 when the feature is enabled, 0 when
// it is not — dashboards can show degraded state on partially-supported
// kernels at a glance. Calling repeatedly with the same name updates the
// existing gauge rather than registering a duplicate.
func recordCapability(name string, enabled bool) {
g := metrics.GetOrRegisterGauge(name, nil)
if enabled {
g.Update(1)
} else {
g.Update(0)
}
}
func (u *StdConn) SupportsMultipleReaders() bool { func (u *StdConn) SupportsMultipleReaders() bool {
return true return true
} }
@@ -183,7 +422,10 @@ func (u *StdConn) listenOutSingle(r EncReader, flush func()) error {
return err return err
} }
from = netip.AddrPortFrom(from.Addr().Unmap(), from.Port()) from = netip.AddrPortFrom(from.Addr().Unmap(), from.Port())
r(from, buffer[:n]) // listenOutSingle uses ReadFromUDPAddrPort which discards cmsgs,
// so the outer ECN field is not visible on this path. Zero RxMeta
// (Not-ECT) means RFC 6040 combine is a no-op.
r(from, buffer[:n], RxMeta{})
flush() flush()
} }
} }
@@ -194,7 +436,22 @@ func (u *StdConn) listenOutBatch(r EncReader, flush func()) error {
var operr error var operr error
bufSize := MTU bufSize := MTU
msgs, buffers, names := u.PrepareRawMessages(u.batch, bufSize) cmsgSpace := 0
if u.groSupported {
bufSize = udpGROBufferSize
cmsgSpace = unix.CmsgSpace(udpGROCmsgPayload)
}
if u.ecnRecvSupported {
// IP_TOS arrives as 1 byte; IPV6_TCLASS arrives as a 4-byte int.
// Reserve enough for the wider of the two so the same buffer fits
// either family alongside any UDP_GRO cmsg.
cmsgSpace += unix.CmsgSpace(4)
}
msgs, buffers, names, _ := u.PrepareRawMessages(u.batch, bufSize, cmsgSpace)
if u.rxOrder == nil {
u.rxOrder = newRxReorderBuffer(u.batch * 64)
}
//reader needs to capture variables from this function, since it's used as a lambda with rawConn.Read //reader needs to capture variables from this function, since it's used as a lambda with rawConn.Read
//defining it outside the loop so it gets re-used //defining it outside the loop so it gets re-used
@@ -204,6 +461,11 @@ func (u *StdConn) listenOutBatch(r EncReader, flush func()) error {
} }
for { for {
if cmsgSpace > 0 {
for i := range msgs {
setMsgControllen(&msgs[i].Hdr, cmsgSpace)
}
}
err := u.rawConn.Read(reader) err := u.rawConn.Read(reader)
if err != nil { if err != nil {
return err return err
@@ -212,6 +474,9 @@ func (u *StdConn) listenOutBatch(r EncReader, flush func()) error {
return operr return operr
} }
// Phase 1: gather every segment from this recvmmsg into rxOrder,
// splitting GRO superpackets into their constituent segments.
u.rxOrder.reset()
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
// Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic // Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic
if u.isV4 { if u.isV4 {
@@ -222,14 +487,77 @@ func (u *StdConn) listenOutBatch(r EncReader, flush func()) error {
from := netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])) from := netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4]))
payload := buffers[i][:msgs[i].Len] payload := buffers[i][:msgs[i].Len]
r(from, payload) segSize := 0
outerECN := byte(0)
if cmsgSpace > 0 {
segSize, outerECN = parseRecvCmsg(&msgs[i].Hdr, u.groSupported, u.ecnRecvSupported, u.isV4)
} }
u.rxOrder.addEntry(from, payload, segSize, outerECN)
}
// Phase 2 + 3: stable-sort by (src, port, counter), then deliver in
// order. Reorder distance is bounded by len(u.rxOrder.buf), which
// stays well within the receiver's ReplayWindow (currently 8192) so
// older arrivals are not rejected as replays.
u.rxOrder.sortStable()
u.rxOrder.deliver(r)
// End-of-batch: let callers (e.g. TUN write coalescer) flush any // End-of-batch: let callers (e.g. TUN write coalescer) flush any
// state they accumulated across this batch. // state they accumulated across this batch.
flush() flush()
} }
} }
// headerCounter returns the big-endian uint64 message counter at bytes
// [8:16] of a nebula packet, or 0 if the buffer is too short.
func headerCounter(buf []byte) uint64 {
if len(buf) < 16 {
return 0
}
return binary.BigEndian.Uint64(buf[8:16])
}
// parseRecvCmsg walks the per-slot ancillary buffer once and extracts up to
// two values of interest in a single pass: the UDP_GRO gso_size (when
// wantGRO is true) and the outer IP-level ECN codepoint stamped on the
// carrier (when wantECN is true). Returns zeros for whichever field is not
// requested or not present. isV4 selects between IP_TOS (1-byte) and
// IPV6_TCLASS (4-byte int) cmsg payloads.
func parseRecvCmsg(hdr *msghdr, wantGRO, wantECN bool, isV4 bool) (gso int, ecn byte) {
controllen := int(hdr.Controllen)
if controllen < unix.SizeofCmsghdr || hdr.Control == nil {
return 0, 0
}
ctrl := unsafe.Slice(hdr.Control, controllen)
off := 0
for off+unix.SizeofCmsghdr <= len(ctrl) {
ch := (*unix.Cmsghdr)(unsafe.Pointer(&ctrl[off]))
clen := int(ch.Len)
if clen < unix.SizeofCmsghdr || off+clen > len(ctrl) {
return gso, ecn
}
dataOff := off + unix.CmsgLen(0)
switch {
case wantGRO && ch.Level == unix.SOL_UDP && ch.Type == unix.UDP_GRO:
if dataOff+udpGROCmsgPayload <= len(ctrl) {
gso = int(int32(binary.NativeEndian.Uint32(ctrl[dataOff : dataOff+udpGROCmsgPayload])))
}
case wantECN && isV4 && ch.Level == unix.IPPROTO_IP && ch.Type == unix.IP_TOS:
// IP_TOS arrives as a single byte; only the low 2 bits are ECN.
if dataOff+1 <= len(ctrl) {
ecn = ctrl[dataOff] & 0x03
}
case wantECN && !isV4 && ch.Level == unix.IPPROTO_IPV6 && ch.Type == unix.IPV6_TCLASS:
// IPV6_TCLASS arrives as a 4-byte int; ECN is the low 2 bits.
if dataOff+4 <= len(ctrl) {
ecn = byte(binary.NativeEndian.Uint32(ctrl[dataOff:dataOff+4])) & 0x03
}
}
// Advance by the aligned cmsg space.
off += unix.CmsgSpace(clen - unix.CmsgLen(0))
}
return gso, ecn
}
func (u *StdConn) ListenOut(r EncReader, flush func()) error { func (u *StdConn) ListenOut(r EncReader, flush func()) error {
if u.batch == 1 { if u.batch == 1 {
return u.listenOutSingle(r, flush) return u.listenOutSingle(r, flush)
@@ -243,19 +571,255 @@ func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error {
return err return err
} }
func (u *StdConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort) error { // WriteBatch sends bufs via sendmmsg(2) using the preallocated scratch on
// StdConn. Consecutive packets to the same destination with matching segment
// sizes (all but possibly the last) are coalesced into a single mmsghdr entry
// carrying a UDP_SEGMENT cmsg, so one syscall can mix runs of GSO superpackets
// with plain one-off datagrams. Without GSO support every packet is its own
// entry, matching the prior behaviour.
//
// Chunks larger than the scratch are processed across multiple syscalls. If
// sendmmsg returns an error AND zero entries went out we fall back to
// per-packet WriteTo for that chunk so the caller still gets best-effort
// delivery; on a partial-success error we just replay the remainder.
func (u *StdConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort, ecns []byte) error {
if len(bufs) != len(addrs) { if len(bufs) != len(addrs) {
return fmt.Errorf("WriteBatch: len(bufs)=%d != len(addrs)=%d", len(bufs), len(addrs)) return fmt.Errorf("WriteBatch: len(bufs)=%d != len(addrs)=%d", len(bufs), len(addrs))
} }
//todo use sendmmsg if ecns != nil && len(ecns) != len(bufs) {
for i := 0; i < len(bufs); i++ { return fmt.Errorf("WriteBatch: len(ecns)=%d != len(bufs)=%d", len(ecns), len(bufs))
if _, err := u.udpConn.WriteToUDPAddrPort(bufs[i], addrs[i]); err != nil { }
// Callers deliver same-destination packets contiguously and in counter
// order, so we run the GSO planner directly without a pre-sort. A
// sorting pass measurably hurt throughput in microbenchmarks while
// providing no observed reordering benefit.
i := 0
for i < len(bufs) {
baseI := i
entry := 0
iovIdx := 0
for entry < len(u.writeMsgs) && i < len(bufs) {
iovBudget := len(u.writeIovs) - iovIdx
if iovBudget < 1 {
break
}
runLen, segSize := u.planRun(bufs, addrs, ecns, i, iovBudget)
if runLen == 0 {
break
}
for k := 0; k < runLen; k++ {
b := bufs[i+k]
if len(b) == 0 {
u.writeIovs[iovIdx+k].Base = nil
setIovLen(&u.writeIovs[iovIdx+k], 0)
} else {
u.writeIovs[iovIdx+k].Base = &b[0]
setIovLen(&u.writeIovs[iovIdx+k], len(b))
}
}
nlen, err := writeSockaddr(u.writeNames[entry], addrs[i], u.isV4)
if err != nil {
return err return err
} }
hdr := &u.writeMsgs[entry].Hdr
hdr.Iov = &u.writeIovs[iovIdx]
setMsgIovlen(hdr, runLen)
hdr.Namelen = uint32(nlen)
var ecn byte
if ecns != nil {
ecn = ecns[i]
}
u.writeEntryCmsg(entry, runLen, segSize, ecn)
i += runLen
iovIdx += runLen
u.writeEntryEnd[entry] = i
entry++
}
if entry == 0 {
return fmt.Errorf("sendmmsg: no progress")
}
sent, serr := u.sendmmsg(entry)
if serr != nil && sent <= 0 {
// Nothing went out for this chunk; fall back to WriteTo for each
// packet that was queued this iteration. We only enter this path
// when sendmmsg returned an error AND zero entries succeeded —
// otherwise the partial-success advance below replays only the
// remainder, avoiding duplicates of already-sent packets.
//
// sent=-1 from sendmmsg means message 0 itself failed (partial
// success returns the count instead), so log entry 0's parameters
// — that's the entry the kernel rejected.
hdr0 := &u.writeMsgs[0].Hdr
runLen0 := u.writeEntryEnd[0] - baseI
seg0 := len(bufs[baseI])
ecn0 := byte(0)
if ecns != nil {
ecn0 = ecns[baseI]
}
u.l.Warn("sendmmsg had problem",
"sent", sent, "err", serr,
"entries", entry,
"entry0_runLen", runLen0,
"entry0_segSize", seg0,
"entry0_iovlen", hdr0.Iovlen,
"entry0_controllen", hdr0.Controllen,
"entry0_namelen", hdr0.Namelen,
"entry0_ecn", ecn0,
"entry0_dst", addrs[baseI],
"isV4", u.isV4,
"gso", u.gsoSupported,
"gro", u.groSupported,
)
for k := baseI; k < i; k++ {
if werr := u.WriteTo(bufs[k], addrs[k]); werr != nil {
return werr
}
}
continue
}
if sent == 0 {
return fmt.Errorf("sendmmsg made no progress")
}
// Rewind i to the end of the last successfully sent entry. For a
// full-success send this leaves i unchanged; for a partial send it
// replays the remainder on the next outer-loop iteration.
i = u.writeEntryEnd[sent-1]
} }
return nil return nil
} }
// planRun groups consecutive packets starting at `start` that can be sent as
// a single UDP GSO superpacket (one sendmmsg entry with UDP_SEGMENT cmsg).
// A run of length 1 means the entry carries no UDP_SEGMENT cmsg and the
// kernel treats it as a plain datagram. Returns the run length and the
// per-segment size (which equals len(bufs[start])). Without GSO support
// every call returns runLen=1. Outer ECN (when ecns != nil) is also a run
// boundary — the kernel stamps one outer codepoint per sendmsg entry, so
// mixing values inside a run would lose information.
func (u *StdConn) planRun(bufs [][]byte, addrs []netip.AddrPort, ecns []byte, start, iovBudget int) (int, int) {
if start >= len(bufs) || iovBudget < 1 {
return 0, 0
}
segSize := len(bufs[start])
if !u.gsoSupported || segSize == 0 || segSize > maxGSOBytes {
return 1, segSize
}
dst := addrs[start]
var ecn byte
if ecns != nil {
ecn = ecns[start]
}
maxLen := maxGSOSegments
if iovBudget < maxLen {
maxLen = iovBudget
}
runLen := 1
total := segSize
for runLen < maxLen && start+runLen < len(bufs) {
nextLen := len(bufs[start+runLen])
if nextLen == 0 || nextLen > segSize {
break
}
if addrs[start+runLen] != dst {
break
}
if ecns != nil && ecns[start+runLen] != ecn {
break
}
if total+nextLen > maxGSOBytes {
break
}
total += nextLen
runLen++
if nextLen < segSize {
// A short packet must be the last in the run.
break
}
}
return runLen, segSize
}
// writeEntryCmsg sets up the per-mmsghdr Hdr.Control / Hdr.Controllen for one
// entry. It writes the UDP_SEGMENT payload when runLen >= 2 and the
// IP_TOS/IPV6_TCLASS payload when ecn != 0, then points hdr.Control at the
// smallest contiguous span that covers whichever cmsg(s) actually apply.
func (u *StdConn) writeEntryCmsg(entry, runLen, segSize int, ecn byte) {
hdr := &u.writeMsgs[entry].Hdr
useSeg := runLen >= 2
useEcn := ecn != 0
base := entry * u.writeCmsgSpace
if useSeg {
dataOff := base + unix.CmsgLen(0)
binary.NativeEndian.PutUint16(u.writeCmsg[dataOff:dataOff+2], uint16(segSize))
}
if useEcn {
dataOff := base + u.writeCmsgSegSpace + unix.CmsgLen(0)
binary.NativeEndian.PutUint32(u.writeCmsg[dataOff:dataOff+4], uint32(ecn))
}
switch {
case useSeg && useEcn:
hdr.Control = &u.writeCmsg[base]
setMsgControllen(hdr, u.writeCmsgSpace)
case useSeg:
hdr.Control = &u.writeCmsg[base]
setMsgControllen(hdr, u.writeCmsgSegSpace)
case useEcn:
hdr.Control = &u.writeCmsg[base+u.writeCmsgSegSpace]
setMsgControllen(hdr, u.writeCmsgEcnSpace)
default:
hdr.Control = nil
setMsgControllen(hdr, 0)
}
}
// sendmmsg issues sendmmsg(2) over u.rawConn against the first n entries
// of u.writeMsgs. Routes through u.rawSend so the per-call kernel callback
// stays alloc-free.
func (u *StdConn) sendmmsg(n int) (int, error) {
return u.rawSend.send(u.rawConn, n)
}
// writeSockaddr encodes addr into buf (which must be at least
// SizeofSockaddrInet6 bytes). Returns the number of bytes used. If isV4 is
// true and addr is not a v4 (or v4-in-v6) address, returns an error.
func writeSockaddr(buf []byte, addr netip.AddrPort, isV4 bool) (int, error) {
ap := addr.Addr().Unmap()
if isV4 {
if !ap.Is4() {
return 0, ErrInvalidIPv6RemoteForSocket
}
// struct sockaddr_in: { sa_family_t(2), in_port_t(2, BE), in_addr(4), zero(8) }
// sa_family is host endian.
binary.NativeEndian.PutUint16(buf[0:2], unix.AF_INET)
binary.BigEndian.PutUint16(buf[2:4], addr.Port())
ip4 := ap.As4()
copy(buf[4:8], ip4[:])
for j := 8; j < 16; j++ {
buf[j] = 0
}
return unix.SizeofSockaddrInet4, nil
}
// struct sockaddr_in6: { sa_family_t(2), in_port_t(2, BE), flowinfo(4), in6_addr(16), scope_id(4) }
binary.NativeEndian.PutUint16(buf[0:2], unix.AF_INET6)
binary.BigEndian.PutUint16(buf[2:4], addr.Port())
binary.NativeEndian.PutUint32(buf[4:8], 0)
ip6 := addr.Addr().As16()
copy(buf[8:24], ip6[:])
binary.NativeEndian.PutUint32(buf[24:28], 0)
return unix.SizeofSockaddrInet6, nil
}
func (u *StdConn) ReloadConfig(c *config.C) { func (u *StdConn) ReloadConfig(c *config.C) {
b := c.GetInt("listen.read_buffer", 0) b := c.GetInt("listen.read_buffer", 0)
if b > 0 { if b > 0 {

View File

@@ -30,11 +30,16 @@ type rawMessage struct {
Len uint32 Len uint32
} }
func (u *StdConn) PrepareRawMessages(n, bufSize int) ([]rawMessage, [][]byte, [][]byte) { func (u *StdConn) PrepareRawMessages(n, bufSize, cmsgSpace int) ([]rawMessage, [][]byte, [][]byte, []byte) {
msgs := make([]rawMessage, n) msgs := make([]rawMessage, n)
buffers := make([][]byte, n) buffers := make([][]byte, n)
names := make([][]byte, n) names := make([][]byte, n)
var cmsgs []byte
if cmsgSpace > 0 {
cmsgs = make([]byte, n*cmsgSpace)
}
for i := range msgs { for i := range msgs {
buffers[i] = make([]byte, bufSize) buffers[i] = make([]byte, bufSize)
names[i] = make([]byte, unix.SizeofSockaddrInet6) names[i] = make([]byte, unix.SizeofSockaddrInet6)
@@ -48,9 +53,14 @@ func (u *StdConn) PrepareRawMessages(n, bufSize int) ([]rawMessage, [][]byte, []
msgs[i].Hdr.Name = &names[i][0] msgs[i].Hdr.Name = &names[i][0]
msgs[i].Hdr.Namelen = uint32(len(names[i])) msgs[i].Hdr.Namelen = uint32(len(names[i]))
if cmsgSpace > 0 {
msgs[i].Hdr.Control = &cmsgs[i*cmsgSpace]
msgs[i].Hdr.Controllen = uint32(cmsgSpace)
}
} }
return msgs, buffers, names return msgs, buffers, names, cmsgs
} }
func setIovLen(v *iovec, n int) { func setIovLen(v *iovec, n int) {

View File

@@ -33,11 +33,16 @@ type rawMessage struct {
Pad0 [4]byte Pad0 [4]byte
} }
func (u *StdConn) PrepareRawMessages(n, bufSize int) ([]rawMessage, [][]byte, [][]byte) { func (u *StdConn) PrepareRawMessages(n, bufSize, cmsgSpace int) ([]rawMessage, [][]byte, [][]byte, []byte) {
msgs := make([]rawMessage, n) msgs := make([]rawMessage, n)
buffers := make([][]byte, n) buffers := make([][]byte, n)
names := make([][]byte, n) names := make([][]byte, n)
var cmsgs []byte
if cmsgSpace > 0 {
cmsgs = make([]byte, n*cmsgSpace)
}
for i := range msgs { for i := range msgs {
buffers[i] = make([]byte, bufSize) buffers[i] = make([]byte, bufSize)
names[i] = make([]byte, unix.SizeofSockaddrInet6) names[i] = make([]byte, unix.SizeofSockaddrInet6)
@@ -51,9 +56,14 @@ func (u *StdConn) PrepareRawMessages(n, bufSize int) ([]rawMessage, [][]byte, []
msgs[i].Hdr.Name = &names[i][0] msgs[i].Hdr.Name = &names[i][0]
msgs[i].Hdr.Namelen = uint32(len(names[i])) msgs[i].Hdr.Namelen = uint32(len(names[i]))
if cmsgSpace > 0 {
msgs[i].Hdr.Control = &cmsgs[i*cmsgSpace]
msgs[i].Hdr.Controllen = uint64(cmsgSpace)
}
} }
return msgs, buffers, names return msgs, buffers, names, cmsgs
} }
func setIovLen(v *iovec, n int) { func setIovLen(v *iovec, n int) {

View File

@@ -161,7 +161,7 @@ func (u *RIOConn) ListenOut(r EncReader, flush func()) error {
continue continue
} }
r(netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), buffer[:n]) r(netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), buffer[:n], RxMeta{})
flush() flush()
} }
} }
@@ -317,7 +317,7 @@ func (u *RIOConn) WriteTo(buf []byte, ip netip.AddrPort) error {
return winrio.SendEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0) return winrio.SendEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
} }
func (u *RIOConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort) error { func (u *RIOConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort, _ []byte) error {
for i, b := range bufs { for i, b := range bufs {
if err := u.WriteTo(b, addrs[i]); err != nil { if err := u.WriteTo(b, addrs[i]); err != nil {
return err return err

View File

@@ -157,7 +157,7 @@ func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error {
return nil return nil
} }
} }
func (u *TesterConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort) error { func (u *TesterConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort, _ []byte) error {
for i, b := range bufs { for i, b := range bufs {
if err := u.WriteTo(b, addrs[i]); err != nil { if err := u.WriteTo(b, addrs[i]); err != nil {
return err return err
@@ -172,7 +172,7 @@ func (u *TesterConn) ListenOut(r EncReader, flush func()) error {
case <-u.done: case <-u.done:
return os.ErrClosed return os.ErrClosed
case p := <-u.RxPackets: case p := <-u.RxPackets:
r(p.From, p.Data) r(p.From, p.Data, RxMeta{})
p.Release() p.Release()
flush() flush()
} }