From 6a46a2913a18b30349f734f63e9c36201c27e515 Mon Sep 17 00:00:00 2001 From: JackDoan Date: Fri, 17 Apr 2026 10:25:05 -0500 Subject: [PATCH] GSO/GRO offloads, with TCP+ECN and UDP support --- connection_state.go | 2 +- cpupin_linux.go | 23 + cpupin_other.go | 11 + ecn_inner_test.go | 125 ++ firewall/cache.go | 6 +- firewall/cache_test.go | 15 +- go.mod | 1 + go.sum | 2 + inside.go | 159 ++- interface.go | 93 +- main.go | 49 + outside.go | 65 +- overlay/batch/batch.go | 27 +- overlay/batch/coalesce_core.go | 163 +++ overlay/batch/multi_coalesce.go | 133 +++ overlay/batch/multi_coalesce_test.go | 94 ++ overlay/batch/tcp_coalesce.go | 722 ++++++++++++ overlay/batch/tcp_coalesce_bench_test.go | 173 +++ overlay/batch/tcp_coalesce_test.go | 1022 +++++++++++++++++ overlay/batch/tx_batch.go | 81 +- overlay/batch/tx_batch_test.go | 121 +- overlay/batch/udp_coalesce.go | 342 ++++++ overlay/batch/udp_coalesce_test.go | 383 ++++++ overlay/device.go | 2 +- overlay/overlaytest/noop.go | 2 +- overlay/tio/queueset_gso_linux.go | 79 ++ ...r_poll_linux.go => queueset_poll_linux.go} | 14 +- overlay/tio/segment_bench_test.go | 65 ++ overlay/tio/segment_other.go | 18 + overlay/tio/tio.go | 172 ++- overlay/tio/tio_gso_linux.go | 461 ++++++++ overlay/tio/tio_poll_linux.go | 6 +- overlay/tio/tun_file_linux_test.go | 4 +- overlay/tio/tun_linux_offload.go | 51 + overlay/tio/tun_linux_offload_test.go | 794 +++++++++++++ overlay/tio/virtio/header_linux.go | 43 + overlay/tio/virtio/segment_linux.go | 401 +++++++ overlay/tun_android.go | 6 +- overlay/tun_darwin.go | 6 +- overlay/tun_disabled.go | 38 +- overlay/tun_freebsd.go | 6 +- overlay/tun_ios.go | 6 +- overlay/tun_linux.go | 108 +- overlay/tun_netbsd.go | 6 +- overlay/tun_openbsd.go | 6 +- overlay/tun_tester.go | 14 +- overlay/tun_windows.go | 6 +- overlay/user.go | 6 +- udp/conn.go | 30 +- udp/raw_sendmmsg_linux.go | 62 + udp/rx_reorder_linux.go | 86 ++ udp/rx_reorder_linux_test.go | 203 ++++ udp/udp_darwin.go | 4 +- udp/udp_ecn_outer_linux_test.go | 61 + udp/udp_generic.go | 4 +- udp/udp_linux.go | 580 +++++++++- udp/udp_linux_32.go | 14 +- udp/udp_linux_64.go | 14 +- udp/udp_rio_windows.go | 4 +- udp/udp_tester.go | 4 +- 60 files changed, 6915 insertions(+), 283 deletions(-) create mode 100644 cpupin_linux.go create mode 100644 cpupin_other.go create mode 100644 ecn_inner_test.go create mode 100644 overlay/batch/coalesce_core.go create mode 100644 overlay/batch/multi_coalesce.go create mode 100644 overlay/batch/multi_coalesce_test.go create mode 100644 overlay/batch/tcp_coalesce.go create mode 100644 overlay/batch/tcp_coalesce_bench_test.go create mode 100644 overlay/batch/tcp_coalesce_test.go create mode 100644 overlay/batch/udp_coalesce.go create mode 100644 overlay/batch/udp_coalesce_test.go create mode 100644 overlay/tio/queueset_gso_linux.go rename overlay/tio/{container_poll_linux.go => queueset_poll_linux.go} (77%) create mode 100644 overlay/tio/segment_bench_test.go create mode 100644 overlay/tio/segment_other.go create mode 100644 overlay/tio/tio_gso_linux.go create mode 100644 overlay/tio/tun_linux_offload.go create mode 100644 overlay/tio/tun_linux_offload_test.go create mode 100644 overlay/tio/virtio/header_linux.go create mode 100644 overlay/tio/virtio/segment_linux.go create mode 100644 udp/raw_sendmmsg_linux.go create mode 100644 udp/rx_reorder_linux.go create mode 100644 udp/rx_reorder_linux_test.go create mode 100644 udp/udp_ecn_outer_linux_test.go diff --git a/connection_state.go b/connection_state.go index 0ae2d9be..a906b7a1 100644 --- a/connection_state.go +++ b/connection_state.go @@ -10,7 +10,7 @@ import ( "github.com/slackhq/nebula/noiseutil" ) -const ReplayWindow = 1024 +const ReplayWindow = 8192 type ConnectionState struct { eKey noiseutil.CipherState diff --git a/cpupin_linux.go b/cpupin_linux.go new file mode 100644 index 00000000..3080df6d --- /dev/null +++ b/cpupin_linux.go @@ -0,0 +1,23 @@ +//go:build linux && !android && !e2e_testing + +package nebula + +import ( + "runtime" + + "golang.org/x/sys/unix" +) + +// pinThreadToCPU restricts the calling OS thread to the given CPU via +// sched_setaffinity(2). Combined with runtime.LockOSThread on the +// goroutine, this prevents the kernel from migrating us across CPUs and +// in turn keeps every sendmmsg from this goroutine going through the +// same XPS-selected TX ring, eliminating the wire-side reorder that +// otherwise fragments one nebula flow across multiple rings. +func pinThreadToCPU(cpu int) error { + runtime.LockOSThread() + var set unix.CPUSet + set.Zero() + set.Set(cpu) + return unix.SchedSetaffinity(0, &set) +} diff --git a/cpupin_other.go b/cpupin_other.go new file mode 100644 index 00000000..4a472eae --- /dev/null +++ b/cpupin_other.go @@ -0,0 +1,11 @@ +//go:build !linux || android || e2e_testing + +package nebula + +// pinThreadToCPU is a no-op outside Linux: only Linux exposes a stable +// per-thread CPU affinity API and only Linux has XPS-driven TX ring +// selection in the first place. On every other platform there's nothing +// to fix here. +func pinThreadToCPU(_ int) error { + return nil +} diff --git a/ecn_inner_test.go b/ecn_inner_test.go new file mode 100644 index 00000000..7bba92ac --- /dev/null +++ b/ecn_inner_test.go @@ -0,0 +1,125 @@ +package nebula + +import ( + "io" + "log/slog" + "testing" +) + +func TestInnerECN(t *testing.T) { + cases := []struct { + name string + pkt []byte + want byte + }{ + {"empty", nil, 0}, + {"v4_NotECT", v4WithToS(0x00), 0x00}, + {"v4_ECT0", v4WithToS(0x02), 0x02}, + {"v4_ECT1", v4WithToS(0x01), 0x01}, + {"v4_CE", v4WithToS(0x03), 0x03}, + {"v4_DSCP_then_NotECT", v4WithToS(0x88 | 0x00), 0x00}, + {"v4_DSCP_then_CE", v4WithToS(0x88 | 0x03), 0x03}, + {"v6_NotECT", v6WithTC(0x00), 0x00}, + {"v6_ECT0", v6WithTC(0x02), 0x02}, + {"v6_CE", v6WithTC(0x03), 0x03}, + {"v6_DSCP_then_CE", v6WithTC(0x88 | 0x03), 0x03}, + {"unknown_version", []byte{0xa5, 0xff}, 0}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + got := innerECN(c.pkt) + if got != c.want { + t.Errorf("innerECN=0x%02x want 0x%02x", got, c.want) + } + }) + } +} + +// v4WithToS returns a 2-byte slice tall enough for innerECN: byte 0 carries +// version=4 in the high nibble, byte 1 is the full ToS so we exercise both +// the DSCP and ECN portions through the byte 1 mask. +func v4WithToS(tos byte) []byte { + return []byte{0x45, tos} +} + +// v6WithTC builds a 2-byte slice that places a known traffic class value +// across bytes 0 (high nibble of TC) and 1 (low nibble of TC). innerECN +// extracts ECN as (b[1]>>4)&0x03, which corresponds to TC[1:0]. +func v6WithTC(tc byte) []byte { + return []byte{0x60 | (tc>>4)&0x0f, (tc & 0x0f) << 4} +} + +func TestApplyOuterECN(t *testing.T) { + silent := slog.New(slog.NewTextHandler(io.Discard, nil)) + hi := &HostInfo{} + + // Build a v4 packet helper with a given inner ECN field. + v4 := func(innerECN byte) []byte { + // 20-byte minimal IPv4 header with ToS = innerECN (DSCP zeroed). + return []byte{ + 0x45, innerECN, 0, 28, + 0, 0, 0x40, 0, + 64, 6, 0, 0, + 10, 0, 0, 1, + 10, 0, 0, 2, + } + } + // Build a v6 packet helper with a given inner ECN field. ECN occupies + // TC[1:0] which sit at byte 1 mask 0x30. + v6 := func(innerECN byte) []byte { + // 40-byte minimal IPv6 header with TC[1:0] = innerECN. + pkt := make([]byte, 40) + pkt[0] = 0x60 // version=6, TC[7:4]=0 + pkt[1] = (innerECN & 0x03) << 4 // TC[3:0]: low 2 bits = ECN, top 2 = DSCP-low (0) + return pkt + } + + type cell struct { + outer byte + inner byte + wantECN byte + wantSame bool // expect inner unchanged (true => verify the byte didn't move) + } + + // RFC 6040 normal-mode combine table. Only outer==CE causes mutation. + table := []cell{ + {ecnNotECT, ecnNotECT, ecnNotECT, true}, + {ecnNotECT, ecnECT0, ecnECT0, true}, + {ecnNotECT, ecnECT1, ecnECT1, true}, + {ecnNotECT, ecnCE, ecnCE, true}, + + {ecnECT0, ecnNotECT, ecnNotECT, true}, + {ecnECT0, ecnECT0, ecnECT0, true}, + {ecnECT0, ecnECT1, ecnECT1, true}, + {ecnECT0, ecnCE, ecnCE, true}, + + {ecnECT1, ecnNotECT, ecnNotECT, true}, + {ecnECT1, ecnECT0, ecnECT0, true}, + {ecnECT1, ecnECT1, ecnECT1, true}, + {ecnECT1, ecnCE, ecnCE, true}, + + {ecnCE, ecnNotECT, ecnNotECT, true}, // legacy: log, leave alone + {ecnCE, ecnECT0, ecnCE, false}, // CE folded in + {ecnCE, ecnECT1, ecnCE, false}, + {ecnCE, ecnCE, ecnCE, true}, + } + + for _, c := range table { + t.Run("v4", func(t *testing.T) { + pkt := v4(c.inner) + applyOuterECN(pkt, c.outer, hi, silent) + got := pkt[1] & 0x03 + if got != c.wantECN { + t.Errorf("v4 outer=0x%02x inner=0x%02x: got 0x%02x want 0x%02x", c.outer, c.inner, got, c.wantECN) + } + }) + t.Run("v6", func(t *testing.T) { + pkt := v6(c.inner) + applyOuterECN(pkt, c.outer, hi, silent) + got := (pkt[1] >> 4) & 0x03 + if got != c.wantECN { + t.Errorf("v6 outer=0x%02x inner=0x%02x: got 0x%02x want 0x%02x", c.outer, c.inner, got, c.wantECN) + } + }) + } +} diff --git a/firewall/cache.go b/firewall/cache.go index ba4b9732..3e34e6ea 100644 --- a/firewall/cache.go +++ b/firewall/cache.go @@ -5,6 +5,8 @@ import ( "log/slog" "sync/atomic" "time" + + "github.com/slackhq/nebula/logging" ) // ConntrackCache is used as a local routine cache to know if a given flow @@ -56,8 +58,8 @@ func (c *ConntrackCacheTicker) Get() ConntrackCache { if tick := c.cacheTick.Load(); tick != c.cacheV { c.cacheV = tick if ll := len(c.cache); ll > 0 { - if c.l.Enabled(context.Background(), slog.LevelDebug) { - c.l.Debug("resetting conntrack cache", "len", ll) + if c.l.Enabled(context.Background(), logging.LevelTrace) { + c.l.Log(context.Background(), logging.LevelTrace, "resetting conntrack cache", "len", ll) } c.cache = make(ConntrackCache, ll) } diff --git a/firewall/cache_test.go b/firewall/cache_test.go index ab807984..3baf2326 100644 --- a/firewall/cache_test.go +++ b/firewall/cache_test.go @@ -6,6 +6,7 @@ import ( "strings" "testing" + "github.com/slackhq/nebula/logging" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" ) @@ -30,27 +31,27 @@ func newFixedTicker(t *testing.T, l *slog.Logger, cacheLen int) *ConntrackCacheT func TestConntrackCacheTicker_Get_TextFormat(t *testing.T) { buf := &bytes.Buffer{} - l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelDebug) + l := test.NewLoggerWithOutputAndLevel(buf, logging.LevelTrace) c := newFixedTicker(t, l, 3) c.Get() - assert.Equal(t, "level=DEBUG msg=\"resetting conntrack cache\" len=3\n", buf.String()) + assert.Equal(t, "level=DEBUG-4 msg=\"resetting conntrack cache\" len=3\n", buf.String()) } func TestConntrackCacheTicker_Get_JSONFormat(t *testing.T) { buf := &bytes.Buffer{} - l := test.NewJSONLoggerWithOutput(buf, slog.LevelDebug) + l := test.NewJSONLoggerWithOutput(buf, logging.LevelTrace) c := newFixedTicker(t, l, 2) c.Get() - assert.JSONEq(t, `{"level":"DEBUG","msg":"resetting conntrack cache","len":2}`, strings.TrimSpace(buf.String())) + assert.JSONEq(t, `{"level":"DEBUG-4","msg":"resetting conntrack cache","len":2}`, strings.TrimSpace(buf.String())) } -func TestConntrackCacheTicker_Get_QuietBelowDebug(t *testing.T) { +func TestConntrackCacheTicker_Get_QuietBelowTrace(t *testing.T) { buf := &bytes.Buffer{} - l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelInfo) + l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelDebug) c := newFixedTicker(t, l, 5) c.Get() @@ -60,7 +61,7 @@ func TestConntrackCacheTicker_Get_QuietBelowDebug(t *testing.T) { func TestConntrackCacheTicker_Get_QuietWhenCacheEmpty(t *testing.T) { buf := &bytes.Buffer{} - l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelDebug) + l := test.NewLoggerWithOutputAndLevel(buf, logging.LevelTrace) c := newFixedTicker(t, l, 0) c.Get() diff --git a/go.mod b/go.mod index 84728201..5d10edde 100644 --- a/go.mod +++ b/go.mod @@ -43,6 +43,7 @@ require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/google/btree v1.1.2 // indirect + github.com/guptarohit/asciigraph v0.9.0 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_model v0.6.2 // indirect diff --git a/go.sum b/go.sum index 3b0b87df..ee81e74e 100644 --- a/go.sum +++ b/go.sum @@ -60,6 +60,8 @@ github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= +github.com/guptarohit/asciigraph v0.9.0 h1:MvCSRRVkT2XvU1IO6n92o7l7zqx1DiFaoszOUZQztbY= +github.com/guptarohit/asciigraph v0.9.0/go.mod h1:dYl5wwK4gNsnFf9Zp+l06rFiDZ5YtXM6x7SRWZ3KGag= github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= diff --git a/inside.go b/inside.go index e528e892..1a865205 100644 --- a/inside.go +++ b/inside.go @@ -10,10 +10,23 @@ import ( "github.com/slackhq/nebula/iputil" "github.com/slackhq/nebula/noiseutil" "github.com/slackhq/nebula/overlay/batch" + "github.com/slackhq/nebula/overlay/tio" "github.com/slackhq/nebula/routing" ) -func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb []byte, sendBatch batch.TxBatcher, rejectBuf []byte, q int, localCache firewall.ConntrackCache) { +func (f *Interface) consumeInsidePacket(pkt tio.Packet, fwPacket *firewall.Packet, nb []byte, sendBatch batch.TxBatcher, rejectBuf []byte, q int, localCache firewall.ConntrackCache) { + // borrowed: pkt.Bytes is owned by the originating tio.Queue and is + // only valid until the next Read on that queue. Every consumer below + // (parse, self-forward, handshake cache, sendInsideMessage) reads it + // synchronously; do not retain pkt outside this call. If a future + // caller needs to keep the packet, use pkt.Clone() to detach it from + // the borrow. + // + // pkt.Bytes is either one IP datagram (GSO zero) or a TSO/USO + // superpacket. In both cases the L3+L4 headers at the start describe + // the same 5-tuple every segment will share, so a single newPacket / + // firewall check covers the whole superpacket. + packet := pkt.Bytes err := newPacket(packet, false, fwPacket) if err != nil { if f.l.Enabled(context.Background(), slog.LevelDebug) { @@ -38,7 +51,14 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet // routes packets from the Nebula addr to the Nebula addr through the Nebula // TUN device. if immediatelyForwardToSelf { - _, err := f.readers[q].Write(packet) + // Write copies into the kernel queue synchronously, so seg's lifetime ends at return. + // A self-forwarded superpacket would be re-handed to the + // kernel as one giant blob; segment first so the loopback + // path sees one IP datagram per Write. + err := tio.SegmentSuperpacket(pkt, func(seg []byte) error { + _, werr := f.readers[q].Write(seg) + return werr + }) if err != nil { f.l.Error("Failed to forward to tun", "error", err) } @@ -54,7 +74,19 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet } hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) { - hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) + // borrowed: SegmentSuperpacket builds each segment in the kernel-supplied pkt + // bytes underneath. cachePacket explicitly copies its argument (handshake_manager.go cachePacket), + // so retaining segments past the loop is safe. + err := tio.SegmentSuperpacket(pkt, func(seg []byte) error { + hh.cachePacket(f.l, header.Message, 0, seg, f.sendMessageNow, f.cachedPacketMetrics) + return nil + }) + if err != nil && f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("Failed to segment superpacket for handshake cache", + "error", err, + "vpnAddr", fwPacket.RemoteAddr, + ) + } }) if hostinfo == nil { @@ -74,7 +106,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache) if dropReason == nil { - f.sendInsideMessage(hostinfo, packet, nb, sendBatch, rejectBuf, q) + f.sendInsideMessage(hostinfo, pkt, nb, sendBatch, rejectBuf, q) } else { f.rejectInside(packet, rejectBuf, q) if f.l.Enabled(context.Background(), slog.LevelDebug) { @@ -86,11 +118,23 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet } } -// sendInsideMessage encrypts a firewall-approved inside packet into the -// caller's batch slot for later sendmmsg flush. When hostinfo.remote is not -// valid we fall through to the relay slow path via the unbatched sendNoMetrics -// so relay behavior is unchanged. -func (f *Interface) sendInsideMessage(hostinfo *HostInfo, p, nb []byte, sendBatch batch.TxBatcher, rejectBuf []byte, q int) { +// sendInsideMessage encrypts a firewall-approved inside packet (or every +// segment of a TSO/USO superpacket) into the caller's batch slot for +// later sendmmsg flush. Segmentation is fused with encryption here so the +// kernel-supplied superpacket bytes never get written into a separate +// scratch arena: SegmentSuperpacket builds each segment's plaintext in +// segScratch[:segLen] in turn, and we encrypt directly into a fresh +// SendBatch slot. +// +// When hostinfo.remote is not valid we fall through to the relay slow +// path via the unbatched sendNoMetrics so relay behavior is unchanged; +// each segment of a superpacket goes through that path independently. +// sendInsideMessage takes a borrowed pkt: pkt.Bytes is only valid until the +// next Read on the originating tio.Queue. Each segment is encrypted into a +// fresh sendBatch slot (Reserve returns owned scratch), so the borrow ends +// inside the SegmentSuperpacket callback below. Do not retain pkt or any +// seg slice past the callback's return. +func (f *Interface) sendInsideMessage(hostinfo *HostInfo, pkt tio.Packet, nb []byte, sendBatch batch.TxBatcher, rejectBuf []byte, q int) { ci := hostinfo.ConnectionState if ci.eKey == nil { return @@ -99,26 +143,20 @@ func (f *Interface) sendInsideMessage(hostinfo *HostInfo, p, nb []byte, sendBatc if !hostinfo.remote.IsValid() { // Slow path: relay fallback. Reuse rejectBuf as the ciphertext // scratch; sendNoMetrics arranges header space for SendVia. - f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, p, nb, rejectBuf, q) + // Segment any superpacket so each segment is sized to fit a + // single relay encap. + err := tio.SegmentSuperpacket(pkt, func(seg []byte) error { + f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, seg, nb, rejectBuf, q) + return nil + }) + if err != nil { + hostinfo.logger(f.l).Error("Failed to segment superpacket for relay send", + "error", err, + ) + } return } - scratch := sendBatch.Next() - if scratch == nil { - // Batch full: bypass batching and send this packet directly so we - // never drop traffic on over-subscribed iterations. - f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, p, nb, rejectBuf, q) - return - } - - if noiseutil.EncryptLockNeeded { - ci.writeLock.Lock() - } - c := ci.messageCounter.Add(1) - - out := header.Encode(scratch, header.Version, header.Message, 0, hostinfo.remoteIndexId, c) - f.connectionManager.Out(hostinfo) - if hostinfo.lastRebindCount != f.rebindCount { //NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is // finally used again. This tunnel would eventually be torn down and recreated if this action didn't help. @@ -131,20 +169,63 @@ func (f *Interface) sendInsideMessage(hostinfo *HostInfo, p, nb []byte, sendBatc } } - out, err := ci.eKey.EncryptDanger(out, out, p, c, nb) - if noiseutil.EncryptLockNeeded { - ci.writeLock.Unlock() - } - if err != nil { - hostinfo.logger(f.l).Error("Failed to encrypt outgoing packet", - "error", err, - "udpAddr", hostinfo.remote, - "counter", c, - ) - return - } + ecnEnabled := f.ecnEnabled.Load() - sendBatch.Commit(len(out), hostinfo.remote) + err := tio.SegmentSuperpacket(pkt, func(seg []byte) error { + // header + plaintext + AEAD tag (16 bytes for both AES-GCM and ChaCha20-Poly1305) + scratch := sendBatch.Reserve(header.Len + len(seg) + 16) + + if noiseutil.EncryptLockNeeded { + ci.writeLock.Lock() + } + c := ci.messageCounter.Add(1) + + out := header.Encode(scratch, header.Version, header.Message, 0, hostinfo.remoteIndexId, c) + f.connectionManager.Out(hostinfo) + + out, encErr := ci.eKey.EncryptDanger(out, out, seg, c, nb) + if noiseutil.EncryptLockNeeded { + ci.writeLock.Unlock() + } + if encErr != nil { + hostinfo.logger(f.l).Error("Failed to encrypt outgoing packet", + "error", encErr, + "udpAddr", hostinfo.remote, + "counter", c, + ) + // Skip this segment; the rest of the superpacket can still + // go out — TCP will retransmit anything we drop here. + return nil + } + + var ecn byte + if ecnEnabled { + ecn = innerECN(seg) + } + sendBatch.Commit(out, hostinfo.remote, ecn) + return nil + }) + if err != nil { + hostinfo.logger(f.l).Error("Failed to segment superpacket for send", + "error", err, + ) + } +} + +// innerECN returns the 2-bit IP-level ECN codepoint of an inner IPv4 or IPv6 +// packet, or 0 if pkt is too short or its IP version is unrecognized. Used at +// encap to copy the inner codepoint onto the outer carrier per RFC 6040. +func innerECN(pkt []byte) byte { + if len(pkt) < 2 { + return 0 + } + switch pkt[0] >> 4 { + case 4: + return pkt[1] & 0x03 + case 6: + return (pkt[1] >> 4) & 0x03 + } + return 0 } func (f *Interface) rejectInside(packet []byte, out []byte, q int) { diff --git a/interface.go b/interface.go index bc7e24d1..e29861b3 100644 --- a/interface.go +++ b/interface.go @@ -6,6 +6,7 @@ import ( "fmt" "log/slog" "net/netip" + "runtime" "sync" "sync/atomic" "time" @@ -48,7 +49,14 @@ type InterfaceConfig struct { reQueryWait time.Duration ConntrackCacheTimeout time.Duration - l *slog.Logger + + // CpuAffinity, when non-empty, names the CPUs each TUN reader goroutine + // should pin to. Queue i pins to CpuAffinity[i % len(CpuAffinity)] — + // shorter lists than `routines` cycle. Empty list keeps the default + // pin-to-(i % NumCPU) behavior. + CpuAffinity []int + + l *slog.Logger } type Interface struct { @@ -72,7 +80,16 @@ type Interface struct { routines int disconnectInvalid atomic.Bool closed atomic.Bool - relayManager *relayManager + // cpuAffinity, when non-empty, names the CPUs each TUN reader goroutine + // should pin to. Queue i pins to cpuAffinity[i % len(cpuAffinity)]. + // Empty falls back to the default pin-to-(i % NumCPU) behavior. + cpuAffinity []int + // ecnEnabled gates RFC 6040 underlay ECN propagation. When true, + // inside.go copies the inner ECN onto the outer carrier on encap and + // decryptToTun folds outer CE into the inner header on decap. Toggle + // via tunnels.ecn (default true). + ecnEnabled atomic.Bool + relayManager *relayManager tryPromoteEvery atomic.Uint32 reQueryEvery atomic.Uint32 @@ -202,6 +219,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { relayManager: c.relayManager, connectionManager: c.connectionManager, conntrackCacheTimeout: c.ConntrackCacheTimeout, + cpuAffinity: c.CpuAffinity, metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)), messageMetrics: c.MessageMetrics, @@ -260,7 +278,16 @@ func (f *Interface) activate() error { } f.readers = f.inside.Readers() for i := range f.readers { - f.batchers[i] = batch.NewPassthrough(f.readers[i]) + caps := tio.QueueCapabilities(f.readers[i]) + if caps.TSO || caps.USO { + // Multi-lane: TCP gets coalesced when TSO is on, UDP when USO + // is on, everything else (and either lane disabled) falls + // through to passthrough so non-IP / non-TCP-UDP traffic still + // reaches the TUN. + f.batchers[i] = batch.NewMultiCoalescer(f.readers[i], caps.TSO, caps.USO) + } else { + f.batchers[i] = batch.NewPassthrough(f.readers[i]) + } } f.wg.Add(1) // for us to wait on Close() to return @@ -322,15 +349,13 @@ func (f *Interface) listenOut(i int) { fwPacket := &firewall.Packet{} nb := make([]byte, 12, 12) - coalescer := f.batchers[i] - - listener := func(fromUdpAddr netip.AddrPort, payload []byte) { + listener := func(fromUdpAddr netip.AddrPort, payload []byte, meta udp.RxMeta) { plaintext := f.batchers[i].Reserve(len(payload)) - f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get()) + f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(), meta) } flusher := func() { - if err := coalescer.Flush(); err != nil { + if err := f.batchers[i].Flush(); err != nil { f.l.Error("Failed to flush tun coalescer", "error", err) } } @@ -346,8 +371,27 @@ func (f *Interface) listenOut(i int) { } func (f *Interface) listenIn(reader tio.Queue, i int) { + // Pin this goroutine to one CPU. LockOSThread alone keeps the goroutine + // on a single OS thread but the kernel can still migrate that thread + // across CPUs — XPS reads smp_processor_id() at sendmmsg time and picks + // the TX ring from the current CPU's xps_cpus map, so an unpinned + // thread bouncing between CPUs spreads one nebula flow's packets across + // multiple TX rings, which the rings then drain at independent rates + // and the wire delivers reordered. + // + // Pinning keeps every sendmmsg from this goroutine going through the + // same TX ring, so the wire sees per-flow order. Cost: less scheduler + // flexibility — if i % NumCPU collides between two TUN reader + // goroutines they share a CPU. + cpu := i % runtime.NumCPU() + if n := len(f.cpuAffinity); n > 0 { + cpu = f.cpuAffinity[i%n] + } + if err := pinThreadToCPU(cpu); err != nil { + f.l.Warn("failed to pin tun reader to CPU", "queue", i, "cpu", cpu, "err", err) + } rejectBuf := make([]byte, mtu) - sb := batch.NewSendBatch(batch.SendBatchCap, udp.MTU+32) + sb := batch.NewSendBatch(f.writers[i], batch.SendBatchCap, udp.MTU+32) fwPacket := &firewall.Packet{} nb := make([]byte, 12, 12) @@ -363,35 +407,24 @@ func (f *Interface) listenIn(reader tio.Queue, i int) { break } - sb.Reset() for _, pkt := range pkts { - if sb.Len() >= sb.Cap() { - f.flushBatch(sb, i) - sb.Reset() - } f.consumeInsidePacket(pkt, fwPacket, nb, sb, rejectBuf, i, conntrackCache.Get()) } - if sb.Len() > 0 { - f.flushBatch(sb, i) + if err := sb.Flush(); err != nil { + f.l.Error("Failed to write outgoing batch", "error", err, "writer", i) } } f.l.Debug("overlay reader is done", "reader", i) } -func (f *Interface) flushBatch(sb batch.TxBatcher, q int) { - bufs, dsts := sb.Get() - if err := f.writers[q].WriteBatch(bufs, dsts); err != nil { - f.l.Error("Failed to write outgoing batch", "error", err, "writer", q) - } -} - func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) { c.RegisterReloadCallback(f.reloadFirewall) c.RegisterReloadCallback(f.reloadSendRecvError) c.RegisterReloadCallback(f.reloadAcceptRecvError) c.RegisterReloadCallback(f.reloadDisconnectInvalid) c.RegisterReloadCallback(f.reloadMisc) + c.RegisterReloadCallback(f.reloadEcn) for _, udpConn := range f.writers { c.RegisterReloadCallback(udpConn.ReloadConfig) @@ -515,6 +548,20 @@ func (f *Interface) reloadMisc(c *config.C) { } } +// reloadEcn syncs Interface.ecnEnabled with the tunnels.ecn config knob. +// Default is enabled (RFC 6040 normal mode); set false on the rare path +// where an underlay middlebox rewrites or drops ECN bits unpredictably. +func (f *Interface) reloadEcn(c *config.C) { + initial := c.InitialLoad() + if initial || c.HasChanged("tunnels.ecn") { + v := c.GetBool("tunnels.ecn", true) + f.ecnEnabled.Store(v) + if !initial { + f.l.Info("tunnels.ecn changed", "enabled", v) + } + } +} + func (f *Interface) emitStats(ctx context.Context, i time.Duration) { ticker := time.NewTicker(i) defer ticker.Stop() diff --git a/main.go b/main.go index 37aa24d1..2af6840a 100644 --- a/main.go +++ b/main.go @@ -220,6 +220,7 @@ func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, dev relayManager: NewRelayManager(ctx, l, hostMap, c), punchy: punchy, ConntrackCacheTimeout: conntrackCacheTimeout, + CpuAffinity: parseCpuAffinity(c, l, routines), l: l, } @@ -237,6 +238,7 @@ func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, dev ifce.reloadDisconnectInvalid(c) ifce.reloadSendRecvError(c) ifce.reloadAcceptRecvError(c) + ifce.reloadEcn(c) handshakeManager.f = ifce go handshakeManager.Run(ctx) @@ -271,6 +273,53 @@ func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, dev }, nil } +// parseCpuAffinity reads `tun.cpu_affinity` from the config — a list of +// integer CPU IDs, one per TUN reader goroutine. Empty / unset returns nil +// (listenIn falls back to its default `i % NumCPU` pinning). Length +// mismatch with `routines` is a warning, not an error: shorter lists are +// modulo-cycled across queues, longer lists' tail is ignored. Invalid +// entries (non-integer, out of range) are also a warning and disable the +// override entirely so we don't silently pin to the wrong CPU. +func parseCpuAffinity(c *config.C, l *slog.Logger, routines int) []int { + raw := c.Get("tun.cpu_affinity") + if raw == nil { + return nil + } + rv, ok := raw.([]any) + if !ok { + l.Warn("tun.cpu_affinity must be a list of integers; ignoring", "value", raw) + return nil + } + nCPU := runtime.NumCPU() + cpus := make([]int, 0, len(rv)) + for i, e := range rv { + var cpu int + switch v := e.(type) { + case int: + cpu = v + case int64: + cpu = int(v) + case float64: + cpu = int(v) + default: + l.Warn("tun.cpu_affinity entry not an integer; ignoring affinity", + "index", i, "value", e) + return nil + } + if cpu < 0 || cpu >= nCPU { + l.Warn("tun.cpu_affinity entry out of range; ignoring affinity", + "index", i, "cpu", cpu, "num_cpu", nCPU) + return nil + } + cpus = append(cpus, cpu) + } + if len(cpus) != routines { + l.Warn("tun.cpu_affinity length doesn't match routines; queues will modulo-cycle through the list", + "affinity_len", len(cpus), "routines", routines) + } + return cpus +} + func moduleVersion() string { info, ok := debug.ReadBuildInfo() if !ok { diff --git a/outside.go b/outside.go index 9bf64ed6..7d2b6661 100644 --- a/outside.go +++ b/outside.go @@ -13,6 +13,7 @@ import ( "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" + "github.com/slackhq/nebula/udp" "golang.org/x/net/ipv4" ) @@ -22,7 +23,7 @@ const ( var ErrOutOfWindow = errors.New("out of window packet") -func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) { +func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, meta udp.RxMeta) { err := h.Parse(packet) if err != nil { // Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors @@ -135,7 +136,7 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, case header.Message: switch h.Subtype { case header.MessageNone: - f.handleOutsideMessagePacket(hostinfo, out, packet, fwPacket, nb, q, localCache) + f.handleOutsideMessagePacket(hostinfo, out, packet, fwPacket, nb, q, localCache, meta) default: hostinfo.logger(f.l).Error("IsValidSubType was true, but unexpected message subtype seen", "from", via, "header", h) return @@ -168,7 +169,7 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, } } -func (f *Interface) handleOutsideRelayPacket(hostinfo *HostInfo, via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) { +func (f *Interface) handleOutsideRelayPacket(hostinfo *HostInfo, via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache, meta udp.RxMeta) { // The entire body is sent as AD, not encrypted. // The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value. // The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's @@ -211,7 +212,7 @@ func (f *Interface) handleOutsideRelayPacket(hostinfo *HostInfo, via ViaSender, relay: relay, IsRelayed: true, } - f.readOutsidePackets(via, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache) + f.readOutsidePackets(via, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache, meta) case ForwardingType: // Find the target HostInfo relay object targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr) @@ -512,7 +513,61 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet [] return out, nil } -func (f *Interface) handleOutsideMessagePacket(hostinfo *HostInfo, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) { +// 2-bit IP-level ECN codepoints (lower bits of IPv4 ToS / IPv6 TC). +const ( + ecnNotECT = 0x00 + ecnECT1 = 0x01 + ecnECT0 = 0x02 + ecnCE = 0x03 +) + +// applyOuterECN folds an outer CE mark from the underlay into the inner +// IP header per RFC 6040 normal mode. It mutates pkt[1] in place. Other +// codepoints are advisory only and leave the inner unchanged. +// +// Merge cases (outer × inner → action): +// +// outer != CE : no-op (inner is authoritative) +// outer == CE, inner Not-ECT : log; cannot propagate to a non-ECN host +// outer == CE, inner ECT/CE : rewrite inner ECN to CE +func applyOuterECN(pkt []byte, outerECN byte, hostinfo *HostInfo, l *slog.Logger) { + if outerECN&ecnCE != ecnCE || len(pkt) < 2 { + return + } + switch pkt[0] >> 4 { + case 4: + switch pkt[1] & 0x03 { + case ecnNotECT: + if l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(l).Debug("RFC 6040: outer CE on inner Not-ECT, leaving inner unchanged") + } + case ecnCE: + // Already CE. + default: + pkt[1] = (pkt[1] &^ 0x03) | ecnCE + } + case 6: + switch (pkt[1] >> 4) & 0x03 { + case ecnNotECT: + if l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(l).Debug("RFC 6040: outer CE on inner Not-ECT, leaving inner unchanged") + } + case ecnCE: + // Already CE. + default: + pkt[1] = (pkt[1] &^ 0x30) | (ecnCE << 4) + } + } +} + +func (f *Interface) handleOutsideMessagePacket(hostinfo *HostInfo, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache, meta udp.RxMeta) { + // RFC 6040 normal-mode combine: fold any outer CE mark stamped by the + // underlay into the inner header before firewall + TUN write. Other + // outer codepoints are advisory only — we keep the inner unchanged. + if f.ecnEnabled.Load() { + applyOuterECN(out, meta.OuterECN, hostinfo, f.l) + } + err := newPacket(out, true, fwPacket) if err != nil { hostinfo.logger(f.l).Warn("Error while validating inbound packet", diff --git a/overlay/batch/batch.go b/overlay/batch/batch.go index 925d6044..338c6008 100644 --- a/overlay/batch/batch.go +++ b/overlay/batch/batch.go @@ -14,20 +14,15 @@ type RxBatcher interface { } type TxBatcher interface { - // Next returns a zero-length slice with slotCap capacity over the next unused - // slot's backing bytes. The caller writes into the returned slice and then - // calls Commit with the final length and destination. Next returns nil when - // the batch is full. - Next() []byte - // Commit records the slot just returned by Next as a packet of length n - // destined for dst. - Commit(n int, dst netip.AddrPort) - // Reset clears committed slots; backing storage is retained for reuse. - Reset() - // Len returns the number of committed packets. - Len() int - // Cap returns the maximum number of slots in the batch. - Cap() int - // Get returns the buffers needed to send the batch - Get() ([][]byte, []netip.AddrPort) + // Reserve creates a pkt to borrow + Reserve(sz int) []byte + // Commit borrows pkt and records its destination plus the 2-bit + // IP-level ECN codepoint to set on the outer (carrier) header. The + // caller must keep pkt valid until the next Flush. Pass 0 (Not-ECT) + // to leave the outer ECN field unset. + Commit(pkt []byte, dst netip.AddrPort, outerECN byte) + // Flush emits every queued packet via the underlying batch writer in + // arrival order. Returns the first error observed. After Flush returns, + // borrowed payload slices may be recycled. + Flush() error } diff --git a/overlay/batch/coalesce_core.go b/overlay/batch/coalesce_core.go new file mode 100644 index 00000000..377881b2 --- /dev/null +++ b/overlay/batch/coalesce_core.go @@ -0,0 +1,163 @@ +package batch + +import ( + "bytes" + "encoding/binary" +) + +// flowKey identifies a transport flow by {src, dst, sport, dport, family}. +// Comparable, so map lookups and linear scans over the slot list stay tight. +// Shared by the TCP and UDP coalescers; each coalescer keeps its own +// openSlots map, so a TCP and UDP flow on the same 5-tuple-without-proto +// never alias. +type flowKey struct { + src, dst [16]byte + sport, dport uint16 + isV6 bool +} + +// initialSlots is the starting capacity of the slot pool. One flow per +// packet is the worst case so this matches a typical carrier-side +// recvmmsg batch on the encrypted UDP socket. +const initialSlots = 64 + +// parsedIP is the IP-level result of parseIPPrologue. The caller layers +// L4-specific parsing (TCP / UDP) on top. +type parsedIP struct { + fk flowKey + ipHdrLen int + // pkt is the original buffer trimmed to the IP-declared total length. + // Anything below the IP layer (transport parsers) should slice into + // pkt rather than the unbounded original. + pkt []byte +} + +// parseIPPrologue extracts the IP-level fields the coalescers care about: +// IHL/payload length, version, src/dst addresses, and the L4 protocol byte. +// Returns ok=false for malformed input, IPv4 with options or fragmentation, +// or IPv6 with extension headers (all rejected by both coalescers in +// identical ways before this refactor). +// +// On success, p.pkt is len-trimmed to the IP-declared length so callers +// don't have to repeat the trim. wantProto is the IANA protocol number to +// require (6 for TCP, 17 for UDP); ok=false for any other value. +func parseIPPrologue(pkt []byte, wantProto byte) (parsedIP, bool) { + var p parsedIP + if len(pkt) < 20 { + return p, false + } + v := pkt[0] >> 4 + switch v { + case 4: + ihl := int(pkt[0]&0x0f) * 4 + if ihl != 20 { + return p, false + } + if pkt[9] != wantProto { + return p, false + } + // Reject actual fragmentation (MF or non-zero frag offset). + if binary.BigEndian.Uint16(pkt[6:8])&0x3fff != 0 { + return p, false + } + totalLen := int(binary.BigEndian.Uint16(pkt[2:4])) + if totalLen > len(pkt) || totalLen < ihl { + return p, false + } + p.ipHdrLen = 20 + p.fk.isV6 = false + copy(p.fk.src[:4], pkt[12:16]) + copy(p.fk.dst[:4], pkt[16:20]) + p.pkt = pkt[:totalLen] + case 6: + if len(pkt) < 40 { + return p, false + } + if pkt[6] != wantProto { + return p, false + } + payloadLen := int(binary.BigEndian.Uint16(pkt[4:6])) + if 40+payloadLen > len(pkt) { + return p, false + } + p.ipHdrLen = 40 + p.fk.isV6 = true + copy(p.fk.src[:], pkt[8:24]) + copy(p.fk.dst[:], pkt[24:40]) + p.pkt = pkt[:40+payloadLen] + default: + return p, false + } + return p, true +} + +// ipHeadersMatch compares the IP portion of two packet header prefixes for +// byte-for-byte equality on every field that must be identical across +// coalesced segments. Size/IPID/IPCsum and the 2-bit IP-level ECN field are +// masked out — the appendPayload step merges CE into the seed. +// +// The transport (L4) portion of the header is checked separately by the +// per-protocol matcher. +func ipHeadersMatch(a, b []byte, isV6 bool) bool { + if isV6 { + // IPv6: byte 0 = version/TC[7:4], byte 1 = TC[3:0]/flow[19:16], + // bytes [2:4] = flow[15:0], [6:8] = next_hdr/hop, [8:40] = src+dst. + // ECN lives in TC[1:0] = byte 1 mask 0x30. Skip [4:6] payload_len. + if a[0] != b[0] { + return false + } + if a[1]&^0x30 != b[1]&^0x30 { + return false + } + if !bytes.Equal(a[2:4], b[2:4]) { + return false + } + if !bytes.Equal(a[6:40], b[6:40]) { + return false + } + return true + } + // IPv4: byte 0 = version/IHL, byte 1 = DSCP(6)|ECN(2), + // [6:10] flags/fragoff/TTL/proto, [12:20] src+dst. + // Skip [2:4] total len, [4:6] id, [10:12] csum. + if a[0] != b[0] { + return false + } + if a[1]&^0x03 != b[1]&^0x03 { + return false + } + if !bytes.Equal(a[6:10], b[6:10]) { + return false + } + if !bytes.Equal(a[12:20], b[12:20]) { + return false + } + return true +} + +// mergeECNIntoSeed ORs the 2-bit IP-level ECN field of pkt's IP header +// onto the seed's IP header, so a CE mark on any coalesced segment +// propagates to the final superpacket. (CE is 0b11; ORing yields CE if +// any segment carried it.) Used by both TCP and UDP coalescers, so the +// invariant lives in one place. +func mergeECNIntoSeed(seedHdr, pktHdr []byte, isV6 bool) { + if isV6 { + seedHdr[1] |= pktHdr[1] & 0x30 + } else { + seedHdr[1] |= pktHdr[1] & 0x03 + } +} + +// reserveFromBacking implements the Reserve half of the RxBatcher contract +// shared by TCP and UDP coalescers. The backing slice grows on demand; +// already-committed slices reference the old array and remain valid until +// Flush resets backing. +func reserveFromBacking(backing *[]byte, sz int) []byte { + if len(*backing)+sz > cap(*backing) { + newCap := max(cap(*backing)*2, sz) + *backing = make([]byte, 0, newCap) + } + start := len(*backing) + *backing = (*backing)[:start+sz] + return (*backing)[start : start+sz : start+sz] +} diff --git a/overlay/batch/multi_coalesce.go b/overlay/batch/multi_coalesce.go new file mode 100644 index 00000000..fbe59ccc --- /dev/null +++ b/overlay/batch/multi_coalesce.go @@ -0,0 +1,133 @@ +package batch + +import ( + "io" +) + +// MultiCoalescer fans plaintext packets out to lane-specific batchers based +// on the IP/L4 protocol of the packet, sharing a single Reserve arena +// across lanes so the caller's allocation pattern is unchanged. +// +// Lanes are processed independently: the TCP coalescer only sees TCP, the +// UDP coalescer only sees UDP, and the passthrough lane handles everything +// else. Per-flow arrival order is preserved because a single 5-tuple only +// ever lands in one lane and each lane preserves its own slot order. +// +// Cross-lane order is NOT preserved across the TCP/UDP/passthrough split. +// This is acceptable because the carrier-side recvmmsg path already +// stable-sorts by (peer, message counter) before delivering plaintext +// here, so replay-window invariants are unaffected, and apps observe +// correct per-flow ordering — which is all the IP layer guarantees anyway. +// Do not "fix" this by interleaving lane outputs at flush time; that +// negates the entire point of coalescing (each lane needs to see runs of +// adjacent same-flow packets to coalesce them). +type MultiCoalescer struct { + tcp *TCPCoalescer + udp *UDPCoalescer + pt *Passthrough + + // arena shared across all lanes so a single Reserve grows one backing + // slice; lane Commit calls borrow into this same arena. + backing []byte +} + +// NewMultiCoalescer builds a multi-lane batcher. tcpEnabled lets the caller +// opt out of TCP coalescing (e.g. when the queue can't do TSO); udpEnabled +// likewise gates UDP coalescing (only enable when USO was negotiated). +// Either lane disabled redirects its traffic into the passthrough lane. +func NewMultiCoalescer(w io.Writer, tcpEnabled, udpEnabled bool) *MultiCoalescer { + m := &MultiCoalescer{ + pt: NewPassthrough(w), + backing: make([]byte, 0, initialSlots*65535), + } + if tcpEnabled { + m.tcp = NewTCPCoalescer(w) + } + if udpEnabled { + m.udp = NewUDPCoalescer(w) + } + return m +} + +func (m *MultiCoalescer) Reserve(sz int) []byte { + if len(m.backing)+sz > cap(m.backing) { + newCap := max(cap(m.backing)*2, sz) + m.backing = make([]byte, 0, newCap) + } + start := len(m.backing) + m.backing = m.backing[:start+sz] + return m.backing[start : start+sz : start+sz] +} + +// Commit dispatches pkt to the appropriate lane based on IP version + L4 +// proto. Borrowed slice contract is identical to the single-lane batchers +// — pkt must remain valid until the next Flush. +// +// On the success path the IP/TCP-or-UDP parse happens here once and the +// parsed struct is handed to the lane via commitParsed so the lane doesn't +// re-walk the header. On a parse failure we fall through to the lane's +// public Commit, which re-runs the parse before passthrough — that path +// only fires for malformed/unsupported packets so the duplicated parse is +// not on the hot path. The lane's public Commit still works for direct +// callers. +func (m *MultiCoalescer) Commit(pkt []byte) error { + if len(pkt) < 20 { + return m.pt.Commit(pkt) + } + v := pkt[0] >> 4 + var proto byte + switch v { + case 4: + proto = pkt[9] + case 6: + if len(pkt) < 40 { + return m.pt.Commit(pkt) + } + proto = pkt[6] + default: + return m.pt.Commit(pkt) + } + switch proto { + case ipProtoTCP: + if m.tcp != nil { + info, ok := parseTCPBase(pkt) + if !ok { + // Malformed/unsupported TCP shape (IP options, fragments, ...) + // — the TCP lane handles this as passthrough. + return m.tcp.Commit(pkt) + } + return m.tcp.commitParsed(pkt, info) + } + case ipProtoUDP: + if m.udp != nil { + info, ok := parseUDP(pkt) + if !ok { + return m.udp.Commit(pkt) + } + return m.udp.commitParsed(pkt, info) + } + } + return m.pt.Commit(pkt) +} + +// Flush drains every lane in a fixed order: TCP, UDP, passthrough. Errors +// from a lane do not stop subsequent lanes from flushing — we keep +// draining and return the first observed error so a single bad packet +// doesn't strand the others. +func (m *MultiCoalescer) Flush() error { + var first error + keep := func(err error) { + if err != nil && first == nil { + first = err + } + } + if m.tcp != nil { + keep(m.tcp.Flush()) + } + if m.udp != nil { + keep(m.udp.Flush()) + } + keep(m.pt.Flush()) + m.backing = m.backing[:0] + return first +} diff --git a/overlay/batch/multi_coalesce_test.go b/overlay/batch/multi_coalesce_test.go new file mode 100644 index 00000000..9d718ecf --- /dev/null +++ b/overlay/batch/multi_coalesce_test.go @@ -0,0 +1,94 @@ +package batch + +import ( + "testing" +) + +// TestMultiCoalescerRoutesByProto confirms TCP/UDP/other land in the right +// lane: TCP and UDP get coalesced when their lanes are enabled, anything +// else (ICMP here) falls through to plain Write. +func TestMultiCoalescerRoutesByProto(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + m := NewMultiCoalescer(w, true, true) + + tcpPay := make([]byte, 1200) + udpPay := make([]byte, 1200) + icmp := make([]byte, 28) + icmp[0] = 0x45 + icmp[2] = 0 + icmp[3] = 28 + icmp[9] = 1 + + if err := m.Commit(buildTCPv4(1000, tcpAck, tcpPay)); err != nil { + t.Fatal(err) + } + if err := m.Commit(buildTCPv4(2200, tcpAck, tcpPay)); err != nil { + t.Fatal(err) + } + if err := m.Commit(buildUDPv4(2000, 53, udpPay)); err != nil { + t.Fatal(err) + } + if err := m.Commit(buildUDPv4(2000, 53, udpPay)); err != nil { + t.Fatal(err) + } + if err := m.Commit(icmp); err != nil { + t.Fatal(err) + } + if err := m.Flush(); err != nil { + t.Fatal(err) + } + // 1 TCP super (2 segments) + 1 UDP super (2 segments) = 2 gso writes. + if len(w.gsoWrites) != 2 { + t.Fatalf("want 2 gso writes (one TCP + one UDP), got %d", len(w.gsoWrites)) + } + if len(w.writes) != 1 { + t.Fatalf("want 1 plain write (ICMP), got %d", len(w.writes)) + } +} + +// TestMultiCoalescerDisabledUDPFallsThrough verifies that when the UDP lane +// is disabled (e.g. kernel doesn't support USO), UDP packets still reach +// the kernel via the passthrough lane rather than being lost. +func TestMultiCoalescerDisabledUDPFallsThrough(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + m := NewMultiCoalescer(w, true, false) // TSO on, USO off + + if err := m.Commit(buildUDPv4(1000, 53, make([]byte, 800))); err != nil { + t.Fatal(err) + } + if err := m.Commit(buildUDPv4(1000, 53, make([]byte, 800))); err != nil { + t.Fatal(err) + } + if err := m.Flush(); err != nil { + t.Fatal(err) + } + if len(w.gsoWrites) != 0 { + t.Errorf("UDP must NOT be coalesced when USO disabled, got %d gso writes", len(w.gsoWrites)) + } + if len(w.writes) != 2 { + t.Errorf("UDP must pass through as 2 plain writes, got %d", len(w.writes)) + } +} + +// TestMultiCoalescerDisabledTCPFallsThrough mirrors the TSO=off case. +func TestMultiCoalescerDisabledTCPFallsThrough(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + m := NewMultiCoalescer(w, false, true) // TSO off, USO on + + pay := make([]byte, 1200) + if err := m.Commit(buildTCPv4(1000, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := m.Commit(buildTCPv4(2200, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := m.Flush(); err != nil { + t.Fatal(err) + } + if len(w.gsoWrites) != 0 { + t.Errorf("TCP must NOT be coalesced when TSO disabled, got %d gso writes", len(w.gsoWrites)) + } + if len(w.writes) != 2 { + t.Errorf("TCP must pass through as 2 plain writes, got %d", len(w.writes)) + } +} diff --git a/overlay/batch/tcp_coalesce.go b/overlay/batch/tcp_coalesce.go new file mode 100644 index 00000000..c957a138 --- /dev/null +++ b/overlay/batch/tcp_coalesce.go @@ -0,0 +1,722 @@ +package batch + +import ( + "bytes" + "encoding/binary" + "io" + "log/slog" + "net/netip" + "sort" + + "github.com/slackhq/nebula/overlay/tio" +) + +// ipProtoTCP is the IANA protocol number for TCP. Hardcoded instead of +// reaching for golang.org/x/sys/unix — that package doesn't define the +// constant on Windows, which would break cross-compiles even though this +// file runs unchanged on every platform. +const ipProtoTCP = 6 + +// tcpCoalesceBufSize caps total bytes per superpacket. Mirrors the kernel's +// sk_gso_max_size of ~64KiB; anything beyond this would be rejected anyway. +const tcpCoalesceBufSize = 65535 + +// tcpCoalesceMaxSegs caps how many segments we'll coalesce into a single +// superpacket. Keeping this well below the kernel's TSO ceiling bounds +// latency. +const tcpCoalesceMaxSegs = 64 + +// tcpCoalesceHdrCap is the scratch space we copy a seed's IP+TCP header +// into. IPv6 (40) + TCP with full options (60) = 100 bytes. +const tcpCoalesceHdrCap = 100 + +// coalesceSlot is one entry in the coalescer's ordered event queue. When +// passthrough is true the slot holds a single borrowed packet that must be +// emitted verbatim (non-TCP, non-admissible TCP, or oversize seed). When +// passthrough is false the slot is an in-progress coalesced superpacket: +// hdrBuf is a mutable copy of the seed's IP+TCP header (we patch total +// length and pseudo-header partial at flush), and payIovs are *borrowed* +// slices from the caller's plaintext buffers — no payload is ever copied. +// The caller (listenOut) must keep those buffers alive until Flush. +type coalesceSlot struct { + passthrough bool + rawPkt []byte // borrowed when passthrough + + fk flowKey + hdrBuf [tcpCoalesceHdrCap]byte + hdrLen int + ipHdrLen int + isV6 bool + gsoSize int + numSeg int + totalPay int + nextSeq uint32 + // psh closes the chain: set when the last-accepted segment had PSH or + // was sub-gsoSize. No further appends after that. + psh bool + payIovs [][]byte +} + +// TCPCoalescer accumulates adjacent in-flow TCP data segments across +// multiple concurrent flows and emits each flow's run as a single TSO +// superpacket via tio.GSOWriter. All output — coalesced or not — is +// deferred until Flush so arrival order is preserved on the wire. Owns +// no locks; one coalescer per TUN write queue. +type TCPCoalescer struct { + plainW io.Writer + gsoW tio.GSOWriter // nil when the queue doesn't support TSO + + // slots is the ordered event queue. Flush walks it once and emits each + // entry as either a WriteGSO (coalesced) or a plainW.Write (passthrough). + slots []*coalesceSlot + // openSlots maps a flow key to its most recent non-sealed slot, so new + // segments can extend an in-progress superpacket in O(1). Slots are + // removed from this map when they close (PSH or short-last-segment), + // when a non-admissible packet for that flow arrives, or in Flush. + openSlots map[flowKey]*coalesceSlot + // lastSlot caches the most recently touched open slot. Steady-state + // bulk traffic is dominated by a single flow, so comparing the + // incoming key against the cached slot's own fk lets the hot path + // skip the map lookup (and the aeshash of a 38-byte key) entirely. + // Kept in lockstep with openSlots: nil whenever the slot it pointed + // at is removed/sealed. + lastSlot *coalesceSlot + pool []*coalesceSlot // free list for reuse + + backing []byte +} + +func NewTCPCoalescer(w io.Writer) *TCPCoalescer { + c := &TCPCoalescer{ + plainW: w, + slots: make([]*coalesceSlot, 0, initialSlots), + openSlots: make(map[flowKey]*coalesceSlot, initialSlots), + pool: make([]*coalesceSlot, 0, initialSlots), + backing: make([]byte, 0, initialSlots*65535), + } + if gw, ok := tio.SupportsGSO(w, tio.GSOProtoTCP); ok { + c.gsoW = gw + } + return c +} + +// parsedTCP holds the fields extracted from a single parse so later steps +// (admission, slot lookup, canAppend) don't re-walk the header. +type parsedTCP struct { + fk flowKey + ipHdrLen int + tcpHdrLen int + hdrLen int + payLen int + seq uint32 + flags byte +} + +// parseTCPBase extracts the flow key and IP/TCP offsets for any TCP packet, +// regardless of whether it's admissible for coalescing. Returns ok=false +// for non-TCP or malformed input. Accepts IPv4 (no options, no fragmentation) +// and IPv6 (no extension headers). +func parseTCPBase(pkt []byte) (parsedTCP, bool) { + var p parsedTCP + ip, ok := parseIPPrologue(pkt, ipProtoTCP) + if !ok { + return p, false + } + pkt = ip.pkt + p.fk = ip.fk + p.ipHdrLen = ip.ipHdrLen + + if len(pkt) < p.ipHdrLen+20 { + return p, false + } + tcpOff := int(pkt[p.ipHdrLen+12]>>4) * 4 + if tcpOff < 20 || tcpOff > 60 { + return p, false + } + if len(pkt) < p.ipHdrLen+tcpOff { + return p, false + } + p.tcpHdrLen = tcpOff + p.hdrLen = p.ipHdrLen + tcpOff + p.payLen = len(pkt) - p.hdrLen + p.seq = binary.BigEndian.Uint32(pkt[p.ipHdrLen+4 : p.ipHdrLen+8]) + p.flags = pkt[p.ipHdrLen+13] + p.fk.sport = binary.BigEndian.Uint16(pkt[p.ipHdrLen : p.ipHdrLen+2]) + p.fk.dport = binary.BigEndian.Uint16(pkt[p.ipHdrLen+2 : p.ipHdrLen+4]) + return p, true +} + +// TCP flag bits (byte 13 of the TCP header). Only the bits actually consulted +// by the coalescer are named; FIN/SYN/RST/URG/CWR are rejected via the +// negative mask in coalesceable, not by name. +const ( + tcpFlagPsh = 0x08 + tcpFlagAck = 0x10 + tcpFlagEce = 0x40 +) + +// coalesceable reports whether a parsed TCP segment is eligible for +// coalescing. Accepts ACK, ACK|PSH, ACK|ECE, ACK|PSH|ECE with a +// non-empty payload. CWR is excluded because it marks a one-shot +// congestion-window-reduced transition the receiver must observe at a +// segment boundary. +func (p parsedTCP) coalesceable() bool { + if p.flags&tcpFlagAck == 0 { + return false + } + if p.flags&^(tcpFlagAck|tcpFlagPsh|tcpFlagEce) != 0 { + return false + } + return p.payLen > 0 +} + +func (c *TCPCoalescer) Reserve(sz int) []byte { + return reserveFromBacking(&c.backing, sz) +} + +// Commit borrows pkt. The caller must keep pkt valid until the next Flush, +// whether or not the packet was coalesced — passthrough (non-admissible) +// packets are queued and written at Flush time, not synchronously. +func (c *TCPCoalescer) Commit(pkt []byte) error { + if c.gsoW == nil { + c.addPassthrough(pkt) + return nil + } + info, ok := parseTCPBase(pkt) + if !ok { + c.addPassthrough(pkt) + return nil + } + return c.commitParsed(pkt, info) +} + +// commitParsed is the post-parse half of Commit. The caller must have +// already verified parseTCPBase succeeded (info is a valid TCP parse). +// Used by MultiCoalescer.Commit to avoid re-walking the IP/TCP header +// after the dispatcher has already done so. +func (c *TCPCoalescer) commitParsed(pkt []byte, info parsedTCP) error { + if c.gsoW == nil { + c.addPassthrough(pkt) + return nil + } + if !info.coalesceable() { + // TCP but not admissible (SYN/FIN/RST/URG/CWR or zero-payload). + // Seal this flow's open slot so later in-flow packets don't extend + // it and accidentally reorder past this passthrough. + if last := c.lastSlot; last != nil && last.fk == info.fk { + c.lastSlot = nil + } + delete(c.openSlots, info.fk) + c.addPassthrough(pkt) + return nil + } + + // Single-flow fast path: with only one open flow the cache hits every + // packet, and len(openSlots)==1 lets us skip the 38-byte fk compare + // when there are multiple flows in flight (where the hit rate would + // be ~0 and the compare is pure overhead). + var open *coalesceSlot + if last := c.lastSlot; last != nil && len(c.openSlots) == 1 && last.fk == info.fk { + open = last + } else { + open = c.openSlots[info.fk] + } + if open != nil { + if c.canAppend(open, pkt, info) { + c.appendPayload(open, pkt, info) + if open.psh { + delete(c.openSlots, info.fk) + c.lastSlot = nil + } else { + c.lastSlot = open + } + return nil + } + // Can't extend — seal it and fall through to seed a fresh slot. + delete(c.openSlots, info.fk) + if c.lastSlot == open { + c.lastSlot = nil + } + } + c.seed(pkt, info) + return nil +} + +// Flush emits every queued event in (per-flow) seq order. Coalesced slots +// go out via WriteGSO; passthrough slots go out via plainW.Write. +// reorderForFlush first sorts each flow's slots into TCP-seq order within +// passthrough-bounded segments and merges contiguous adjacent slots, so +// any wire-side reorder that crossed an rxOrder batch boundary doesn't +// get amplified into kernel-visible reorder by the slot machinery. +// Returns the first error observed; keeps draining so one bad packet +// doesn't hold up the rest. After Flush returns, borrowed payload slices +// may be recycled. +func (c *TCPCoalescer) Flush() error { + c.reorderForFlush() + var first error + for _, s := range c.slots { + var err error + if s.passthrough { + _, err = c.plainW.Write(s.rawPkt) + } else { + err = c.flushSlot(s) + } + if err != nil && first == nil { + first = err + } + c.release(s) + } + for i := range c.slots { + c.slots[i] = nil + } + c.slots = c.slots[:0] + for k := range c.openSlots { + delete(c.openSlots, k) + } + c.lastSlot = nil + + c.backing = c.backing[:0] + return first +} + +func (c *TCPCoalescer) addPassthrough(pkt []byte) { + s := c.take() + s.passthrough = true + s.rawPkt = pkt + c.slots = append(c.slots, s) +} + +func (c *TCPCoalescer) seed(pkt []byte, info parsedTCP) { + if info.hdrLen > tcpCoalesceHdrCap || info.hdrLen+info.payLen > tcpCoalesceBufSize { + // Pathological shape — can't fit our scratch, emit as-is. + c.addPassthrough(pkt) + return + } + s := c.take() + s.passthrough = false + s.rawPkt = nil + copy(s.hdrBuf[:], pkt[:info.hdrLen]) + s.hdrLen = info.hdrLen + s.ipHdrLen = info.ipHdrLen + s.isV6 = info.fk.isV6 + s.fk = info.fk + s.gsoSize = info.payLen + s.numSeg = 1 + s.totalPay = info.payLen + s.nextSeq = info.seq + uint32(info.payLen) + s.psh = info.flags&tcpFlagPsh != 0 + s.payIovs = append(s.payIovs[:0], pkt[info.hdrLen:info.hdrLen+info.payLen]) + c.slots = append(c.slots, s) + if !s.psh { + c.openSlots[info.fk] = s + c.lastSlot = s + } else if last := c.lastSlot; last != nil && last.fk == info.fk { + // PSH-on-seed seals the slot immediately. Any prior cached open + // slot for this flow has just been sealed-and-replaced by this + // passthrough-shaped seed, so drop the cache too. + c.lastSlot = nil + } +} + +// canAppend reports whether info's packet extends the slot's seed: same +// header shape and stable contents, adjacent seq, not oversized, chain not +// closed. +func (c *TCPCoalescer) canAppend(s *coalesceSlot, pkt []byte, info parsedTCP) bool { + if s.psh { + return false + } + if info.hdrLen != s.hdrLen { + return false + } + if info.seq != s.nextSeq { + return false + } + if s.numSeg >= tcpCoalesceMaxSegs { + return false + } + if info.payLen > s.gsoSize { + return false + } + if s.hdrLen+s.totalPay+info.payLen > tcpCoalesceBufSize { + return false + } + // ECE state must be stable across a burst — receivers expect the + // flag set on every segment of a CE-echoing window or none. + seedFlags := s.hdrBuf[s.ipHdrLen+13] + if (seedFlags^info.flags)&tcpFlagEce != 0 { + return false + } + if !headersMatch(s.hdrBuf[:s.hdrLen], pkt[:info.hdrLen], s.isV6, s.ipHdrLen) { + return false + } + return true +} + +func (c *TCPCoalescer) appendPayload(s *coalesceSlot, pkt []byte, info parsedTCP) { + s.payIovs = append(s.payIovs, pkt[info.hdrLen:info.hdrLen+info.payLen]) + s.numSeg++ + s.totalPay += info.payLen + s.nextSeq = info.seq + uint32(info.payLen) + if info.flags&tcpFlagPsh != 0 { + // Propagate PSH into the seed header so kernel TSO sets it on the + // last segment. Without this the sender's push signal is dropped. + s.hdrBuf[s.ipHdrLen+13] |= tcpFlagPsh + } + // Merge IP-level CE marks into the seed: headersMatch ignores ECN, so + // this is the one place the signal is preserved. + mergeECNIntoSeed(s.hdrBuf[:s.ipHdrLen], pkt[:s.ipHdrLen], s.isV6) + if info.payLen < s.gsoSize || info.flags&tcpFlagPsh != 0 { + s.psh = true + } +} + +func (c *TCPCoalescer) take() *coalesceSlot { + if n := len(c.pool); n > 0 { + s := c.pool[n-1] + c.pool[n-1] = nil + c.pool = c.pool[:n-1] + return s + } + return &coalesceSlot{} +} + +func (c *TCPCoalescer) release(s *coalesceSlot) { + s.passthrough = false + s.rawPkt = nil + for i := range s.payIovs { + s.payIovs[i] = nil + } + s.payIovs = s.payIovs[:0] + s.numSeg = 0 + s.totalPay = 0 + s.psh = false + c.pool = append(c.pool, s) +} + +// flushSlot patches the header and calls WriteGSO. Does not remove the +// slot from c.slots. +func (c *TCPCoalescer) flushSlot(s *coalesceSlot) error { + total := s.hdrLen + s.totalPay + l4Len := total - s.ipHdrLen + hdr := s.hdrBuf[:s.hdrLen] + + if s.isV6 { + binary.BigEndian.PutUint16(hdr[4:6], uint16(l4Len)) + } else { + binary.BigEndian.PutUint16(hdr[2:4], uint16(total)) + hdr[10] = 0 + hdr[11] = 0 + binary.BigEndian.PutUint16(hdr[10:12], ipv4HdrChecksum(hdr[:s.ipHdrLen])) + } + + var psum uint32 + if s.isV6 { + psum = pseudoSumIPv6(hdr[8:24], hdr[24:40], ipProtoTCP, l4Len) + } else { + psum = pseudoSumIPv4(hdr[12:16], hdr[16:20], ipProtoTCP, l4Len) + } + tcsum := s.ipHdrLen + 16 + binary.BigEndian.PutUint16(hdr[tcsum:tcsum+2], foldOnceNoInvert(psum)) + + return c.gsoW.WriteGSO(hdr[:s.ipHdrLen], hdr[s.ipHdrLen:], s.payIovs, tio.GSOProtoTCP) +} + +// headersMatch compares two IP+TCP header prefixes for byte-for-byte +// equality on every field that must be identical across coalesced +// segments. Size/IPID/IPCsum/seq/flags/tcpCsum are masked out, as is the +// 2-bit IP-level ECN field — appendPayload merges CE into the seed. +func headersMatch(a, b []byte, isV6 bool, ipHdrLen int) bool { + if len(a) != len(b) { + return false + } + if !ipHeadersMatch(a, b, isV6) { + return false + } + // TCP: compare [0:4] ports, [8:13] ack+dataoff, [14:16] window, + // [18:tcpHdrLen] options (incl. urgent). + tcp := ipHdrLen + if !bytes.Equal(a[tcp:tcp+4], b[tcp:tcp+4]) { + return false + } + if !bytes.Equal(a[tcp+8:tcp+13], b[tcp+8:tcp+13]) { + return false + } + if !bytes.Equal(a[tcp+14:tcp+16], b[tcp+14:tcp+16]) { + return false + } + if !bytes.Equal(a[tcp+18:], b[tcp+18:]) { + return false + } + return true +} + +// reorderForFlush neutralizes wire-side reorder that the rxOrder buffer +// couldn't catch (anything crossing a recvmmsg batch boundary). Without +// this pass a small wire reorder — counter 250 arriving in batch K when +// 200..249 are coming in batch K+1 — would seed an out-of-seq slot first +// and emit it ahead of the lower-seq slot, manifesting at the inner TCP +// receiver as a much larger reorder than the wire actually had. +// +// Two phases: +// 1. Sort each passthrough-bounded segment of c.slots by (flow, seq). +// Cross-flow ordering inside a segment isn't preserved (it never was +// and doesn't matter for any single flow's TCP correctness). +// 2. Sweep once and merge adjacent same-flow slots whose ranges are now +// contiguous AND whose tail is gsoSize-aligned. The tail constraint +// matters because the kernel TSO splitter chops at gsoSize from the +// start of the merged payload — a short segment in the middle would +// desynchronize every later segment. +// +// Passthrough slots act as barriers: the merge check skips them on either +// side, so a SYN/FIN/RST/CWR is never reordered relative to its flow's +// data. +func (c *TCPCoalescer) reorderForFlush() { + if len(c.slots) <= 1 { + return + } + runStart := 0 + for i := 0; i <= len(c.slots); i++ { + if i < len(c.slots) && !c.slots[i].passthrough { + continue + } + c.sortRun(c.slots[runStart:i]) + runStart = i + 1 + } + out := c.slots[:0] + logged := false + for _, s := range c.slots { + if n := len(out); n > 0 { + prev := out[n-1] + if !prev.passthrough && !s.passthrough && prev.fk == s.fk { + // Same-flow neighbors after sort. If they aren't seq- + // contiguous it's a real gap — packets the wire reordered + // across batches, or actual loss before nebula. Log it so + // the operator can quantify how often it happens; the data + // itself still emits in seq order, kernel TCP handles the + // gap via its OOO queue. + if prev.nextSeq != slotSeedSeq(s) { + logged = true + gap := int64(slotSeedSeq(s)) - int64(prev.nextSeq) + slog.Default().Warn("tcp coalesce: cross-slot seq gap", + "src", flowKeyAddr(s.fk, false), + "dst", flowKeyAddr(s.fk, true), + "sport", s.fk.sport, + "dport", s.fk.dport, + "prev_seed_seq", slotSeedSeq(prev), + "prev_next_seq", prev.nextSeq, + "this_seed_seq", slotSeedSeq(s), + "gap_bytes", gap, + "prev_seg_count", prev.numSeg, + "prev_total_pay", prev.totalPay, + ) + } + if canMergeSlots(prev, s) { + mergeSlots(prev, s) + c.release(s) + continue + } + } + } + out = append(out, s) + } + if logged { + slog.Default().Warn("==== end of batch ====") + } + c.slots = out +} + +// flowKeyAddr returns the src or dst address from fk as a netip.Addr for +// logging. Only used on the cold gap-log path so the netip allocation +// doesn't matter. +func flowKeyAddr(fk flowKey, dst bool) netip.Addr { + src := fk.src + if dst { + src = fk.dst + } + if fk.isV6 { + return netip.AddrFrom16(src) + } + var v4 [4]byte + copy(v4[:], src[:4]) + return netip.AddrFrom4(v4) +} + +// sortRun stable-sorts run by (flowKey, seedSeq) so each flow's slots +// cluster together in seq order, ready for the merge sweep. Stable so +// equal-key slots keep their original relative position (defensive — a +// duplicate seedSeq would already mean something's wrong upstream). +func (c *TCPCoalescer) sortRun(run []*coalesceSlot) { + if len(run) <= 1 { + return + } + sort.SliceStable(run, func(i, j int) bool { + a, b := run[i], run[j] + if cmp := flowKeyCompare(a.fk, b.fk); cmp != 0 { + return cmp < 0 + } + return tcpSeqLess(slotSeedSeq(a), slotSeedSeq(b)) + }) +} + +// slotSeedSeq returns the TCP seq of the slot's seed (first segment). +// nextSeq tracks the seq just past the last appended byte; subtracting +// totalPay walks back to the seed. uint32 wraparound is the right TCP +// arithmetic so no special-casing is needed. +func slotSeedSeq(s *coalesceSlot) uint32 { + return s.nextSeq - uint32(s.totalPay) +} + +// tcpSeqLess reports whether a precedes b in TCP serial-number arithmetic +// (RFC 1323 §2.3). The signed int32 cast turns the modular subtraction +// into the right comparison even across the 2^32 wrap. +func tcpSeqLess(a, b uint32) bool { + return int32(a-b) < 0 +} + +// flowKeyCompare orders flowKeys deterministically. The exact ordering +// is irrelevant — only that same-flow slots cluster together so the +// post-sort sweep can merge contiguous pairs. +func flowKeyCompare(a, b flowKey) int { + if c := bytes.Compare(a.src[:], b.src[:]); c != 0 { + return c + } + if c := bytes.Compare(a.dst[:], b.dst[:]); c != 0 { + return c + } + if a.sport != b.sport { + if a.sport < b.sport { + return -1 + } + return 1 + } + if a.dport != b.dport { + if a.dport < b.dport { + return -1 + } + return 1 + } + if a.isV6 != b.isV6 { + if !a.isV6 { + return -1 + } + return 1 + } + return 0 +} + +// canMergeSlots reports whether s can fold into prev as one merged TSO +// superpacket. Same flow, contiguous TCP byte range, equal gsoSize, and +// fits within the kernel TSO limits. The tail-of-prev check rejects any +// merge whose first slot ended on a sub-gsoSize segment — kernel TSO +// would split the merged skb at gsoSize boundaries from the start, so a +// short segment in the middle would corrupt every later segment. PSH and +// ECE state must agree across both slots: PSH is a semantic delimiter +// (preserving the sender's push boundary) and ECE state must be uniform +// across a window (the same rule canAppend enforces for in-flow appends). +// +// Note: a slot sealed by reorder (canAppend returned false on seq +// mismatch) keeps psh=false, so this restriction does not block the +// reorder-fix merge — only legitimate PSH-set seals. +func canMergeSlots(prev, s *coalesceSlot) bool { + if prev.psh { + return false + } + if prev.fk != s.fk { + return false + } + if prev.gsoSize != s.gsoSize { + return false + } + if prev.nextSeq != slotSeedSeq(s) { + return false + } + if prev.numSeg+s.numSeg > tcpCoalesceMaxSegs { + return false + } + if prev.hdrLen+prev.totalPay+s.totalPay > tcpCoalesceBufSize { + return false + } + if len(prev.payIovs[len(prev.payIovs)-1]) != prev.gsoSize { + return false + } + prevFlags := prev.hdrBuf[prev.ipHdrLen+13] + sFlags := s.hdrBuf[s.ipHdrLen+13] + if (prevFlags^sFlags)&tcpFlagEce != 0 { + return false + } + if !headersMatch(prev.hdrBuf[:prev.hdrLen], s.hdrBuf[:s.hdrLen], prev.isV6, prev.ipHdrLen) { + return false + } + return true +} + +// mergeSlots folds src into dst in place: payIovs concatenated, counters +// and totals updated, PSH and IP-level CE bits OR'd into the seed header +// so neither the push signal nor a CE mark is lost. The seed header's +// seq, gsoSize, and fk are unchanged. Caller is responsible for releasing +// src (it's no longer in c.slots after this call). +func mergeSlots(dst, src *coalesceSlot) { + dst.payIovs = append(dst.payIovs, src.payIovs...) + dst.numSeg += src.numSeg + dst.totalPay += src.totalPay + dst.nextSeq = src.nextSeq + if src.psh { + dst.psh = true + dst.hdrBuf[dst.ipHdrLen+13] |= tcpFlagPsh + } + mergeECNIntoSeed(dst.hdrBuf[:dst.ipHdrLen], src.hdrBuf[:src.ipHdrLen], dst.isV6) +} + +// ipv4HdrChecksum computes the IPv4 header checksum over hdr (which must +// already have its checksum field zeroed) and returns the folded/inverted +// 16-bit value to store. +func ipv4HdrChecksum(hdr []byte) uint16 { + var sum uint32 + for i := 0; i+1 < len(hdr); i += 2 { + sum += uint32(binary.BigEndian.Uint16(hdr[i : i+2])) + } + if len(hdr)%2 == 1 { + sum += uint32(hdr[len(hdr)-1]) << 8 + } + for sum>>16 != 0 { + sum = (sum & 0xffff) + (sum >> 16) + } + return ^uint16(sum) +} + +// pseudoSumIPv4 / pseudoSumIPv6 build the L4 pseudo-header partial sum +// expected by the virtio NEEDS_CSUM kernel path: the 32-bit accumulator +// before folding. proto selects the L4 (TCP or UDP); the UDP coalescer +// reuses these helpers. +func pseudoSumIPv4(src, dst []byte, proto byte, l4Len int) uint32 { + var sum uint32 + sum += uint32(binary.BigEndian.Uint16(src[0:2])) + sum += uint32(binary.BigEndian.Uint16(src[2:4])) + sum += uint32(binary.BigEndian.Uint16(dst[0:2])) + sum += uint32(binary.BigEndian.Uint16(dst[2:4])) + sum += uint32(proto) + sum += uint32(l4Len) + return sum +} + +func pseudoSumIPv6(src, dst []byte, proto byte, l4Len int) uint32 { + var sum uint32 + for i := 0; i < 16; i += 2 { + sum += uint32(binary.BigEndian.Uint16(src[i : i+2])) + sum += uint32(binary.BigEndian.Uint16(dst[i : i+2])) + } + sum += uint32(l4Len >> 16) + sum += uint32(l4Len & 0xffff) + sum += uint32(proto) + return sum +} + +// foldOnceNoInvert folds the 32-bit accumulator to 16 bits and returns it +// unchanged (no one's complement). This is what virtio NEEDS_CSUM wants in +// the L4 checksum field — the kernel will add the payload sum and invert. +func foldOnceNoInvert(sum uint32) uint16 { + for sum>>16 != 0 { + sum = (sum & 0xffff) + (sum >> 16) + } + return uint16(sum) +} diff --git a/overlay/batch/tcp_coalesce_bench_test.go b/overlay/batch/tcp_coalesce_bench_test.go new file mode 100644 index 00000000..54da995a --- /dev/null +++ b/overlay/batch/tcp_coalesce_bench_test.go @@ -0,0 +1,173 @@ +package batch + +import ( + "encoding/binary" + "testing" + + "github.com/slackhq/nebula/overlay/tio" +) + +// nopTunWriter is a zero-alloc tio.GSOWriter for benchmarks. Discards +// everything but satisfies the interface the coalescer detects. +type nopTunWriter struct{} + +func (nopTunWriter) Write(p []byte) (int, error) { return len(p), nil } +func (nopTunWriter) WriteGSO(hdr []byte, transportHdr []byte, pays [][]byte, _ tio.GSOProto) error { + return nil +} +func (nopTunWriter) Capabilities() tio.Capabilities { + return tio.Capabilities{TSO: true, USO: true} +} + +// buildTCPv4BulkFlow returns a slice of N adjacent ACK-only TCP segments +// on a single 5-tuple, each carrying payloadLen bytes. Seq numbers are +// contiguous so every packet is coalesceable onto the previous one. +func buildTCPv4BulkFlow(n, payloadLen int) [][]byte { + pkts := make([][]byte, n) + pay := make([]byte, payloadLen) + seq := uint32(1000) + for i := range n { + pkts[i] = buildTCPv4(seq, tcpAck, pay) + seq += uint32(payloadLen) + } + return pkts +} + +// buildTCPv4Interleaved returns nFlows * perFlow packets with per-flow +// seq continuity but round-robin across flows — worst case for any +// "last-slot" cache. +func buildTCPv4Interleaved(nFlows, perFlow, payloadLen int) [][]byte { + pay := make([]byte, payloadLen) + seqs := make([]uint32, nFlows) + for i := range seqs { + seqs[i] = uint32(1000 + i*1000000) + } + pkts := make([][]byte, 0, nFlows*perFlow) + for range perFlow { + for f := range nFlows { + sport := uint16(10000 + f) + pkts = append(pkts, buildTCPv4Ports(sport, 2000, seqs[f], tcpAck, pay)) + seqs[f] += uint32(payloadLen) + } + } + return pkts +} + +// buildICMPv4 returns a minimal non-TCP packet that takes the passthrough +// branch in Commit. +func buildICMPv4() []byte { + pkt := make([]byte, 28) + pkt[0] = 0x45 + binary.BigEndian.PutUint16(pkt[2:4], 28) + pkt[9] = 1 // ICMP + copy(pkt[12:16], []byte{10, 0, 0, 1}) + copy(pkt[16:20], []byte{10, 0, 0, 2}) + return pkt +} + +// runCommitBench drives Commit over pkts batchSize at a time, flushing +// between batches, and reports per-packet cost. +func runCommitBench(b *testing.B, pkts [][]byte, batchSize int) { + b.Helper() + c := NewTCPCoalescer(nopTunWriter{}) + b.ReportAllocs() + b.SetBytes(int64(len(pkts[0]))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + pkt := pkts[i%len(pkts)] + if err := c.Commit(pkt); err != nil { + b.Fatal(err) + } + if (i+1)%batchSize == 0 { + if err := c.Flush(); err != nil { + b.Fatal(err) + } + } + } + // Drain any trailing partial batch so slot state doesn't leak across runs. + _ = c.Flush() +} + +// BenchmarkCommitSingleFlow is the bulk-TCP steady state: one flow, +// contiguous seq, 1200-byte payloads. Every packet past the seed should +// append onto the open slot. This is the case we most care about. +func BenchmarkCommitSingleFlow(b *testing.B) { + pkts := buildTCPv4BulkFlow(tcpCoalesceMaxSegs, 1200) + runCommitBench(b, pkts, tcpCoalesceMaxSegs) +} + +// BenchmarkCommitInterleaved4 has 4 concurrent bulk flows round-robined. +// A single-entry fast-path cache will miss on every packet; an N-way +// cache or map lookup carries the weight. +func BenchmarkCommitInterleaved4(b *testing.B) { + pkts := buildTCPv4Interleaved(4, tcpCoalesceMaxSegs, 1200) + runCommitBench(b, pkts, len(pkts)) +} + +// BenchmarkCommitInterleaved16 stresses the map at higher flow counts. +func BenchmarkCommitInterleaved16(b *testing.B) { + pkts := buildTCPv4Interleaved(16, tcpCoalesceMaxSegs, 1200) + runCommitBench(b, pkts, len(pkts)) +} + +// BenchmarkCommitPassthrough exercises the non-TCP branch: parseTCPBase +// bails early and addPassthrough is the only work. +func BenchmarkCommitPassthrough(b *testing.B) { + pkt := buildICMPv4() + pkts := make([][]byte, 64) + for i := range pkts { + pkts[i] = pkt + } + runCommitBench(b, pkts, 64) +} + +// BenchmarkCommitNonCoalesceableTCP sends SYN|ACK packets on one flow. +// Each packet takes the "TCP but not admissible" branch which does a +// map delete + passthrough. Measures the seal-without-slot cost. +func BenchmarkCommitNonCoalesceableTCP(b *testing.B) { + pay := make([]byte, 0) + pkts := make([][]byte, 64) + for i := range pkts { + pkts[i] = buildTCPv4(uint32(1000+i), tcpSyn|tcpAck, pay) + } + runCommitBench(b, pkts, 64) +} + +// runMultiCommitBench drives MultiCoalescer.Commit. The dispatcher does +// the IP/L4 parse once and passes the parsed struct to the lane, so this +// is the bench that shows the savings of skipping the lane's re-parse. +func runMultiCommitBench(b *testing.B, pkts [][]byte, batchSize int) { + b.Helper() + m := NewMultiCoalescer(nopTunWriter{}, true, true) + b.ReportAllocs() + b.SetBytes(int64(len(pkts[0]))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + pkt := pkts[i%len(pkts)] + if err := m.Commit(pkt); err != nil { + b.Fatal(err) + } + if (i+1)%batchSize == 0 { + if err := m.Flush(); err != nil { + b.Fatal(err) + } + } + } + _ = m.Flush() +} + +// BenchmarkMultiCommitSingleFlow is the multi-lane analogue of +// BenchmarkCommitSingleFlow — same workload but routed through the +// dispatcher. The delta vs the single-lane bench measures dispatcher +// overhead. +func BenchmarkMultiCommitSingleFlow(b *testing.B) { + pkts := buildTCPv4BulkFlow(tcpCoalesceMaxSegs, 1200) + runMultiCommitBench(b, pkts, tcpCoalesceMaxSegs) +} + +// BenchmarkMultiCommitInterleaved4 mirrors BenchmarkCommitInterleaved4 +// through the dispatcher. +func BenchmarkMultiCommitInterleaved4(b *testing.B) { + pkts := buildTCPv4Interleaved(4, tcpCoalesceMaxSegs, 1200) + runMultiCommitBench(b, pkts, len(pkts)) +} diff --git a/overlay/batch/tcp_coalesce_test.go b/overlay/batch/tcp_coalesce_test.go new file mode 100644 index 00000000..84e78cdd --- /dev/null +++ b/overlay/batch/tcp_coalesce_test.go @@ -0,0 +1,1022 @@ +package batch + +import ( + "encoding/binary" + "testing" + + "github.com/slackhq/nebula/overlay/tio" +) + +// fakeTunWriter records plain Writes and WriteGSO calls without touching a +// real TUN fd. WriteGSO records the IP header, transport header, and +// borrowed payload fragments separately so tests can inspect each. +type fakeTunWriter struct { + gsoEnabled bool + writes [][]byte + gsoWrites []fakeGSOWrite +} + +// fakeGSOWrite captures one WriteGSO call. hdr is the concatenation of the +// IP and transport headers (in that order), gsoSize / isV6 / csumStart are +// derived from the call so existing assertions keep working unchanged. +type fakeGSOWrite struct { + hdr []byte + pays [][]byte + gsoSize uint16 + isV6 bool + csumStart uint16 +} + +// total returns hdrLen + sum of pay lens. +func (g fakeGSOWrite) total() int { + n := len(g.hdr) + for _, p := range g.pays { + n += len(p) + } + return n +} + +// payLen sums the pays. +func (g fakeGSOWrite) payLen() int { + var n int + for _, p := range g.pays { + n += len(p) + } + return n +} + +func (w *fakeTunWriter) Write(p []byte) (int, error) { + buf := make([]byte, len(p)) + copy(buf, p) + w.writes = append(w.writes, buf) + return len(p), nil +} + +func (w *fakeTunWriter) WriteGSO(hdr []byte, transportHdr []byte, pays [][]byte, _ tio.GSOProto) error { + hcopy := make([]byte, len(hdr)+len(transportHdr)) + copy(hcopy, hdr) + copy(hcopy[len(hdr):], transportHdr) + paysCopy := make([][]byte, len(pays)) + for i, p := range pays { + pc := make([]byte, len(p)) + copy(pc, p) + paysCopy[i] = pc + } + var gsoSize uint16 + if len(pays) > 1 { + gsoSize = uint16(len(pays[0])) + } + isV6 := len(hdr) > 0 && hdr[0]>>4 == 6 + w.gsoWrites = append(w.gsoWrites, fakeGSOWrite{ + hdr: hcopy, + pays: paysCopy, + gsoSize: gsoSize, + isV6: isV6, + csumStart: uint16(len(hdr)), + }) + return nil +} + +func (w *fakeTunWriter) Capabilities() tio.Capabilities { + return tio.Capabilities{TSO: w.gsoEnabled, USO: w.gsoEnabled} +} + +// buildTCPv4 constructs a minimal IPv4+TCP packet with the given payload, +// seq, and flags. Assumes no IP options and a 20-byte TCP header. +func buildTCPv4(seq uint32, flags byte, payload []byte) []byte { + return buildTCPv4Ports(1000, 2000, seq, flags, payload) +} + +// buildTCPv4Ports is buildTCPv4 with caller-specified ports so tests can +// build distinct flows. +func buildTCPv4Ports(sport, dport uint16, seq uint32, flags byte, payload []byte) []byte { + const ipHdrLen = 20 + const tcpHdrLen = 20 + total := ipHdrLen + tcpHdrLen + len(payload) + pkt := make([]byte, total) + + pkt[0] = 0x45 + pkt[1] = 0x00 + binary.BigEndian.PutUint16(pkt[2:4], uint16(total)) + binary.BigEndian.PutUint16(pkt[4:6], 0) + binary.BigEndian.PutUint16(pkt[6:8], 0x4000) + pkt[8] = 64 + pkt[9] = ipProtoTCP + copy(pkt[12:16], []byte{10, 0, 0, 1}) + copy(pkt[16:20], []byte{10, 0, 0, 2}) + + binary.BigEndian.PutUint16(pkt[20:22], sport) + binary.BigEndian.PutUint16(pkt[22:24], dport) + binary.BigEndian.PutUint32(pkt[24:28], seq) + binary.BigEndian.PutUint32(pkt[28:32], 12345) + pkt[32] = 0x50 + pkt[33] = flags + binary.BigEndian.PutUint16(pkt[34:36], 0xffff) + + copy(pkt[40:], payload) + return pkt +} + +const ( + tcpAck = 0x10 + tcpPsh = 0x08 + tcpSyn = 0x02 + tcpFin = 0x01 + tcpAckPsh = tcpAck | tcpPsh +) + +func TestCoalescerPassthroughWhenGSOUnavailable(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: false} + c := NewTCPCoalescer(w) + pkt := buildTCPv4(1000, tcpAck, []byte("hello")) + if err := c.Commit(pkt); err != nil { + t.Fatal(err) + } + // No sync write — passthrough is deferred to Flush. + if len(w.writes) != 0 || len(w.gsoWrites) != 0 { + t.Fatalf("no Add-time writes: got writes=%d gso=%d", len(w.writes), len(w.gsoWrites)) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + if len(w.writes) != 1 || len(w.gsoWrites) != 0 { + t.Fatalf("want single plain write, got writes=%d gso=%d", len(w.writes), len(w.gsoWrites)) + } +} + +func TestCoalescerNonTCPPassthrough(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := NewTCPCoalescer(w) + pkt := make([]byte, 28) + pkt[0] = 0x45 + binary.BigEndian.PutUint16(pkt[2:4], 28) + pkt[9] = 1 + copy(pkt[12:16], []byte{10, 0, 0, 1}) + copy(pkt[16:20], []byte{10, 0, 0, 2}) + if err := c.Commit(pkt); err != nil { + t.Fatal(err) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + if len(w.writes) != 1 || len(w.gsoWrites) != 0 { + t.Fatalf("ICMP should pass through unchanged") + } +} + +func TestCoalescerSeedThenFlushAlone(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := NewTCPCoalescer(w) + pkt := buildTCPv4(1000, tcpAck, make([]byte, 1000)) + if err := c.Commit(pkt); err != nil { + t.Fatal(err) + } + if len(w.writes) != 0 || len(w.gsoWrites) != 0 { + t.Fatalf("unexpected output before flush") + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + // Single-segment flush goes through WriteGSO with GSO_NONE + // (virtio NEEDS_CSUM lets the kernel fill in the L4 csum). + if len(w.gsoWrites) != 1 || len(w.writes) != 0 { + t.Fatalf("single-seg flush: writes=%d gso=%d", len(w.writes), len(w.gsoWrites)) + } + g := w.gsoWrites[0] + if g.total() != 40+1000 { + t.Errorf("super total=%d want %d", g.total(), 40+1000) + } + if g.payLen() != 1000 { + t.Errorf("payLen=%d want 1000", g.payLen()) + } +} + +func TestCoalescerCoalescesAdjacentACKs(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := NewTCPCoalescer(w) + pay := make([]byte, 1200) + if err := c.Commit(buildTCPv4(1000, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Commit(buildTCPv4(2200, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Commit(buildTCPv4(3400, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + if len(w.gsoWrites) != 1 { + t.Fatalf("want 1 gso write, got %d (plain=%d)", len(w.gsoWrites), len(w.writes)) + } + g := w.gsoWrites[0] + if g.gsoSize != 1200 { + t.Errorf("gsoSize=%d want 1200", g.gsoSize) + } + if len(g.hdr) != 40 { + t.Errorf("hdrLen=%d want 40", len(g.hdr)) + } + if g.csumStart != 20 { + t.Errorf("csumStart=%d want 20", g.csumStart) + } + if len(g.pays) != 3 { + t.Errorf("pay count=%d want 3", len(g.pays)) + } + if g.total() != 40+3*1200 { + t.Errorf("superpacket len=%d want %d", g.total(), 40+3*1200) + } + if tot := binary.BigEndian.Uint16(g.hdr[2:4]); int(tot) != g.total() { + t.Errorf("ip total_length=%d want %d", tot, g.total()) + } +} + +func TestCoalescerRejectsSeqGap(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := NewTCPCoalescer(w) + pay := make([]byte, 1200) + if err := c.Commit(buildTCPv4(1000, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Commit(buildTCPv4(3000, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + // Each packet flushes as its own single-segment WriteGSO. + if len(w.gsoWrites) != 2 || len(w.writes) != 0 { + t.Fatalf("seq gap: want 2 gso writes got writes=%d gso=%d", len(w.writes), len(w.gsoWrites)) + } +} + +func TestCoalescerRejectsFlagMismatch(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := NewTCPCoalescer(w) + pay := make([]byte, 1200) + if err := c.Commit(buildTCPv4(1000, tcpAck, pay)); err != nil { + t.Fatal(err) + } + // SYN|ACK is non-admissible. Must flush matching flow's slot (gso) + // and then plain-write the SYN packet itself. + syn := buildTCPv4(2200, tcpSyn|tcpAck, pay) + if err := c.Commit(syn); err != nil { + t.Fatal(err) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + if len(w.writes) != 1 || len(w.gsoWrites) != 1 { + t.Fatalf("flag mismatch: want 1 plain + 1 gso, got writes=%d gso=%d", len(w.writes), len(w.gsoWrites)) + } +} + +func TestCoalescerRejectsFIN(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := NewTCPCoalescer(w) + fin := buildTCPv4(1000, tcpAck|tcpFin, []byte("x")) + if err := c.Commit(fin); err != nil { + t.Fatal(err) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + // FIN isn't admissible — passthrough as plain, no slot, no gso. + if len(w.writes) != 1 || len(w.gsoWrites) != 0 { + t.Fatalf("FIN should be passthrough, got writes=%d gso=%d", len(w.writes), len(w.gsoWrites)) + } +} + +func TestCoalescerShortLastSegmentClosesChain(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := NewTCPCoalescer(w) + full := make([]byte, 1200) + half := make([]byte, 500) + if err := c.Commit(buildTCPv4(1000, tcpAck, full)); err != nil { + t.Fatal(err) + } + if err := c.Commit(buildTCPv4(2200, tcpAck, half)); err != nil { + t.Fatal(err) + } + // Chain now closed; next packet seeds a new slot on the same flow + // after flushing the old one. + if err := c.Commit(buildTCPv4(2700, tcpAck, full)); err != nil { + t.Fatal(err) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + // Expect two gso writes: the first two packets coalesced, then the + // third flushed alone (single-seg via GSO_NONE). + if len(w.gsoWrites) != 2 { + t.Fatalf("want 2 gso writes got %d", len(w.gsoWrites)) + } + if len(w.writes) != 0 { + t.Fatalf("want 0 plain writes got %d", len(w.writes)) + } + if w.gsoWrites[0].gsoSize != 1200 { + t.Errorf("gsoSize=%d want 1200", w.gsoWrites[0].gsoSize) + } + if got, want := w.gsoWrites[0].total(), 40+1200+500; got != want { + t.Errorf("super len=%d want %d", got, want) + } +} + +func TestCoalescerPSHFinalizesChain(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := NewTCPCoalescer(w) + pay := make([]byte, 1200) + if err := c.Commit(buildTCPv4(1000, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Commit(buildTCPv4(2200, tcpAckPsh, pay)); err != nil { + t.Fatal(err) + } + if err := c.Commit(buildTCPv4(3400, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + // First two coalesce; the third seeds a fresh slot that flushes alone. + if len(w.gsoWrites) != 2 { + t.Fatalf("want 2 gso writes got %d", len(w.gsoWrites)) + } + if len(w.writes) != 0 { + t.Fatalf("want 0 plain writes got %d", len(w.writes)) + } +} + +// TestCoalescerPropagatesPSHFromAppended ensures that when an appended +// segment carries PSH (or is short, sealing the chain), the PSH bit ends +// up in the emitted superpacket's TCP flags. The kernel TSO path keeps +// PSH only on the last segment iff the input header has it set; if the +// coalescer drops it the sender's push signal never reaches the receiver. +func TestCoalescerPropagatesPSHFromAppended(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := NewTCPCoalescer(w) + pay := make([]byte, 1200) + // Seed has no PSH; second segment carries PSH and seals the chain. + if err := c.Commit(buildTCPv4(1000, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Commit(buildTCPv4(2200, tcpAckPsh, pay)); err != nil { + t.Fatal(err) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + if len(w.gsoWrites) != 1 { + t.Fatalf("want 1 gso write got %d", len(w.gsoWrites)) + } + g := w.gsoWrites[0] + const ipHdrLen = 20 + flags := g.hdr[ipHdrLen+13] + if flags&tcpPsh == 0 { + t.Fatalf("PSH lost from coalesced superpacket: flags=0x%02x", flags) + } + if flags&tcpAck == 0 { + t.Fatalf("ACK missing from coalesced superpacket: flags=0x%02x", flags) + } +} + +func TestCoalescerRejectsDifferentFlow(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := NewTCPCoalescer(w) + pay := make([]byte, 1200) + p1 := buildTCPv4(1000, tcpAck, pay) + p2 := buildTCPv4(2200, tcpAck, pay) + binary.BigEndian.PutUint16(p2[20:22], 9999) + if err := c.Commit(p1); err != nil { + t.Fatal(err) + } + if err := c.Commit(p2); err != nil { + t.Fatal(err) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + // Two independent flows, each flushes its own single-segment WriteGSO. + if len(w.gsoWrites) != 2 || len(w.writes) != 0 { + t.Fatalf("diff flow: want 2 gso writes got writes=%d gso=%d", len(w.writes), len(w.gsoWrites)) + } +} + +func TestCoalescerRejectsIPOptions(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := NewTCPCoalescer(w) + pay := make([]byte, 500) + pkt := buildTCPv4(1000, tcpAck, pay) + // Bump IHL to 6 to simulate 4 bytes of IP options. Don't actually add + // bytes — parser should bail before it matters. + pkt[0] = 0x46 + if err := c.Commit(pkt); err != nil { + t.Fatal(err) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + // Non-admissible parse → passthrough as plain. + if len(w.writes) != 1 || len(w.gsoWrites) != 0 { + t.Fatalf("IP options should passthrough, got writes=%d gso=%d", len(w.writes), len(w.gsoWrites)) + } +} + +func TestCoalescerCapBySegments(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := NewTCPCoalescer(w) + pay := make([]byte, 512) + seq := uint32(1000) + for i := 0; i < tcpCoalesceMaxSegs+5; i++ { + if err := c.Commit(buildTCPv4(seq, tcpAck, pay)); err != nil { + t.Fatal(err) + } + seq += uint32(len(pay)) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + for _, g := range w.gsoWrites { + segs := len(g.pays) + if segs > tcpCoalesceMaxSegs { + t.Fatalf("super exceeded seg cap: %d > %d", segs, tcpCoalesceMaxSegs) + } + } +} + +// TestCoalescerMultipleFlowsInSameBatch proves two interleaved bulk TCP +// flows coalesce independently in a single Flush. +func TestCoalescerMultipleFlowsInSameBatch(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := NewTCPCoalescer(w) + pay := make([]byte, 1200) + + // Flow A: sport 1000. Flow B: sport 3000. + if err := c.Commit(buildTCPv4Ports(1000, 2000, 100, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Commit(buildTCPv4Ports(3000, 2000, 500, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Commit(buildTCPv4Ports(1000, 2000, 1300, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Commit(buildTCPv4Ports(3000, 2000, 1700, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Commit(buildTCPv4Ports(1000, 2000, 2500, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Commit(buildTCPv4Ports(3000, 2000, 2900, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + + if len(w.gsoWrites) != 2 { + t.Fatalf("want 2 gso writes (one per flow), got %d", len(w.gsoWrites)) + } + if len(w.writes) != 0 { + t.Fatalf("want no plain writes, got %d", len(w.writes)) + } + // Each superpacket should carry 3 segments. + for i, g := range w.gsoWrites { + if len(g.pays) != 3 { + t.Errorf("gso[%d]: segs=%d want 3", i, len(g.pays)) + } + if g.gsoSize != 1200 { + t.Errorf("gso[%d]: gsoSize=%d want 1200", i, g.gsoSize) + } + } + // Verify each superpacket carries the source port it was seeded with. + seenSports := map[uint16]bool{} + for _, g := range w.gsoWrites { + sp := binary.BigEndian.Uint16(g.hdr[20:22]) + seenSports[sp] = true + } + if !seenSports[1000] || !seenSports[3000] { + t.Errorf("expected superpackets for sports 1000 and 3000, got %v", seenSports) + } +} + +// TestCoalescerPreservesArrivalOrder confirms that with passthrough and +// coalesced events both queued, Flush emits them in Add order rather than +// writing passthrough packets synchronously. +func TestCoalescerPreservesArrivalOrder(t *testing.T) { + w := &orderedFakeWriter{gsoEnabled: true} + c := NewTCPCoalescer(w) + // Sequence: coalesceable TCP, ICMP (passthrough), coalesceable TCP on + // a different flow. Expected emit order: gso(X), plain(ICMP), gso(Y). + pay := make([]byte, 1200) + if err := c.Commit(buildTCPv4Ports(1000, 2000, 100, tcpAck, pay)); err != nil { + t.Fatal(err) + } + icmp := make([]byte, 28) + icmp[0] = 0x45 + binary.BigEndian.PutUint16(icmp[2:4], 28) + icmp[9] = 1 + copy(icmp[12:16], []byte{10, 0, 0, 1}) + copy(icmp[16:20], []byte{10, 0, 0, 3}) + if err := c.Commit(icmp); err != nil { + t.Fatal(err) + } + if err := c.Commit(buildTCPv4Ports(3000, 2000, 500, tcpAck, pay)); err != nil { + t.Fatal(err) + } + // Nothing should have hit the writer synchronously. + if len(w.events) != 0 { + t.Fatalf("Add emitted events synchronously: %v", w.events) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + if got, want := w.events, []string{"gso", "plain", "gso"}; !stringSliceEq(got, want) { + t.Fatalf("flush order=%v want %v", got, want) + } +} + +// orderedFakeWriter records only the sequence of call types so tests can +// assert arrival order without inspecting bytes. +type orderedFakeWriter struct { + gsoEnabled bool + events []string +} + +func (w *orderedFakeWriter) Write(p []byte) (int, error) { + w.events = append(w.events, "plain") + return len(p), nil +} + +func (w *orderedFakeWriter) WriteGSO(hdr []byte, transportHdr []byte, pays [][]byte, _ tio.GSOProto) error { + w.events = append(w.events, "gso") + return nil +} + +func (w *orderedFakeWriter) Capabilities() tio.Capabilities { + return tio.Capabilities{TSO: w.gsoEnabled, USO: w.gsoEnabled} +} + +func stringSliceEq(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} + +// TestCoalescerInterleavedFlowsPreserveOrdering checks that a non-admissible +// packet (SYN) mid-flow only flushes its own flow, not others. +func TestCoalescerInterleavedFlowsPreserveOrdering(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := NewTCPCoalescer(w) + pay := make([]byte, 1200) + + // Flow A two segments. + if err := c.Commit(buildTCPv4Ports(1000, 2000, 100, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Commit(buildTCPv4Ports(1000, 2000, 1300, tcpAck, pay)); err != nil { + t.Fatal(err) + } + // Flow B two segments. + if err := c.Commit(buildTCPv4Ports(3000, 2000, 500, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Commit(buildTCPv4Ports(3000, 2000, 1700, tcpAck, pay)); err != nil { + t.Fatal(err) + } + // Flow A SYN (non-admissible) — must flush only flow A's slot. + syn := buildTCPv4Ports(1000, 2000, 9999, tcpSyn|tcpAck, pay) + if err := c.Commit(syn); err != nil { + t.Fatal(err) + } + // Flow B continues — should still be coalesced with its seed. + if err := c.Commit(buildTCPv4Ports(3000, 2000, 2900, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + + // Expected: + // - 1 gso for flow A (first 2 segments) + // - 1 plain for flow A SYN + // - 1 gso for flow B (3 segments) + if len(w.gsoWrites) != 2 { + t.Fatalf("want 2 gso writes, got %d", len(w.gsoWrites)) + } + if len(w.writes) != 1 { + t.Fatalf("want 1 plain write (SYN), got %d", len(w.writes)) + } + // Find the 3-segment gso (flow B) and the 2-segment gso (flow A). + var segCounts []int + for _, g := range w.gsoWrites { + segCounts = append(segCounts, len(g.pays)) + } + if !(segCounts[0] == 2 && segCounts[1] == 3) && !(segCounts[0] == 3 && segCounts[1] == 2) { + t.Errorf("unexpected segment counts: %v (want 2 and 3)", segCounts) + } +} + +// ECN test helpers and constants. + +const ( + tcpEce = 0x40 + tcpCwr = 0x80 + + // 2-bit IP-level ECN codepoints (lower 2 bits of IPv4 ToS / IPv6 TC). + ecnNotECT = 0x00 + ecnECT1 = 0x01 + ecnECT0 = 0x02 + ecnCE = 0x03 +) + +// buildTCPv4WithToS is buildTCPv4 with caller-specified IPv4 ToS so tests can +// drive DSCP and ECN bits. +func buildTCPv4WithToS(tos byte, seq uint32, flags byte, payload []byte) []byte { + pkt := buildTCPv4(seq, flags, payload) + pkt[1] = tos + return pkt +} + +// buildTCPv6 mirrors buildTCPv4 for IPv6. tcLow is the low 4 bits of Traffic +// Class, which carries the ECN codepoint (mask 0x03) and the bottom 2 DSCP +// bits — enough to drive the ECN paths under test. +func buildTCPv6(tcLow byte, seq uint32, flags byte, payload []byte) []byte { + const ipHdrLen = 40 + const tcpHdrLen = 20 + pkt := make([]byte, ipHdrLen+tcpHdrLen+len(payload)) + + pkt[0] = 0x60 // version=6, TC[7:4]=0 + pkt[1] = (tcLow & 0x0f) << 4 // TC[3:0] in high nibble; flow=0 + binary.BigEndian.PutUint16(pkt[4:6], uint16(tcpHdrLen+len(payload))) + pkt[6] = ipProtoTCP + pkt[7] = 64 + copy(pkt[8:24], []byte{0xfd, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}) + copy(pkt[24:40], []byte{0xfd, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2}) + + binary.BigEndian.PutUint16(pkt[40:42], 1000) + binary.BigEndian.PutUint16(pkt[42:44], 2000) + binary.BigEndian.PutUint32(pkt[44:48], seq) + binary.BigEndian.PutUint32(pkt[48:52], 12345) + pkt[52] = 0x50 + pkt[53] = flags + binary.BigEndian.PutUint16(pkt[54:56], 0xffff) + + copy(pkt[60:], payload) + return pkt +} + +// TestCoalescerCoalescesEceFlow confirms that ECN-Echo-marked ACKs (an +// ECN-aware flow under congestion) keep getting coalesced into a TSO +// superpacket instead of falling out to passthrough, and that the seed +// retains ECE on the wire. +func TestCoalescerCoalescesEceFlow(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := NewTCPCoalescer(w) + pay := make([]byte, 1200) + flags := byte(tcpAck | tcpEce) + if err := c.Commit(buildTCPv4(1000, flags, pay)); err != nil { + t.Fatal(err) + } + if err := c.Commit(buildTCPv4(2200, flags, pay)); err != nil { + t.Fatal(err) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + if len(w.gsoWrites) != 1 { + t.Fatalf("want 1 gso write, got %d (plain=%d)", len(w.gsoWrites), len(w.writes)) + } + g := w.gsoWrites[0] + if len(g.pays) != 2 { + t.Errorf("pay count=%d want 2", len(g.pays)) + } + if seedFlags := g.hdr[20+13]; seedFlags&tcpEce == 0 { + t.Errorf("seed flags=0x%02x want ECE preserved", seedFlags) + } +} + +// TestCoalescerCwrSealsFlow confirms that a CWR-bearing segment in the +// middle of a flow goes to passthrough and seals the open slot, so a later +// in-flow segment seeds a new slot rather than extending the prior burst. +func TestCoalescerCwrSealsFlow(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := NewTCPCoalescer(w) + pay := make([]byte, 1200) + if err := c.Commit(buildTCPv4(1000, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Commit(buildTCPv4(2200, tcpAck|tcpCwr, pay)); err != nil { + t.Fatal(err) + } + if err := c.Commit(buildTCPv4(3400, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + if len(w.writes) != 1 { + t.Fatalf("want 1 plain write (CWR), got %d", len(w.writes)) + } + // Two GSO writes: the first seed before CWR, and a fresh seed after. + if len(w.gsoWrites) != 2 { + t.Fatalf("want 2 gso writes, got %d", len(w.gsoWrites)) + } + for i, g := range w.gsoWrites { + if len(g.pays) != 1 { + t.Errorf("gso %d pay count=%d want 1", i, len(g.pays)) + } + } +} + +// TestCoalescerEceMismatchReseeds confirms that toggling ECE mid-flow does +// not silently merge — receivers expect ECE either set on every segment of +// a CE-echoing window or none. +func TestCoalescerEceMismatchReseeds(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := NewTCPCoalescer(w) + pay := make([]byte, 1200) + if err := c.Commit(buildTCPv4(1000, tcpAck|tcpEce, pay)); err != nil { + t.Fatal(err) + } + if err := c.Commit(buildTCPv4(2200, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + if len(w.gsoWrites) != 2 { + t.Fatalf("want 2 separate seeds, got %d gso writes", len(w.gsoWrites)) + } + for i, g := range w.gsoWrites { + if len(g.pays) != 1 { + t.Errorf("gso %d pay count=%d want 1", i, len(g.pays)) + } + } +} + +// TestCoalescerMergesCEMark confirms that an ECT(0) burst with a single +// CE-marked packet still coalesces, and the merged superpacket carries CE. +func TestCoalescerMergesCEMark(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := NewTCPCoalescer(w) + pay := make([]byte, 1200) + if err := c.Commit(buildTCPv4WithToS(ecnECT0, 1000, tcpAck, pay)); err != nil { + t.Fatal(err) + } + // Router along the path stamped CE on this one. + if err := c.Commit(buildTCPv4WithToS(ecnCE, 2200, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Commit(buildTCPv4WithToS(ecnECT0, 3400, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + if len(w.gsoWrites) != 1 { + t.Fatalf("want 1 merged gso write, got %d (plain=%d)", len(w.gsoWrites), len(w.writes)) + } + g := w.gsoWrites[0] + if len(g.pays) != 3 { + t.Errorf("pay count=%d want 3", len(g.pays)) + } + if got := g.hdr[1] & 0x03; got != ecnCE { + t.Errorf("seed ECN=0x%02x want CE 0x%02x", got, ecnCE) + } +} + +// TestCoalescerDscpMismatchReseeds confirms that the new ECN-mask in +// headersMatch did not also relax DSCP — different DSCP must still split. +func TestCoalescerDscpMismatchReseeds(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := NewTCPCoalescer(w) + pay := make([]byte, 1200) + // Same ECN (Not-ECT), different DSCP (0x10 vs 0x20 in upper 6 bits). + tosA := byte(0x10<<2) | ecnNotECT + tosB := byte(0x20<<2) | ecnNotECT + if err := c.Commit(buildTCPv4WithToS(tosA, 1000, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Commit(buildTCPv4WithToS(tosB, 2200, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + if len(w.gsoWrites) != 2 { + t.Fatalf("want 2 separate seeds (different DSCP), got %d", len(w.gsoWrites)) + } +} + +// TestCoalescerIPv6CoalescesEceFlow is the IPv6 analogue of +// TestCoalescerCoalescesEceFlow. +func TestCoalescerIPv6CoalescesEceFlow(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := NewTCPCoalescer(w) + pay := make([]byte, 1200) + flags := byte(tcpAck | tcpEce) + if err := c.Commit(buildTCPv6(0, 1000, flags, pay)); err != nil { + t.Fatal(err) + } + if err := c.Commit(buildTCPv6(0, 2200, flags, pay)); err != nil { + t.Fatal(err) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + if len(w.gsoWrites) != 1 { + t.Fatalf("want 1 gso write, got %d", len(w.gsoWrites)) + } + g := w.gsoWrites[0] + if seedFlags := g.hdr[40+13]; seedFlags&tcpEce == 0 { + t.Errorf("seed flags=0x%02x want ECE preserved", seedFlags) + } +} + +// TestCoalescerSortsReorderedSeedsAndMerges feeds three same-flow MSS +// segments out of TCP-seq order (mimicking a wire reorder that escaped +// the rxOrder per-batch sort). Without the reorderForFlush sort+merge, +// each out-of-seq arrival would seed its own slot and the slots would +// emit in arrival order, producing a kernel-visible TCP reorder. With +// the sort+merge, the three slots are sorted by seq and folded back into +// one in-order TSO superpacket — same shape the receiver TCP would have +// seen had the wire never reordered. +func TestCoalescerSortsReorderedSeedsAndMerges(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := NewTCPCoalescer(w) + pay := make([]byte, 1200) + // Arrival order: seq 1000, 3400, 2200. The 3400 seeds a separate slot + // because 3400 != nextSeq=2200, then 2200 fails to extend the 3400 slot + // and seeds its own. Three slots end up in c.slots; reorderForFlush + // should sort them into [1000,2200,3400] and merge them back into one. + if err := c.Commit(buildTCPv4(1000, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Commit(buildTCPv4(3400, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Commit(buildTCPv4(2200, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + if len(w.gsoWrites) != 1 { + t.Fatalf("want 1 merged gso write got %d", len(w.gsoWrites)) + } + g := w.gsoWrites[0] + if len(g.pays) != 3 { + t.Fatalf("merged segs=%d want 3", len(g.pays)) + } + const ipHdrLen = 20 + if seedSeq := binary.BigEndian.Uint32(g.hdr[ipHdrLen+4 : ipHdrLen+8]); seedSeq != 1000 { + t.Errorf("merged seed seq=%d want 1000 (lowest)", seedSeq) + } +} + +// TestCoalescerSortAcrossFlowsMergesEachIndependently checks that two +// flows interleaved with reorder are each sorted-and-merged in isolation +// without any cross-flow contamination. +func TestCoalescerSortAcrossFlowsMergesEachIndependently(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := NewTCPCoalescer(w) + pay := make([]byte, 1200) + // Flow A (sport 1000) seq 100, 1300; flow B (sport 3000) seq 500, 1700. + // Arrival: A.1300, B.1700, A.100, B.500 — every flow reordered. + if err := c.Commit(buildTCPv4Ports(1000, 2000, 1300, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Commit(buildTCPv4Ports(3000, 2000, 1700, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Commit(buildTCPv4Ports(1000, 2000, 100, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Commit(buildTCPv4Ports(3000, 2000, 500, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + if len(w.gsoWrites) != 2 { + t.Fatalf("want 2 gso writes (one per flow merged), got %d", len(w.gsoWrites)) + } + for i, g := range w.gsoWrites { + if len(g.pays) != 2 { + t.Errorf("gso[%d] segs=%d want 2", i, len(g.pays)) + } + const ipHdrLen = 20 + seedSeq := binary.BigEndian.Uint32(g.hdr[ipHdrLen+4 : ipHdrLen+8]) + sport := binary.BigEndian.Uint16(g.hdr[ipHdrLen : ipHdrLen+2]) + // Each flow's merged seed should be the LOWER of its two seqs. + switch sport { + case 1000: + if seedSeq != 100 { + t.Errorf("flow A seed seq=%d want 100", seedSeq) + } + case 3000: + if seedSeq != 500 { + t.Errorf("flow B seed seq=%d want 500", seedSeq) + } + default: + t.Errorf("unexpected sport %d", sport) + } + } +} + +// TestCoalescerSortKeepsPSHBoundary verifies that a PSH-sealed slot is +// not folded into a later seq-contiguous slot — PSH placement is part of +// the wire signal and merging across it would shift the receiver's push +// boundary by an arbitrary number of segments. +func TestCoalescerSortKeepsPSHBoundary(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := NewTCPCoalescer(w) + pay := make([]byte, 1200) + // Seq 1000 (no PSH) + 2200 (PSH) → seal one slot with PSH set. + // Seq 3400 (no PSH) is contiguous to 3400 from seq 2200+1200; without + // the PSH check it would merge in. + if err := c.Commit(buildTCPv4(1000, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Commit(buildTCPv4(2200, tcpAckPsh, pay)); err != nil { + t.Fatal(err) + } + if err := c.Commit(buildTCPv4(3400, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + if len(w.gsoWrites) != 2 { + t.Fatalf("want 2 gso writes (PSH-sealed and fresh seed), got %d", len(w.gsoWrites)) + } +} + +// TestCoalescerSortKeepsPassthroughBarrier confirms a passthrough slot in +// the middle of the queue prevents the post-sort merge from folding +// across it. Reordered same-flow data on either side of the passthrough +// is sorted/merged independently. +func TestCoalescerSortKeepsPassthroughBarrier(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := NewTCPCoalescer(w) + pay := make([]byte, 1200) + // First two segments seed S1 (then a 3400 reorder seeds S2). + if err := c.Commit(buildTCPv4(1000, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Commit(buildTCPv4(3400, tcpAck, pay)); err != nil { + t.Fatal(err) + } + // Non-coalesceable packet (SYN+ACK) flushes S1's openSlots entry and + // becomes a passthrough barrier in c.slots. + if err := c.Commit(buildTCPv4(9999, tcpSyn|tcpAck, pay)); err != nil { + t.Fatal(err) + } + // Post-barrier same-flow data: should never end up before the SYN. + if err := c.Commit(buildTCPv4(2200, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + // We expect: gso(merged 1000+3400 ranges sorted but not contiguous so 2 + // gso writes), plain(SYN), gso(2200 alone). The pre-barrier sort should + // land 1000 before 3400, and the post-barrier 2200 stays after the SYN. + if len(w.writes) != 1 { + t.Fatalf("want 1 plain SYN passthrough, got %d", len(w.writes)) + } +} + +// TestCoalescerIPv6MergesCEMark is the IPv6 analogue of +// TestCoalescerMergesCEMark. ECN bits live in TC[1:0] = byte 1 mask 0x30. +func TestCoalescerIPv6MergesCEMark(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := NewTCPCoalescer(w) + pay := make([]byte, 1200) + // tcLow is the low 4 bits of TC; ECN occupies the bottom 2 of those. + if err := c.Commit(buildTCPv6(ecnECT0, 1000, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Commit(buildTCPv6(ecnCE, 2200, tcpAck, pay)); err != nil { + t.Fatal(err) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + if len(w.gsoWrites) != 1 { + t.Fatalf("want 1 merged gso write, got %d", len(w.gsoWrites)) + } + g := w.gsoWrites[0] + // Byte 1 high nibble holds TC[3:0]; ECN is the low 2 bits of that nibble, + // which appears in byte 1 mask 0x30 (>>4 to read the codepoint value). + if got := (g.hdr[1] >> 4) & 0x03; got != ecnCE { + t.Errorf("seed v6 ECN=0x%02x want CE 0x%02x", got, ecnCE) + } +} diff --git a/overlay/batch/tx_batch.go b/overlay/batch/tx_batch.go index cac441d9..b7f219a5 100644 --- a/overlay/batch/tx_batch.go +++ b/overlay/batch/tx_batch.go @@ -4,58 +4,63 @@ import "net/netip" const SendBatchCap = 128 -// SendBatch accumulates encrypted UDP packets for potential TX offloading. +// batchWriter is the minimal subset of udp.Conn needed by SendBatch to flush. +type batchWriter interface { + WriteBatch(bufs [][]byte, addrs []netip.AddrPort, outerECNs []byte) error +} + +// SendBatch accumulates encrypted UDP packets and flushes them via WriteBatch. // One SendBatch is owned by each listenIn goroutine; no locking is needed. -// The backing storage holds up to batchCap packets of slotCap bytes each; -// bufs and dsts are parallel slices of committed slots. +// The backing arena grows on demand: when there isn't room for the next slot +// we allocate a fresh backing array. Already-committed slices keep referencing +// the old array and remain valid until Flush drops them. type SendBatch struct { - bufs [][]byte - dsts []netip.AddrPort - backing []byte - slotCap int - batchCap int - nextSlot int + out batchWriter + bufs [][]byte + dsts []netip.AddrPort + ecns []byte + backing []byte } -func NewSendBatch(batchCap, slotCap int) *SendBatch { +func NewSendBatch(out batchWriter, batchCap, slotCap int) *SendBatch { return &SendBatch{ - bufs: make([][]byte, 0, batchCap), - dsts: make([]netip.AddrPort, 0, batchCap), - backing: make([]byte, batchCap*slotCap), - slotCap: slotCap, - batchCap: batchCap, + out: out, + bufs: make([][]byte, 0, batchCap), + dsts: make([]netip.AddrPort, 0, batchCap), + ecns: make([]byte, 0, batchCap), + backing: make([]byte, 0, batchCap*slotCap), } } -func (b *SendBatch) Next() []byte { - if b.nextSlot >= b.batchCap { - return nil +func (b *SendBatch) Reserve(sz int) []byte { + if len(b.backing)+sz > cap(b.backing) { + // Grow: allocate a fresh backing. Already-committed slices still + // reference the old array and remain valid until Flush drops them. + newCap := max(cap(b.backing)*2, sz) + b.backing = make([]byte, 0, newCap) } - start := b.nextSlot * b.slotCap - return b.backing[start : start : start+b.slotCap] //set len to 0 but cap to slotCap + start := len(b.backing) + b.backing = b.backing[:start+sz] + return b.backing[start : start+sz : start+sz] } -func (b *SendBatch) Commit(n int, dst netip.AddrPort) { - start := b.nextSlot * b.slotCap - b.bufs = append(b.bufs, b.backing[start:start+n]) +func (b *SendBatch) Commit(pkt []byte, dst netip.AddrPort, outerECN byte) { + b.bufs = append(b.bufs, pkt) b.dsts = append(b.dsts, dst) - b.nextSlot++ + b.ecns = append(b.ecns, outerECN) } -func (b *SendBatch) Reset() { +func (b *SendBatch) Flush() error { + var err error + if len(b.bufs) > 0 { + err = b.out.WriteBatch(b.bufs, b.dsts, b.ecns) + } + for i := range b.bufs { + b.bufs[i] = nil + } b.bufs = b.bufs[:0] b.dsts = b.dsts[:0] - b.nextSlot = 0 -} - -func (b *SendBatch) Len() int { - return len(b.bufs) -} - -func (b *SendBatch) Cap() int { - return b.batchCap -} - -func (b *SendBatch) Get() ([][]byte, []netip.AddrPort) { - return b.bufs, b.dsts + b.ecns = b.ecns[:0] + b.backing = b.backing[:0] + return err } diff --git a/overlay/batch/tx_batch_test.go b/overlay/batch/tx_batch_test.go index 32412492..454011dc 100644 --- a/overlay/batch/tx_batch_test.go +++ b/overlay/batch/tx_batch_test.go @@ -5,65 +5,120 @@ import ( "testing" ) -func TestSendBatchBookkeeping(t *testing.T) { - b := NewSendBatch(4, 32) - if b.Len() != 0 || b.Cap() != 4 { - t.Fatalf("fresh batch: len=%d cap=%d", b.Len(), b.Cap()) +type fakeBatchWriter struct { + bufs [][]byte + addrs []netip.AddrPort + ecns []byte +} + +func (w *fakeBatchWriter) WriteBatch(bufs [][]byte, addrs []netip.AddrPort, ecns []byte) error { + // Snapshot — SendBatch.Flush nils its slot pointers right after WriteBatch + // returns, so tests must capture data before that happens. + w.bufs = make([][]byte, len(bufs)) + for i, b := range bufs { + cp := make([]byte, len(b)) + copy(cp, b) + w.bufs[i] = cp } + w.addrs = append(w.addrs[:0], addrs...) + w.ecns = append(w.ecns[:0], ecns...) + return nil +} + +func TestSendBatchReserveCommitFlush(t *testing.T) { + fw := &fakeBatchWriter{} + b := NewSendBatch(fw, 4, 32) ap := netip.MustParseAddrPort("10.0.0.1:4242") for i := 0; i < 4; i++ { - slot := b.Next() - if slot == nil { - t.Fatalf("slot %d: Next returned nil before cap", i) + slot := b.Reserve(32) + if cap(slot) != 32 { + t.Fatalf("slot %d: cap=%d want 32", i, cap(slot)) } - if cap(slot) != 32 || len(slot) != 0 { - t.Fatalf("slot %d: got len=%d cap=%d want len=0 cap=32", i, len(slot), cap(slot)) - } - // Write a marker byte. - slot = append(slot, byte(i), byte(i+1), byte(i+2)) - b.Commit(len(slot), ap) + pkt := append(slot[:0], byte(i), byte(i+1), byte(i+2)) + b.Commit(pkt, ap, 0) } - if b.Next() != nil { - t.Fatalf("Next should return nil when full") + if err := b.Flush(); err != nil { + t.Fatalf("Flush: %v", err) } - if b.Len() != 4 { - t.Fatalf("Len=%d want 4", b.Len()) + if len(fw.bufs) != 4 { + t.Fatalf("WriteBatch got %d bufs want 4", len(fw.bufs)) } - for i, buf := range b.bufs { + for i, buf := range fw.bufs { if len(buf) != 3 || buf[0] != byte(i) { t.Errorf("buf %d: %x", i, buf) } - if b.dsts[i] != ap { - t.Errorf("dst %d: got %v want %v", i, b.dsts[i], ap) + if fw.addrs[i] != ap { + t.Errorf("addr %d: got %v want %v", i, fw.addrs[i], ap) } } - // Reset returns empty and Next works again. - b.Reset() - if b.Len() != 0 { - t.Fatalf("after Reset Len=%d want 0", b.Len()) + // Flush again with nothing committed — should be a no-op. + fw.bufs = nil + if err := b.Flush(); err != nil { + t.Fatalf("empty Flush: %v", err) } - slot := b.Next() - if slot == nil || cap(slot) != 32 { - t.Fatalf("after Reset Next nil or wrong cap: %v cap=%d", slot == nil, cap(slot)) + if fw.bufs != nil { + t.Fatalf("empty Flush triggered WriteBatch") + } + + // Reuse after Flush. + slot := b.Reserve(32) + if cap(slot) != 32 { + t.Fatalf("after Flush Reserve wrong cap: %d", cap(slot)) } } func TestSendBatchSlotsDoNotOverlap(t *testing.T) { - b := NewSendBatch(3, 8) + fw := &fakeBatchWriter{} + b := NewSendBatch(fw, 3, 8) ap := netip.MustParseAddrPort("10.0.0.1:80") - // Fill three slots, each with its own sentinel byte. for i := 0; i < 3; i++ { - s := b.Next() - s = append(s, byte(0xA0+i), byte(0xB0+i)) - b.Commit(len(s), ap) + s := b.Reserve(8) + pkt := append(s[:0], byte(0xA0+i), byte(0xB0+i)) + b.Commit(pkt, ap, 0) + } + if err := b.Flush(); err != nil { + t.Fatalf("Flush: %v", err) } - for i, buf := range b.bufs { + for i, buf := range fw.bufs { if buf[0] != byte(0xA0+i) || buf[1] != byte(0xB0+i) { t.Errorf("slot %d corrupted: %x", i, buf) } } } + +func TestSendBatchGrowPreservesCommitted(t *testing.T) { + fw := &fakeBatchWriter{} + // Tiny initial backing forces a grow on the second Reserve. + b := NewSendBatch(fw, 1, 4) + ap := netip.MustParseAddrPort("10.0.0.1:80") + + s1 := b.Reserve(4) + pkt1 := append(s1[:0], 0x11, 0x22, 0x33, 0x44) + b.Commit(pkt1, ap, 0) + + s2 := b.Reserve(8) // exceeds remaining cap, triggers grow + pkt2 := append(s2[:0], 0xA, 0xB, 0xC, 0xD, 0xE) + b.Commit(pkt2, ap, 0) + + // pkt1 must still be intact even though backing reallocated. + if pkt1[0] != 0x11 || pkt1[3] != 0x44 { + t.Fatalf("first packet corrupted by grow: %x", pkt1) + } + + if err := b.Flush(); err != nil { + t.Fatalf("Flush: %v", err) + } + if len(fw.bufs) != 2 { + t.Fatalf("got %d bufs want 2", len(fw.bufs)) + } + if fw.bufs[0][0] != 0x11 || fw.bufs[0][3] != 0x44 { + t.Errorf("first packet on the wire: %x", fw.bufs[0]) + } + if fw.bufs[1][0] != 0xA || fw.bufs[1][4] != 0xE { + t.Errorf("second packet on the wire: %x", fw.bufs[1]) + } +} diff --git a/overlay/batch/udp_coalesce.go b/overlay/batch/udp_coalesce.go new file mode 100644 index 00000000..29e79677 --- /dev/null +++ b/overlay/batch/udp_coalesce.go @@ -0,0 +1,342 @@ +package batch + +import ( + "encoding/binary" + "io" + + "github.com/slackhq/nebula/overlay/tio" +) + +// ipProtoUDP is the IANA protocol number for UDP. +const ipProtoUDP = 17 + +// udpCoalesceBufSize caps total bytes per UDP superpacket. Mirrors the +// kernel's gso_max_size; payloads beyond this are emitted as-is. +const udpCoalesceBufSize = 65535 + +// udpCoalesceMaxSegs caps how many segments we'll coalesce. Kernel UDP-GSO +// accepts up to 64 segments per skb (UDP_MAX_SEGMENTS); stay under that. +const udpCoalesceMaxSegs = 64 + +// udpCoalesceHdrCap is the scratch space we copy a seed's IP+UDP header +// into. IPv6 (40) + UDP (8) = 48; round up for safety. +const udpCoalesceHdrCap = 64 + +// udpSlot is one entry in the UDPCoalescer's ordered event queue. Same +// passthrough-vs-coalesced shape as the TCP coalescer's slot, but no +// seq/PSH/CWR bookkeeping — UDP segments only need 5-tuple + length +// matching to coalesce. +type udpSlot struct { + passthrough bool + rawPkt []byte // borrowed when passthrough + + fk flowKey + hdrBuf [udpCoalesceHdrCap]byte + hdrLen int + ipHdrLen int + isV6 bool + gsoSize int // per-segment UDP payload length + numSeg int + totalPay int + // sealed closes the chain: set when a sub-gsoSize segment is appended + // (kernel UDP-GSO requires every segment but the last to be exactly + // gsoSize) or when limits are hit. No further appends after. + sealed bool + payIovs [][]byte +} + +// UDPCoalescer accumulates adjacent in-flow UDP datagrams across multiple +// concurrent flows and emits each flow's run as a single GSO_UDP_L4 +// superpacket via tio.GSOWriter. Falls back to per-packet writes when the +// underlying writer doesn't support USO. +// +// All output — coalesced or not — is deferred until Flush so per-flow +// arrival order is preserved on the wire. Cross-flow order is NOT preserved +// across the TCP/UDP/passthrough split when this coalescer runs alongside +// others — see multi_coalesce.go. Per-flow order is preserved because a +// single 5-tuple only ever lands in one lane and each lane preserves its +// own slot order. +// +// Owns no locks; one coalescer per TUN write queue. +type UDPCoalescer struct { + plainW io.Writer + gsoW tio.GSOWriter // nil when the queue can't accept GSO_UDP_L4 + + slots []*udpSlot + openSlots map[flowKey]*udpSlot + pool []*udpSlot + + backing []byte +} + +// NewUDPCoalescer wraps w. The caller is responsible for only constructing +// this when the underlying Queue's Capabilities advertise USO; otherwise +// the kernel may reject GSO_UDP_L4 writes. If w does not implement +// tio.GSOWriter at all (single-packet Queue), the coalescer degrades to +// plain Writes — same defensive shape as the TCP coalescer. +func NewUDPCoalescer(w io.Writer) *UDPCoalescer { + c := &UDPCoalescer{ + plainW: w, + slots: make([]*udpSlot, 0, initialSlots), + openSlots: make(map[flowKey]*udpSlot, initialSlots), + pool: make([]*udpSlot, 0, initialSlots), + backing: make([]byte, 0, initialSlots*udpCoalesceBufSize), + } + if gw, ok := tio.SupportsGSO(w, tio.GSOProtoUDP); ok { + c.gsoW = gw + } + return c +} + +// parsedUDP holds the fields extracted from a single parse so later steps +// (admission, slot lookup, canAppend) don't re-walk the header. +type parsedUDP struct { + fk flowKey + ipHdrLen int + hdrLen int // ipHdrLen + 8 + payLen int +} + +// parseUDP extracts the flow key and IP/UDP offsets for a UDP packet. +// Returns ok=false for non-UDP, malformed, or unsupported header shapes +// (IPv4 with options/fragmentation, IPv6 with extension headers). +func parseUDP(pkt []byte) (parsedUDP, bool) { + var p parsedUDP + ip, ok := parseIPPrologue(pkt, ipProtoUDP) + if !ok { + return p, false + } + pkt = ip.pkt + p.fk = ip.fk + p.ipHdrLen = ip.ipHdrLen + + if len(pkt) < p.ipHdrLen+8 { + return p, false + } + p.hdrLen = p.ipHdrLen + 8 + // UDP `length` field: must equal IP-derived length-of-UDP-header-plus-payload. + udpLen := int(binary.BigEndian.Uint16(pkt[p.ipHdrLen+4 : p.ipHdrLen+6])) + if udpLen < 8 || udpLen > len(pkt)-p.ipHdrLen { + return p, false + } + p.payLen = udpLen - 8 + p.fk.sport = binary.BigEndian.Uint16(pkt[p.ipHdrLen : p.ipHdrLen+2]) + p.fk.dport = binary.BigEndian.Uint16(pkt[p.ipHdrLen+2 : p.ipHdrLen+4]) + return p, true +} + +func (c *UDPCoalescer) Reserve(sz int) []byte { + return reserveFromBacking(&c.backing, sz) +} + +// Commit borrows pkt. The caller must keep pkt valid until the next Flush. +func (c *UDPCoalescer) Commit(pkt []byte) error { + if c.gsoW == nil { + c.addPassthrough(pkt) + return nil + } + info, ok := parseUDP(pkt) + if !ok { + c.addPassthrough(pkt) + return nil + } + return c.commitParsed(pkt, info) +} + +// commitParsed is the post-parse half of Commit. The caller must have +// already verified parseUDP succeeded. Used by MultiCoalescer.Commit to +// avoid re-walking the IP/UDP header. +func (c *UDPCoalescer) commitParsed(pkt []byte, info parsedUDP) error { + if c.gsoW == nil { + c.addPassthrough(pkt) + return nil + } + if open := c.openSlots[info.fk]; open != nil { + if c.canAppend(open, pkt, info) { + c.appendPayload(open, pkt, info) + if open.sealed { + delete(c.openSlots, info.fk) + } + return nil + } + // Can't extend — seal it and fall through to seed a fresh slot. + delete(c.openSlots, info.fk) + } + c.seed(pkt, info) + return nil +} + +func (c *UDPCoalescer) Flush() error { + var first error + for _, s := range c.slots { + var err error + if s.passthrough { + _, err = c.plainW.Write(s.rawPkt) + } else { + err = c.flushSlot(s) + } + if err != nil && first == nil { + first = err + } + c.release(s) + } + for i := range c.slots { + c.slots[i] = nil + } + c.slots = c.slots[:0] + for k := range c.openSlots { + delete(c.openSlots, k) + } + c.backing = c.backing[:0] + return first +} + +func (c *UDPCoalescer) addPassthrough(pkt []byte) { + s := c.take() + s.passthrough = true + s.rawPkt = pkt + c.slots = append(c.slots, s) +} + +func (c *UDPCoalescer) seed(pkt []byte, info parsedUDP) { + if info.hdrLen > udpCoalesceHdrCap || info.hdrLen+info.payLen > udpCoalesceBufSize { + c.addPassthrough(pkt) + return + } + s := c.take() + s.passthrough = false + s.rawPkt = nil + copy(s.hdrBuf[:], pkt[:info.hdrLen]) + s.hdrLen = info.hdrLen + s.ipHdrLen = info.ipHdrLen + s.isV6 = info.fk.isV6 + s.fk = info.fk + s.gsoSize = info.payLen + s.numSeg = 1 + s.totalPay = info.payLen + s.sealed = false + s.payIovs = append(s.payIovs[:0], pkt[info.hdrLen:info.hdrLen+info.payLen]) + c.slots = append(c.slots, s) + c.openSlots[info.fk] = s +} + +// canAppend reports whether info's packet extends the slot's seed. +// Kernel UDP-GSO requires every segment except possibly the last to be +// exactly gsoSize, and the last may be shorter (≤ gsoSize). +func (c *UDPCoalescer) canAppend(s *udpSlot, pkt []byte, info parsedUDP) bool { + if s.sealed { + return false + } + if info.hdrLen != s.hdrLen { + return false + } + if s.numSeg >= udpCoalesceMaxSegs { + return false + } + if info.payLen > s.gsoSize { + return false + } + if s.hdrLen+s.totalPay+info.payLen > udpCoalesceBufSize { + return false + } + if !udpHeadersMatch(s.hdrBuf[:s.hdrLen], pkt[:info.hdrLen], s.isV6, s.ipHdrLen) { + return false + } + return true +} + +func (c *UDPCoalescer) appendPayload(s *udpSlot, pkt []byte, info parsedUDP) { + s.payIovs = append(s.payIovs, pkt[info.hdrLen:info.hdrLen+info.payLen]) + s.numSeg++ + s.totalPay += info.payLen + // Merge IP-level CE marks into the seed (same trick TCP coalescer uses). + mergeECNIntoSeed(s.hdrBuf[:s.ipHdrLen], pkt[:s.ipHdrLen], s.isV6) + if info.payLen < s.gsoSize { + // Last-segment-can-be-shorter: this seals the chain. + s.sealed = true + } +} + +func (c *UDPCoalescer) take() *udpSlot { + if n := len(c.pool); n > 0 { + s := c.pool[n-1] + c.pool[n-1] = nil + c.pool = c.pool[:n-1] + return s + } + return &udpSlot{} +} + +func (c *UDPCoalescer) release(s *udpSlot) { + s.passthrough = false + s.rawPkt = nil + for i := range s.payIovs { + s.payIovs[i] = nil + } + s.payIovs = s.payIovs[:0] + s.numSeg = 0 + s.totalPay = 0 + s.sealed = false + c.pool = append(c.pool, s) +} + +// flushSlot patches the IP header total length / IPv6 payload length and +// the UDP length to the *total* across all coalesced segments, then seeds +// the UDP checksum field with the pseudo-header partial (single-fold, not +// inverted) per virtio NEEDS_CSUM. The kernel's ip_rcv_core (v4) and +// ip6_rcv_core (v6) trim the skb to those length fields, so per-segment +// values would silently drop everything but the first segment. The kernel +// then walks each segment in __udp_gso_segment, recomputing per-segment +// uh->len / iph->tot_len / IPv6 plen and adjusting the checksum via +// `check = csum16_add(csum16_sub(uh->check, uh->len), newlen)` — meaning +// our seed's uh->check must be consistent with the seed's uh->len, which +// is what passing the total to both pseudoSum and the UDP length field +// guarantees. +func (c *UDPCoalescer) flushSlot(s *udpSlot) error { + hdr := s.hdrBuf[:s.hdrLen] + total := s.hdrLen + s.totalPay // full IP+UDP+all_payloads bytes + l4Len := total - s.ipHdrLen // total UDP (8 + sum of payloads) + + if s.isV6 { + binary.BigEndian.PutUint16(hdr[4:6], uint16(l4Len)) + } else { + binary.BigEndian.PutUint16(hdr[2:4], uint16(total)) + hdr[10] = 0 + hdr[11] = 0 + binary.BigEndian.PutUint16(hdr[10:12], ipv4HdrChecksum(hdr[:s.ipHdrLen])) + } + + // UDP length field (offset 4 inside the UDP header) = total UDP size. + binary.BigEndian.PutUint16(hdr[s.ipHdrLen+4:s.ipHdrLen+6], uint16(l4Len)) + + var psum uint32 + if s.isV6 { + psum = pseudoSumIPv6(hdr[8:24], hdr[24:40], ipProtoUDP, l4Len) + } else { + psum = pseudoSumIPv4(hdr[12:16], hdr[16:20], ipProtoUDP, l4Len) + } + udpCsumOff := s.ipHdrLen + 6 + binary.BigEndian.PutUint16(hdr[udpCsumOff:udpCsumOff+2], foldOnceNoInvert(psum)) + + return c.gsoW.WriteGSO(hdr[:s.ipHdrLen], hdr[s.ipHdrLen:], s.payIovs, tio.GSOProtoUDP) +} + +// udpHeadersMatch compares two IP+UDP header prefixes for byte-equality on +// every field that must be identical across coalesced segments. Length +// fields and the ECN bits in IP TOS/TC are masked out — appendPayload +// merges CE into the seed; flushSlot rewrites lengths. +func udpHeadersMatch(a, b []byte, isV6 bool, ipHdrLen int) bool { + if len(a) != len(b) { + return false + } + if !ipHeadersMatch(a, b, isV6) { + return false + } + // UDP: compare sport+dport ([0:4]). Skip length [4:6] and checksum [6:8] — + // length varies (we rewrite at flush) and the checksum will be redone. + udp := ipHdrLen + if a[udp] != b[udp] || a[udp+1] != b[udp+1] || a[udp+2] != b[udp+2] || a[udp+3] != b[udp+3] { + return false + } + return true +} diff --git a/overlay/batch/udp_coalesce_test.go b/overlay/batch/udp_coalesce_test.go new file mode 100644 index 00000000..7eefc41a --- /dev/null +++ b/overlay/batch/udp_coalesce_test.go @@ -0,0 +1,383 @@ +package batch + +import ( + "encoding/binary" + "testing" +) + +// buildUDPv4 builds a minimal IPv4+UDP packet with the given payload and ports. +func buildUDPv4(sport, dport uint16, payload []byte) []byte { + const ipHdrLen = 20 + const udpHdrLen = 8 + total := ipHdrLen + udpHdrLen + len(payload) + pkt := make([]byte, total) + + pkt[0] = 0x45 + pkt[1] = 0x00 + binary.BigEndian.PutUint16(pkt[2:4], uint16(total)) + binary.BigEndian.PutUint16(pkt[4:6], 0) + binary.BigEndian.PutUint16(pkt[6:8], 0x4000) + pkt[8] = 64 + pkt[9] = ipProtoUDP + copy(pkt[12:16], []byte{10, 0, 0, 1}) + copy(pkt[16:20], []byte{10, 0, 0, 2}) + + binary.BigEndian.PutUint16(pkt[20:22], sport) + binary.BigEndian.PutUint16(pkt[22:24], dport) + binary.BigEndian.PutUint16(pkt[24:26], uint16(udpHdrLen+len(payload))) + binary.BigEndian.PutUint16(pkt[26:28], 0) + + copy(pkt[28:], payload) + return pkt +} + +// buildUDPv6 builds a minimal IPv6+UDP packet. +func buildUDPv6(sport, dport uint16, payload []byte) []byte { + const ipHdrLen = 40 + const udpHdrLen = 8 + total := ipHdrLen + udpHdrLen + len(payload) + pkt := make([]byte, total) + + pkt[0] = 0x60 + binary.BigEndian.PutUint16(pkt[4:6], uint16(udpHdrLen+len(payload))) + pkt[6] = ipProtoUDP + pkt[7] = 64 + pkt[8] = 0xfe + pkt[9] = 0x80 + pkt[23] = 1 + pkt[24] = 0xfe + pkt[25] = 0x80 + pkt[39] = 2 + + binary.BigEndian.PutUint16(pkt[40:42], sport) + binary.BigEndian.PutUint16(pkt[42:44], dport) + binary.BigEndian.PutUint16(pkt[44:46], uint16(udpHdrLen+len(payload))) + binary.BigEndian.PutUint16(pkt[46:48], 0) + + copy(pkt[48:], payload) + return pkt +} + +func TestUDPCoalescerPassthroughWhenGSOUnavailable(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: false} + c := NewUDPCoalescer(w) + pkt := buildUDPv4(1000, 53, make([]byte, 100)) + if err := c.Commit(pkt); err != nil { + t.Fatal(err) + } + if len(w.writes) != 0 || len(w.gsoWrites) != 0 { + t.Fatalf("no Add-time writes: writes=%d gso=%d", len(w.writes), len(w.gsoWrites)) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + if len(w.writes) != 1 || len(w.gsoWrites) != 0 { + t.Fatalf("want single plain write, got writes=%d gso=%d", len(w.writes), len(w.gsoWrites)) + } +} + +func TestUDPCoalescerNonUDPPassthrough(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := NewUDPCoalescer(w) + // ICMP packet + pkt := make([]byte, 28) + pkt[0] = 0x45 + binary.BigEndian.PutUint16(pkt[2:4], 28) + pkt[9] = 1 + copy(pkt[12:16], []byte{10, 0, 0, 1}) + copy(pkt[16:20], []byte{10, 0, 0, 2}) + if err := c.Commit(pkt); err != nil { + t.Fatal(err) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + if len(w.writes) != 1 || len(w.gsoWrites) != 0 { + t.Fatalf("ICMP must pass through unchanged: writes=%d gso=%d", len(w.writes), len(w.gsoWrites)) + } +} + +func TestUDPCoalescerSeedThenFlushAlone(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := NewUDPCoalescer(w) + pkt := buildUDPv4(1000, 53, make([]byte, 800)) + if err := c.Commit(pkt); err != nil { + t.Fatal(err) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + // Single-segment flush goes through WriteGSO; the writer infers GSO_NONE + // from len(pays)==1 and the kernel fills in the UDP csum (NEEDS_CSUM). + if len(w.gsoWrites) != 1 || len(w.writes) != 0 { + t.Fatalf("single-seg flush: writes=%d gso=%d", len(w.writes), len(w.gsoWrites)) + } +} + +func TestUDPCoalescerCoalescesEqualSized(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := NewUDPCoalescer(w) + pay := make([]byte, 1200) + for i := 0; i < 3; i++ { + if err := c.Commit(buildUDPv4(1000, 53, pay)); err != nil { + t.Fatal(err) + } + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + if len(w.gsoWrites) != 1 { + t.Fatalf("want 1 gso write, got %d (plain=%d)", len(w.gsoWrites), len(w.writes)) + } + g := w.gsoWrites[0] + if g.gsoSize != 1200 { + t.Errorf("gsoSize=%d want 1200", g.gsoSize) + } + if len(g.pays) != 3 { + t.Errorf("pay count=%d want 3", len(g.pays)) + } + if g.csumStart != 20 { + t.Errorf("csumStart=%d want 20", g.csumStart) + } + // IP totalLen and UDP length must be the TOTAL across all segments — + // the kernel's ip_rcv_core trims skbs to iph->tot_len, so a per-segment + // value would silently drop everything but the first segment. Total = + // IP(20) + UDP(8) + 3*1200 = 3628. + gotTotalLen := binary.BigEndian.Uint16(g.hdr[2:4]) + if gotTotalLen != 3628 { + t.Errorf("ipv4 total_len=%d want 3628 (must be total across segments)", gotTotalLen) + } + gotUDPLen := binary.BigEndian.Uint16(g.hdr[20+4 : 20+6]) + if gotUDPLen != 8+3*1200 { + t.Errorf("udp len=%d want %d", gotUDPLen, 8+3*1200) + } +} + +// Last segment may be shorter, sealing the chain. +func TestUDPCoalescerShortLastSegmentSeals(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := NewUDPCoalescer(w) + full := make([]byte, 1200) + tail := make([]byte, 600) + if err := c.Commit(buildUDPv4(1000, 53, full)); err != nil { + t.Fatal(err) + } + if err := c.Commit(buildUDPv4(1000, 53, full)); err != nil { + t.Fatal(err) + } + if err := c.Commit(buildUDPv4(1000, 53, tail)); err != nil { + t.Fatal(err) + } + // A 4th packet, even same-sized, must NOT join — chain is sealed. + if err := c.Commit(buildUDPv4(1000, 53, full)); err != nil { + t.Fatal(err) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + if len(w.gsoWrites) != 2 { + t.Fatalf("want 2 gso writes (sealed + new seed), got %d", len(w.gsoWrites)) + } + if len(w.gsoWrites[0].pays) != 3 { + t.Errorf("first super: want 3 pays, got %d", len(w.gsoWrites[0].pays)) + } + if len(w.gsoWrites[1].pays) != 1 { + t.Errorf("second super: want 1 pay (re-seed), got %d", len(w.gsoWrites[1].pays)) + } +} + +// A larger-than-gsoSize packet cannot extend the slot — it reseeds. +func TestUDPCoalescerLargerThanSeedReseeds(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := NewUDPCoalescer(w) + if err := c.Commit(buildUDPv4(1000, 53, make([]byte, 800))); err != nil { + t.Fatal(err) + } + if err := c.Commit(buildUDPv4(1000, 53, make([]byte, 1200))); err != nil { + t.Fatal(err) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + if len(w.gsoWrites) != 2 { + t.Fatalf("want 2 separate seeds, got %d", len(w.gsoWrites)) + } +} + +// Different 5-tuples must not coalesce. +func TestUDPCoalescerDifferentFlowsKeepSeparate(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := NewUDPCoalescer(w) + pay := make([]byte, 800) + if err := c.Commit(buildUDPv4(1000, 53, pay)); err != nil { + t.Fatal(err) + } + if err := c.Commit(buildUDPv4(2000, 53, pay)); err != nil { + t.Fatal(err) + } + if err := c.Commit(buildUDPv4(1000, 53, pay)); err != nil { + t.Fatal(err) + } + if err := c.Commit(buildUDPv4(2000, 53, pay)); err != nil { + t.Fatal(err) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + // Two flows × 2 datagrams each = 2 superpackets of 2 segments. + if len(w.gsoWrites) != 2 { + t.Fatalf("want 2 gso writes (one per flow), got %d", len(w.gsoWrites)) + } + for i, g := range w.gsoWrites { + if len(g.pays) != 2 { + t.Errorf("super %d: want 2 pays, got %d", i, len(g.pays)) + } + } +} + +// Caps at udpCoalesceMaxSegs. +func TestUDPCoalescerCapsAtMaxSegs(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := NewUDPCoalescer(w) + pay := make([]byte, 100) + for i := 0; i < udpCoalesceMaxSegs+5; i++ { + if err := c.Commit(buildUDPv4(1000, 53, pay)); err != nil { + t.Fatal(err) + } + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + // First superpacket holds udpCoalesceMaxSegs segments; the spillover + // reseeds a new one. + if len(w.gsoWrites) != 2 { + t.Fatalf("want 2 gso writes (cap then reseed), got %d", len(w.gsoWrites)) + } + if len(w.gsoWrites[0].pays) != udpCoalesceMaxSegs { + t.Errorf("first super: pays=%d want %d", len(w.gsoWrites[0].pays), udpCoalesceMaxSegs) + } + if len(w.gsoWrites[1].pays) != 5 { + t.Errorf("second super: pays=%d want 5", len(w.gsoWrites[1].pays)) + } +} + +// CE marks on appended segments must be merged into the seed's IP TOS. +func TestUDPCoalescerMergesCEMark(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := NewUDPCoalescer(w) + pay := make([]byte, 800) + pkt0 := buildUDPv4(1000, 53, pay) // ECN=00 + pkt1 := buildUDPv4(1000, 53, pay) + pkt1[1] = 0x03 // CE + pkt2 := buildUDPv4(1000, 53, pay) + if err := c.Commit(pkt0); err != nil { + t.Fatal(err) + } + if err := c.Commit(pkt1); err != nil { + t.Fatal(err) + } + if err := c.Commit(pkt2); err != nil { + t.Fatal(err) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + if len(w.gsoWrites) != 1 { + t.Fatalf("want 1 merged gso write, got %d (plain=%d)", len(w.gsoWrites), len(w.writes)) + } + if w.gsoWrites[0].hdr[1]&0x03 != 0x03 { + t.Errorf("CE not merged into seed (tos=%#x)", w.gsoWrites[0].hdr[1]) + } +} + +// IPv6 path: same flow, equal-sized → coalesced. +func TestUDPCoalescerIPv6Coalesces(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := NewUDPCoalescer(w) + pay := make([]byte, 1200) + for i := 0; i < 3; i++ { + if err := c.Commit(buildUDPv6(1000, 53, pay)); err != nil { + t.Fatal(err) + } + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + if len(w.gsoWrites) != 1 { + t.Fatalf("want 1 gso write, got %d", len(w.gsoWrites)) + } + g := w.gsoWrites[0] + if !g.isV6 { + t.Errorf("expected v6 write") + } + if g.csumStart != 40 { + t.Errorf("csumStart=%d want 40", g.csumStart) + } + // IPv6 payload_len and UDP length must be TOTAL — kernel's + // ip6_rcv_core trims to payload_len + ipv6 hdr size. Total UDP = 8 + + // 3*1200 = 3608. + gotPlen := binary.BigEndian.Uint16(g.hdr[4:6]) + if gotPlen != 8+3*1200 { + t.Errorf("ipv6 payload_len=%d want %d (must be total)", gotPlen, 8+3*1200) + } + gotUDPLen := binary.BigEndian.Uint16(g.hdr[40+4 : 40+6]) + if gotUDPLen != 8+3*1200 { + t.Errorf("udp len=%d want %d", gotUDPLen, 8+3*1200) + } +} + +// DSCP differences must reseed (headers don't match outside ECN). +func TestUDPCoalescerDSCPMismatchReseeds(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := NewUDPCoalescer(w) + pay := make([]byte, 800) + pkt0 := buildUDPv4(1000, 53, pay) + pkt1 := buildUDPv4(1000, 53, pay) + pkt1[1] = 0xb8 // EF DSCP, ECN=0 + if err := c.Commit(pkt0); err != nil { + t.Fatal(err) + } + if err := c.Commit(pkt1); err != nil { + t.Fatal(err) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + if len(w.gsoWrites) != 2 { + t.Fatalf("want 2 separate seeds (different DSCP), got %d", len(w.gsoWrites)) + } +} + +// Fragmented IPv4 must not be coalesced. +func TestUDPCoalescerFragmentedIPv4PassesThrough(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := NewUDPCoalescer(w) + pkt := buildUDPv4(1000, 53, make([]byte, 200)) + binary.BigEndian.PutUint16(pkt[6:8], 0x2000) // MF=1 + if err := c.Commit(pkt); err != nil { + t.Fatal(err) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + if len(w.writes) != 1 || len(w.gsoWrites) != 0 { + t.Fatalf("frag must pass through plain, got writes=%d gso=%d", len(w.writes), len(w.gsoWrites)) + } +} + +// IPv4 with options is not admissible (we require IHL=5). +func TestUDPCoalescerIPv4WithOptionsPassesThrough(t *testing.T) { + w := &fakeTunWriter{gsoEnabled: true} + c := NewUDPCoalescer(w) + pkt := buildUDPv4(1000, 53, make([]byte, 200)) + pkt[0] = 0x46 // IHL = 6 (24-byte IPv4 header — has options) + if err := c.Commit(pkt); err != nil { + t.Fatal(err) + } + if err := c.Flush(); err != nil { + t.Fatal(err) + } + if len(w.writes) != 1 || len(w.gsoWrites) != 0 { + t.Fatalf("ipv4-with-options must pass through plain, got writes=%d gso=%d", len(w.writes), len(w.gsoWrites)) + } +} diff --git a/overlay/device.go b/overlay/device.go index f8181421..8044ee75 100644 --- a/overlay/device.go +++ b/overlay/device.go @@ -18,7 +18,7 @@ type Device interface { Networks() []netip.Prefix Name() string RoutesFor(netip.Addr) routing.Gateways - SupportsMultiqueue() bool //todo remove? + SupportsMultiqueue() bool NewMultiQueueReader() error Readers() []tio.Queue } diff --git a/overlay/overlaytest/noop.go b/overlay/overlaytest/noop.go index 10886511..6a39ab43 100644 --- a/overlay/overlaytest/noop.go +++ b/overlay/overlaytest/noop.go @@ -31,7 +31,7 @@ func (NoopTun) Name() string { return "noop" } -func (NoopTun) Read() ([][]byte, error) { +func (NoopTun) Read() ([]tio.Packet, error) { return nil, nil } diff --git a/overlay/tio/queueset_gso_linux.go b/overlay/tio/queueset_gso_linux.go new file mode 100644 index 00000000..4914df88 --- /dev/null +++ b/overlay/tio/queueset_gso_linux.go @@ -0,0 +1,79 @@ +package tio + +import ( + "encoding/binary" + "errors" + "fmt" + + "golang.org/x/sys/unix" +) + +type offloadQueueSet struct { + pq []*Offload + // pqi is exactly the same as pq, but stored as the interface type + pqi []Queue + shutdownFd int + // usoEnabled is true when newTun successfully negotiated TUN_F_USO4|6 + // with the kernel. Queues created by Add inherit this and surface it + // via Offload.USOSupported so coalescers can gate USO emission. + usoEnabled bool +} + +// NewOffloadQueueSet creates a QueueSet that uses virtio_net_hdr to do +// TSO segmentation in userspace. usoEnabled tells downstream queues whether +// the kernel agreed to deliver/accept GSO_UDP_L4 superpackets — coalescers +// should fall back to per-packet writes when this is false. +func NewOffloadQueueSet(usoEnabled bool) (QueueSet, error) { + shutdownFd, err := unix.Eventfd(0, unix.EFD_NONBLOCK|unix.EFD_CLOEXEC) + if err != nil { + return nil, fmt.Errorf("failed to create eventfd: %w", err) + } + + out := &offloadQueueSet{ + pq: []*Offload{}, + pqi: []Queue{}, + shutdownFd: shutdownFd, + usoEnabled: usoEnabled, + } + + return out, nil +} + +func (c *offloadQueueSet) Queues() []Queue { + return c.pqi +} + +func (c *offloadQueueSet) Add(fd int) error { + x, err := newOffload(fd, c.shutdownFd, c.usoEnabled) + if err != nil { + return err + } + c.pq = append(c.pq, x) + c.pqi = append(c.pqi, x) + + return nil +} + +func (c *offloadQueueSet) wakeForShutdown() error { + var buf [8]byte + binary.NativeEndian.PutUint64(buf[:], 1) + _, err := unix.Write(c.shutdownFd, buf[:]) + return err +} + +func (c *offloadQueueSet) Close() error { + errs := []error{} + + // Signal all readers blocked in poll to wake up and exit + if err := c.wakeForShutdown(); err != nil { + errs = append(errs, err) + } + + for _, x := range c.pq { + if err := x.Close(); err != nil { + errs = append(errs, err) + } + } + + return errors.Join(errs...) +} diff --git a/overlay/tio/container_poll_linux.go b/overlay/tio/queueset_poll_linux.go similarity index 77% rename from overlay/tio/container_poll_linux.go rename to overlay/tio/queueset_poll_linux.go index fa6367e7..ab967df4 100644 --- a/overlay/tio/container_poll_linux.go +++ b/overlay/tio/queueset_poll_linux.go @@ -8,20 +8,20 @@ import ( "golang.org/x/sys/unix" ) -type pollContainer struct { +type pollQueueSet struct { pq []*Poll // pqi is exactly the same as pq, but stored as the interface type pqi []Queue shutdownFd int } -func NewPollContainer() (Container, error) { +func NewPollQueueSet() (QueueSet, error) { shutdownFd, err := unix.Eventfd(0, unix.EFD_NONBLOCK|unix.EFD_CLOEXEC) if err != nil { return nil, fmt.Errorf("failed to create eventfd: %w", err) } - out := &pollContainer{ + out := &pollQueueSet{ pq: []*Poll{}, pqi: []Queue{}, shutdownFd: shutdownFd, @@ -30,11 +30,11 @@ func NewPollContainer() (Container, error) { return out, nil } -func (c *pollContainer) Queues() []Queue { +func (c *pollQueueSet) Queues() []Queue { return c.pqi } -func (c *pollContainer) Add(fd int) error { +func (c *pollQueueSet) Add(fd int) error { x, err := newPoll(fd, c.shutdownFd) if err != nil { return err @@ -45,14 +45,14 @@ func (c *pollContainer) Add(fd int) error { return nil } -func (c *pollContainer) wakeForShutdown() error { +func (c *pollQueueSet) wakeForShutdown() error { var buf [8]byte binary.NativeEndian.PutUint64(buf[:], 1) _, err := unix.Write(int(c.shutdownFd), buf[:]) return err } -func (c *pollContainer) Close() error { +func (c *pollQueueSet) Close() error { errs := []error{} if err := c.wakeForShutdown(); err != nil { diff --git a/overlay/tio/segment_bench_test.go b/overlay/tio/segment_bench_test.go new file mode 100644 index 00000000..13713010 --- /dev/null +++ b/overlay/tio/segment_bench_test.go @@ -0,0 +1,65 @@ +//go:build linux && !android && !e2e_testing + +package tio + +import "testing" + +// fakeBatch stands in for batch.TxBatcher inside the bench — same shape +// of pointer-capturing closure that sendInsideMessage builds. +type fakeBatch struct{ buf [65536]byte } + +func (b *fakeBatch) Reserve(sz int) []byte { return b.buf[:sz] } +func (b *fakeBatch) Commit([]byte) {} + +type fakeHostInfo struct { + remoteIndexId uint32 + counter uint64 +} +type fakeIface struct { + rebindCount uint8 + hi *fakeHostInfo +} + +// BenchmarkSegmentSuperpacketAllocsTSO measures allocation per +// SegmentSuperpacket call when a closure captures pointer-bearing +// receivers — the realistic shape of sendInsideMessage's closure. +func BenchmarkSegmentSuperpacketAllocsTSO(b *testing.B) { + const mss = 1400 + const numSeg = 32 + pkt := buildTSOv6(mss*numSeg, mss) + gso := GSOInfo{ + Size: mss, + HdrLen: 60, // 40 (IPv6) + 20 (TCP) + CsumStart: 40, + Proto: GSOProtoTCP, + } + p := Packet{Bytes: pkt, GSO: gso} + + hi := &fakeHostInfo{remoteIndexId: 0xdeadbeef} + f := &fakeIface{rebindCount: 7, hi: hi} + fb := &fakeBatch{} + + // SegmentSuperpacket consumes pkt destructively; refresh from a master + // copy each iter (matches the production pattern where every TUN read + // hands the segmenter a fresh kernel-supplied buffer). + master := append([]byte(nil), pkt...) + work := make([]byte, len(pkt)) + p.Bytes = work + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + copy(work, master) + err := SegmentSuperpacket(p, func(seg []byte) error { + out := fb.Reserve(16 + len(seg) + 16) + out[0] = byte(f.rebindCount) + out[1] = byte(hi.counter) + hi.counter++ + fb.Commit(out) + return nil + }) + if err != nil { + b.Fatalf("SegmentSuperpacket: %v", err) + } + } +} diff --git a/overlay/tio/segment_other.go b/overlay/tio/segment_other.go new file mode 100644 index 00000000..24e73fdd --- /dev/null +++ b/overlay/tio/segment_other.go @@ -0,0 +1,18 @@ +//go:build !linux || android || e2e_testing + +package tio + +import "fmt" + +// SegmentSuperpacket invokes fn once per segment of pkt. On non-Linux +// builds (and Android/e2e_testing) this package does not provide a Queue +// implementation, so any caller that does construct a Packet here can only +// be operating on non-superpacket bytes and the stub forwards them +// directly. A non-zero GSO field is a programming error from the caller +// and returns an explicit error rather than silently misbehaving. +func SegmentSuperpacket(pkt Packet, scratch []byte, fn func(seg []byte) error) error { + if pkt.GSO.IsSuperpacket() { + return fmt.Errorf("tio: GSO superpacket on platform without segmentation support") + } + return fn(pkt.Bytes) +} diff --git a/overlay/tio/tio.go b/overlay/tio/tio.go index 94240d06..2d94c764 100644 --- a/overlay/tio/tio.go +++ b/overlay/tio/tio.go @@ -1,56 +1,170 @@ package tio -import "io" +import ( + "io" +) -// defaultBatchBufSize is the per-Queue scratch size for Read on backends -// that don't do TSO segmentation. 65535 covers any single IP packet. -const defaultBatchBufSize = 65535 - -// Container holds one or many Queue objects and helps close them in an orderly way -type Container interface { +// QueueSet holds one or many Queue objects and helps close them in an orderly way. +type QueueSet interface { io.Closer Queues() []Queue - // Add takes a tun fd, adds it to the container, and prepares it for use as a Queue + // Add takes a tun fd, adds it to the set, and prepares it for use as a Queue. Add(fd int) error } +// Capabilities advertises which kernel offload features a Queue +// successfully negotiated. Callers consult this to decide which coalescers +// to wire onto the write path — a Queue without TSO can't usefully accept a +// TCPCoalescer, and a Queue without USO can't accept a UDPCoalescer. +type Capabilities struct { + // TSO means the FD was opened with IFF_VNET_HDR and the kernel agreed + // to TUN_F_TSO4|TSO6 — i.e. WriteGSO with GSOProtoTCP is safe. + TSO bool + // USO means the kernel additionally agreed to TUN_F_USO4|USO6, so + // WriteGSO with GSOProtoUDP is safe. Linux ≥ 6.2. + USO bool +} + // Queue is a readable/writable Poll queue. One Queue is driven by a single -// read goroutine plus concurrent writers (see Write / WriteReject below). +// read goroutine plus a single writer (see Write below). type Queue interface { io.Closer - // Read returns one or more packets. The returned slices are borrowed - // from the Queue's internal buffer and are only valid until the next - // Read or Close on this Queue - callers must encrypt or copy each - // slice before the next call. Not safe for concurrent Reads. - Read() ([][]byte, error) + // Read returns one or more packets. The returned Packet.Bytes slices + // are borrowed from the Queue's internal buffer and are only valid + // until the next Read or Close on this Queue - callers must encrypt + // or copy each slice before the next call. A Packet may carry a + // GSO/USO superpacket (see GSOInfo); when GSO.IsSuperpacket() is + // true the caller must segment Bytes before treating it as a single + // IP datagram. Not safe for concurrent Reads. + Read() ([]Packet, error) // Write emits a single packet on the plaintext (outside→inside) // delivery path. Not safe for concurrent Writes. Write(p []byte) (int, error) } -// GSOWriter is implemented by Queues that can emit a TCP TSO superpacket +// Packet is the unit Queue.Read returns. Bytes points into the queue's +// internal buffer and is only valid until the next Read or Close on the +// queue that produced it. GSO is the zero value for an already-segmented +// IP datagram; when non-zero it describes a kernel-supplied TSO/USO +// superpacket the caller must segment before consuming. +type Packet struct { + Bytes []byte + GSO GSOInfo +} + +// GSOInfo describes a kernel-supplied superpacket sitting in Packet.Bytes. +// The zero value means "not a superpacket" — Bytes is one regular IP +// datagram and no segmentation is required. +type GSOInfo struct { + // Size is the GSO segment size: max payload bytes per segment + // (== TCP MSS for TSO, == UDP payload chunk for USO). Zero means + // not a superpacket. + Size uint16 + // HdrLen is the total L3+L4 header length within Bytes (already + // corrected via correctHdrLen, so safe to slice on). + HdrLen uint16 + // CsumStart is the L4 header offset inside Bytes (== L3 header + // length). + CsumStart uint16 + // Proto picks the L4 protocol (TCP or UDP) so the segmenter knows + // which checksum/header layout to apply. + Proto GSOProto +} + +// IsSuperpacket reports whether g describes a multi-segment GSO/USO +// superpacket that needs segmentation before its bytes can be encrypted +// and sent on the wire. +func (g GSOInfo) IsSuperpacket() bool { return g.Size > 0 } + +// Clone returns a Packet whose Bytes is a freshly allocated copy of p.Bytes, +// safe to retain past the next Read or Close on the originating Queue. +// GSO metadata is copied verbatim. Use this only when a caller genuinely +// needs to outlive the borrowed-slice contract — the hot path reads should +// continue to consume the borrow synchronously to avoid the allocation. +func (p Packet) Clone() Packet { + if p.Bytes == nil { + return p + } + cp := make([]byte, len(p.Bytes)) + copy(cp, p.Bytes) + return Packet{Bytes: cp, GSO: p.GSO} +} + +// CapsProvider is an optional interface implemented by Queues that +// successfully negotiated kernel offload features at open time. Callers +// pick a write-path coalescer based on the result. Queues that don't +// implement it are treated as having no offload capability — callers must +// fall back to plain per-packet writes. +type CapsProvider interface { + Capabilities() Capabilities +} + +// QueueCapabilities returns q's negotiated offload capabilities, or the +// zero value when q does not advertise any. +func QueueCapabilities(q Queue) Capabilities { + if cp, ok := q.(CapsProvider); ok { + return cp.Capabilities() + } + return Capabilities{} +} + +// GSOProto selects the L4 protocol for a GSO superpacket. Determines which +// VIRTIO_NET_HDR_GSO_* type the writer stamps and which checksum offset +// inside the transport header virtio NEEDS_CSUM expects. +type GSOProto uint8 + +const ( + GSOProtoTCP GSOProto = iota + GSOProtoUDP +) + +// GSOWriter is implemented by Queues that can emit a TCP or UDP superpacket // assembled from a header prefix plus one or more borrowed payload // fragments, in a single vectored write (writev with a leading // virtio_net_hdr). This lets the coalescer avoid copying payload bytes // between the caller's decrypt buffer and the TUN. Backends without GSO -// support return false from GSOSupported and coalescing is skipped. +// support do not implement this interface and coalescing is skipped. // -// hdr contains the IPv4/IPv6 + TCP header prefix (mutable - callers will -// have filled in total length and pseudo-header partial). pays are -// non-overlapping payload fragments whose concatenation is the full -// superpacket payload; they are read-only from the writer's perspective -// and must remain valid until the call returns. gsoSize is the MSS: -// every segment except possibly the last is exactly that many bytes. -// csumStart is the byte offset where the TCP header begins within hdr. +// hdr contains the IPv4/IPv6 header prefix (mutable - callers will have +// filled in total length and IP csum). transportHdr is the TCP or UDP +// header (mutable - the L4 checksum field must hold the pseudo-header +// partial, single-fold not inverted, per virtio NEEDS_CSUM semantics). +// pays are non-overlapping payload fragments whose concatenation is the +// full superpacket payload; they are read-only from the writer's +// perspective and must remain valid until the call returns. Every segment +// in pays except possibly the last is exactly the same size. proto picks +// the L4 protocol so the writer knows which GSOType / CsumOffset to set. // -// # TODO fold into Queue -// -// hdr's TCP checksum field must already hold the pseudo-header partial -// sum (single-fold, not inverted), per virtio NEEDS_CSUM semantics. +// Callers should also consult CapsProvider (via SupportsGSO or +// QueueCapabilities) for the per-protocol negotiated capability; an +// implementation of GSOWriter is necessary but not sufficient since USO +// may not have been negotiated even when TSO was. type GSOWriter interface { - WriteGSO(hdr []byte, pays [][]byte, gsoSize uint16, isV6 bool, csumStart uint16) error - GSOSupported() bool + WriteGSO(hdr []byte, transportHdr []byte, pays [][]byte, proto GSOProto) error +} + +// SupportsGSO reports whether w implements GSOWriter and the underlying +// queue advertises the negotiated capability for `want`. A writer that +// implements GSOWriter but not CapsProvider is treated as permissive +// (used by tests and fakes that don't negotiate). +func SupportsGSO(w any, want GSOProto) (GSOWriter, bool) { + gw, ok := w.(GSOWriter) + if !ok { + return nil, false + } + cp, ok := w.(CapsProvider) + if !ok { + return gw, true + } + caps := cp.Capabilities() + switch want { + case GSOProtoTCP: + return gw, caps.TSO + case GSOProtoUDP: + return gw, caps.USO + } + return gw, false } diff --git a/overlay/tio/tio_gso_linux.go b/overlay/tio/tio_gso_linux.go new file mode 100644 index 00000000..583bad35 --- /dev/null +++ b/overlay/tio/tio_gso_linux.go @@ -0,0 +1,461 @@ +package tio + +import ( + "fmt" + "io" + "log/slog" + "os" + "sync" + "sync/atomic" + "syscall" + "unsafe" + + "golang.org/x/sys/unix" + + "github.com/slackhq/nebula/overlay/tio/virtio" +) + +// tunRxBufSize is the per-Read worst-case footprint inside rxBuf: one +// kernel-supplied packet body, which is at most ~64 KiB (tunReadBufSize). +// Segmentation happens at encrypt time on a per-routine MTU-sized scratch +// (see SegmentSuperpacket), so rxBuf only holds raw kernel-supplied bytes. +// We round up to give comfortable margin for the drain headroom check +// below. +const tunRxBufSize = 64 * 1024 + +// tunRxBufCap is the total size we allocate for the per-reader rx +// buffer. With reads landing directly in rxBuf, each drain iteration +// consumes up to tunRxBufSize of headroom for the kernel-supplied bytes. +// Sized to two such iterations so the initial blocking read plus one +// drain read both fit without partial-drop. +const tunRxBufCap = tunRxBufSize * 2 + +// tunDrainCap caps how many packets a single Read will accumulate via +// the post-wake drain loop. Sized to soak up a burst of small ACKs while +// bounding how much work a single caller holds before handing off. +const tunDrainCap = 64 + +// gsoMaxIovs caps the iovec budget WriteGSO assembles per call: 3 fixed +// entries (virtio_net_hdr, IP hdr, transport hdr) plus up to gsoMaxIovs-3 +// payload fragments. Sized comfortably above the typical kernel GSO +// segment cap (Linux UDP_GRO is 64) so realistic coalesced bursts never +// touch the limit. iovecs are tiny (16 bytes), so the entire scratch is +// 4 KiB — fine to keep resident on every queue. WriteGSO returns an error +// rather than reallocating when a caller exceeds this budget. +const gsoMaxIovs = 256 + +// validVnetHdr is the 10-byte virtio_net_hdr we prepend to every non-GSO TUN +// write. Only flag set is VIRTIO_NET_HDR_F_DATA_VALID, which marks the skb +// CHECKSUM_UNNECESSARY so the receiving network stack skips L4 checksum +// verification. All packets that reach the plain Write paths already carry +// a valid L4 checksum (either supplied by a remote peer whose ciphertext we +// AEAD-authenticated, produced by segmentTCPYield/segmentUDPYield during +// superpacket segmentation, or built locally by CreateRejectPacket), so +// trusting them is safe. +var validVnetHdr = [virtio.Size]byte{unix.VIRTIO_NET_HDR_F_DATA_VALID} + +// Offload wraps a TUN file descriptor with poll-based reads. The FD provided will be changed to non-blocking. +// A shared eventfd allows Close to wake all readers blocked in poll. +type Offload struct { + fd int + shutdownFd int + readPoll [2]unix.PollFd + writePoll [2]unix.PollFd + // writeLock serializes blockOnWrite's read+clear of writePoll[*].Revents. + // Any goroutine that calls Write may end up parked in poll(2); without + // the lock concurrent waiters could race the Revents reset and lose + // events. + writeLock sync.Mutex + closed atomic.Bool + rxBuf []byte // backing store for kernel-handed packets read this drain + rxOff int // cursor into rxBuf for the current Read drain + pending []Packet // packets returned from the most recent Read + + // readVnetScratch holds the 10-byte virtio_net_hdr split off the front of + // every TUN read via readv(2). Decoupling the header from the packet body + // lets us read the body directly into rxBuf at the current rxOff with + // no userspace copy on the GSO_NONE fast path. + readVnetScratch [virtio.Size]byte + // readIovs is the readv(2) iovec scratch wired once at construction — + // iovec[0] points at readVnetScratch; iovec[1].Base/Len is updated per + // read to address the current rxBuf slot. + readIovs [2]unix.Iovec + + // usoEnabled records whether the kernel agreed to TUN_F_USO* on this FD, + // so writers can decide whether emitting GSO_UDP_L4 superpackets is safe. + usoEnabled bool + + // gsoHdrBuf is a per-queue 10-byte scratch for the virtio_net_hdr emitted + // by WriteGSO. Kept separate from the read-only package-level validVnetHdr + // so non-GSO Writes can ship that constant directly while WriteGSO + // rewrites this scratch on every call. + gsoHdrBuf [virtio.Size]byte + // gsoIovs is the writev iovec scratch for WriteGSO. Pre-sized to + // gsoMaxIovs at construction; never grown. WriteGSO returns an error + // (and drops the call) if a caller hands it more fragments than fit. + gsoIovs []unix.Iovec +} + +func newOffload(fd int, shutdownFd int, usoEnabled bool) (*Offload, error) { + if err := unix.SetNonblock(fd, true); err != nil { + return nil, fmt.Errorf("failed to set tun fd non-blocking: %w", err) + } + + out := &Offload{ + fd: fd, + shutdownFd: shutdownFd, + usoEnabled: usoEnabled, + closed: atomic.Bool{}, + readPoll: [2]unix.PollFd{ + {Fd: int32(fd), Events: unix.POLLIN}, + {Fd: int32(shutdownFd), Events: unix.POLLIN}, + }, + writePoll: [2]unix.PollFd{ + {Fd: int32(fd), Events: unix.POLLOUT}, + {Fd: int32(shutdownFd), Events: unix.POLLIN}, + }, + writeLock: sync.Mutex{}, + + rxBuf: make([]byte, tunRxBufCap), + gsoIovs: make([]unix.Iovec, 2, gsoMaxIovs), + } + + out.gsoIovs[0].Base = &out.gsoHdrBuf[0] + out.gsoIovs[0].SetLen(virtio.Size) + + // readIovs[0] is wired once to the virtio_net_hdr scratch; per-read we + // only repoint readIovs[1] at the next rxBuf slot (see readPacket). + out.readIovs[0].Base = &out.readVnetScratch[0] + out.readIovs[0].SetLen(virtio.Size) + + return out, nil +} + +func (r *Offload) blockOnRead() error { + const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR + var err error + for { + _, err = unix.Poll(r.readPoll[:], -1) + if err != unix.EINTR { + break + } + } + //always reset these! + tunEvents := r.readPoll[0].Revents + shutdownEvents := r.readPoll[1].Revents + r.readPoll[0].Revents = 0 + r.readPoll[1].Revents = 0 + //do the err check before trusting the potentially bogus bits we just got + if err != nil { + return err + } + if shutdownEvents&(unix.POLLIN|problemFlags) != 0 { + return os.ErrClosed + } else if tunEvents&problemFlags != 0 { + return os.ErrClosed + } + return nil +} + +func (r *Offload) blockOnWrite() error { + const problemFlags = unix.POLLHUP | unix.POLLNVAL | unix.POLLERR + var err error + for { + _, err = unix.Poll(r.writePoll[:], -1) + if err != unix.EINTR { + break + } + } + //always reset these! + r.writeLock.Lock() + tunEvents := r.writePoll[0].Revents + shutdownEvents := r.writePoll[1].Revents + r.writePoll[0].Revents = 0 + r.writePoll[1].Revents = 0 + r.writeLock.Unlock() + //do the err check before trusting the potentially bogus bits we just got + if err != nil { + return err + } + if shutdownEvents&(unix.POLLIN|problemFlags) != 0 { + return os.ErrClosed + } else if tunEvents&problemFlags != 0 { + return os.ErrClosed + } + return nil +} + +// readPacket issues a single readv(2) splitting the virtio_net_hdr off +// into readVnetScratch and reading the packet body directly into rxBuf at +// the current rxOff. Returns the body length (zero virtio header bytes, +// just the IP packet/superpacket). block controls whether EAGAIN is +// retried via poll: the initial read of a drain blocks; subsequent drain +// reads do not. +// +// The body iovec capacity is always tunReadBufSize; callers (the Read +// drain loop) gate entry on tunRxBufCap-rxOff >= tunRxBufSize, sized to +// hold one worst-case kernel-supplied packet body. Without that gate the +// body iovec could be smaller than the next inbound packet and the +// kernel would truncate. +func (r *Offload) readPacket(block bool) (int, error) { + for { + r.readIovs[1].Base = &r.rxBuf[r.rxOff] + r.readIovs[1].SetLen(tunReadBufSize) + n, _, errno := syscall.Syscall(unix.SYS_READV, uintptr(r.fd), uintptr(unsafe.Pointer(&r.readIovs[0])), uintptr(len(r.readIovs))) + if errno == 0 { + if int(n) < virtio.Size { + return 0, io.ErrShortWrite + } + return int(n) - virtio.Size, nil + } + if errno == unix.EAGAIN { + if !block { + return 0, errno + } + if err := r.blockOnRead(); err != nil { + return 0, err + } + continue + } + if errno == unix.EINTR { + continue + } + if errno == unix.EBADF { + return 0, os.ErrClosed + } + return 0, errno + } +} + +// Read returns one or more packets from the tun. Each Packet either +// carries a single ready-to-use IP datagram (GSO zero) or a TSO/USO +// superpacket plus the GSOInfo a caller needs to segment it (see +// SegmentSuperpacket). The first read blocks via poll; once the fd is +// known readable we drain additional packets non-blocking until the +// kernel queue is empty (EAGAIN), we've collected tunDrainCap packets, +// or we're out of rxBuf headroom. This amortizes the poll wake over +// bursts of small packets (e.g. TCP ACKs). Packet.Bytes slices point +// into the Offload's internal buffer and are only valid until the next +// Read or Close on this Queue. +func (r *Offload) Read() ([]Packet, error) { + r.pending = r.pending[:0] + r.rxOff = 0 + + // Initial (blocking) read. Retry on decode errors so a single bad + // packet does not stall the reader. + for { + n, err := r.readPacket(true) + if err != nil { + return nil, err + } + if err := r.decodeRead(n); err != nil { + // Drop and read again — a bad packet should not kill the reader. + continue + } + break + } + + // Drain: non-blocking reads until the kernel queue is empty, the drain + // cap is reached, or rxBuf no longer has room for another worst-case + // kernel-supplied packet (tunRxBufSize). + for len(r.pending) < tunDrainCap && tunRxBufCap-r.rxOff >= tunRxBufSize { + n, err := r.readPacket(false) + if err != nil { + // EAGAIN / EINTR / anything else: stop draining. We already + // have a valid batch from the first read. + break + } + if n <= 0 { + break + } + if err := r.decodeRead(n); err != nil { + // Drop this packet and stop the drain; we'd rather hand off + // what we have than keep spinning here. + break + } + } + + return r.pending, nil +} + +// decodeRead processes the packet sitting in rxBuf at rxOff (length +// pktLen). The bytes stay in rxBuf — for GSO_NONE we slice them as a +// regular IP datagram (running finishChecksum if NEEDS_CSUM is set); +// for TSO/USO superpackets we attach the corrected GSO metadata so the +// caller can segment lazily at encrypt time. rxOff advances past the +// kernel-supplied body and nothing else, since segmentation no longer +// writes back into rxBuf. +func (r *Offload) decodeRead(pktLen int) error { + if pktLen <= 0 { + return fmt.Errorf("short tun read: %d", pktLen) + } + var hdr virtio.Hdr + hdr.Decode(r.readVnetScratch[:]) + + body := r.rxBuf[r.rxOff : r.rxOff+pktLen] + + if hdr.GSOType == unix.VIRTIO_NET_HDR_GSO_NONE { + if hdr.Flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 { + if err := virtio.FinishChecksum(body, hdr); err != nil { + return err + } + } + r.pending = append(r.pending, Packet{Bytes: body}) + r.rxOff += pktLen + return nil + } + + // GSO superpacket: validate, fix the kernel-supplied HdrLen on the + // FORWARD path (CorrectHdrLen), pick the L4 protocol, and attach + // the metadata. The bytes stay in rxBuf untouched, segmentation + // happens in SegmentSuperpacket at encrypt time. + if err := virtio.CheckValid(body, hdr); err != nil { + return err + } + if err := virtio.CorrectHdrLen(body, &hdr); err != nil { + return err + } + proto, err := protoFromGSOType(hdr.GSOType) + if err != nil { + return err + } + r.pending = append(r.pending, Packet{ + Bytes: body, + GSO: GSOInfo{ + Size: hdr.GSOSize, + HdrLen: hdr.HdrLen, + CsumStart: hdr.CsumStart, + Proto: proto, + }, + }) + r.rxOff += pktLen + return nil +} + +func (r *Offload) Write(buf []byte) (int, error) { + iovs := [2]unix.Iovec{ + {Base: &validVnetHdr[0]}, + {Base: &buf[0]}, + } + iovs[0].SetLen(virtio.Size) + iovs[1].SetLen(len(buf)) + return r.writeWithScratch(buf, &iovs) +} + +func (r *Offload) writeWithScratch(buf []byte, iovs *[2]unix.Iovec) (int, error) { + if len(buf) == 0 { + return 0, nil + } + iovs[1].Base = &buf[0] + iovs[1].SetLen(len(buf)) + return r.rawWrite(unsafe.Slice(&iovs[0], len(iovs))) +} + +func (r *Offload) rawWrite(iovs []unix.Iovec) (int, error) { + for { + n, _, errno := syscall.Syscall(unix.SYS_WRITEV, uintptr(r.fd), uintptr(unsafe.Pointer(&iovs[0])), uintptr(len(iovs))) + if errno == 0 { + if int(n) < virtio.Size { + return 0, io.ErrShortWrite + } + return int(n) - virtio.Size, nil + } + if errno == unix.EAGAIN { + if err := r.blockOnWrite(); err != nil { + return 0, err + } + continue + } + if errno == unix.EINTR { + continue + } + if errno == unix.EBADF { + return 0, os.ErrClosed + } + return 0, errno + } +} + +// Capabilities reports the offload features negotiated for this Queue. TSO +// is always true for Offload (we only construct it on IFF_VNET_HDR FDs); +// USO is true only when the kernel agreed to TUN_F_USO4|6 at open time +// (Linux ≥ 6.2). +func (r *Offload) Capabilities() Capabilities { + return Capabilities{TSO: true, USO: r.usoEnabled} +} + +func (r *Offload) WriteGSO(hdr []byte, transportHdr []byte, pays [][]byte, proto GSOProto) error { + if len(hdr) == 0 || len(pays) == 0 || len(transportHdr) == 0 { + return nil + } + // L4 checksum offset inside transportHdr: TCP=16 (the `check` field after + // seq/ack/dataoff/flags/window), UDP=6 (after sport/dport/length). + var csumOff uint16 + switch proto { + case GSOProtoUDP: + csumOff = 6 + default: + csumOff = 16 + } + vhdr := virtio.Hdr{ + Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, + HdrLen: uint16(len(hdr) + len(transportHdr)), + GSOSize: uint16(len(pays[0])), + CsumStart: uint16(len(hdr)), + CsumOffset: csumOff, + } + if len(pays) > 1 { + ipVer := hdr[0] >> 4 + switch { + case proto == GSOProtoUDP && (ipVer == 4 || ipVer == 6): + vhdr.GSOType = unix.VIRTIO_NET_HDR_GSO_UDP_L4 + case ipVer == 6: + vhdr.GSOType = unix.VIRTIO_NET_HDR_GSO_TCPV6 + case ipVer == 4: + vhdr.GSOType = unix.VIRTIO_NET_HDR_GSO_TCPV4 + default: + vhdr.GSOType = unix.VIRTIO_NET_HDR_GSO_NONE + vhdr.GSOSize = 0 + } + } else { + vhdr.GSOType = unix.VIRTIO_NET_HDR_GSO_NONE + vhdr.GSOSize = 0 + } + vhdr.Encode(r.gsoHdrBuf[:]) + + // Build the iovec array: [virtio_hdr, hdr, transportHdr, pays...]. r.gsoIovs[0] is + // wired to gsoHdrBuf at construction and never changes. + need := 3 + len(pays) + if need > cap(r.gsoIovs) { + slog.Default().Warn("tio: WriteGSO iovec budget exceeded; dropping superpacket", + "need", need, "cap", cap(r.gsoIovs), "segments", len(pays)) + return fmt.Errorf("tio: WriteGSO needs %d iovecs but cap is %d", need, cap(r.gsoIovs)) + } + r.gsoIovs = r.gsoIovs[:need] + r.gsoIovs[1].Base = &hdr[0] + r.gsoIovs[1].SetLen(len(hdr)) + r.gsoIovs[2].Base = &transportHdr[0] + r.gsoIovs[2].SetLen(len(transportHdr)) + for i, p := range pays { + r.gsoIovs[3+i].Base = &p[0] + r.gsoIovs[3+i].SetLen(len(p)) + } + + _, err := r.rawWrite(r.gsoIovs) + return err +} + +func (r *Offload) Close() error { + if r.closed.Swap(true) { + return nil + } + + //shutdownFd is owned by the container, so we should not close it + var err error + if r.fd >= 0 { + err = unix.Close(r.fd) + r.fd = -1 + } + + return err +} diff --git a/overlay/tio/tio_poll_linux.go b/overlay/tio/tio_poll_linux.go index 21ae1336..1b8c1f21 100644 --- a/overlay/tio/tio_poll_linux.go +++ b/overlay/tio/tio_poll_linux.go @@ -21,7 +21,7 @@ type Poll struct { closed atomic.Bool readBuf []byte - batchRet [1][]byte + batchRet [1]Packet } func newPoll(fd int, shutdownFd int) (*Poll, error) { @@ -97,12 +97,12 @@ func (t *Poll) blockOnWrite() error { return nil } -func (t *Poll) Read() ([][]byte, error) { +func (t *Poll) Read() ([]Packet, error) { n, err := t.readOne(t.readBuf) if err != nil { return nil, err } - t.batchRet[0] = t.readBuf[:n] + t.batchRet[0] = Packet{Bytes: t.readBuf[:n]} return t.batchRet[:], nil } diff --git a/overlay/tio/tun_file_linux_test.go b/overlay/tio/tun_file_linux_test.go index 8d66fb05..f92f58ec 100644 --- a/overlay/tio/tun_file_linux_test.go +++ b/overlay/tio/tun_file_linux_test.go @@ -15,7 +15,7 @@ import ( ) // newReadPipe returns a read fd. The matching write fd is registered for cleanup. -// The caller takes ownership of the read fd (pass it to newOffload / newFriend). +// The caller takes ownership of the read fd (pass it into a QueueSet). func newReadPipe(t *testing.T) int { t.Helper() var fds [2]int @@ -29,7 +29,7 @@ func newReadPipe(t *testing.T) int { func TestPoll_WakeForShutdown_WakesFriends(t *testing.T) { pipe1 := newReadPipe(t) pipe2 := newReadPipe(t) - parent, err := NewPollContainer() + parent, err := NewPollQueueSet() require.NoError(t, err) require.NoError(t, parent.Add(pipe1)) require.NoError(t, parent.Add(pipe2)) diff --git a/overlay/tio/tun_linux_offload.go b/overlay/tio/tun_linux_offload.go new file mode 100644 index 00000000..9eb46729 --- /dev/null +++ b/overlay/tio/tun_linux_offload.go @@ -0,0 +1,51 @@ +//go:build linux && !android && !e2e_testing +// +build linux,!android,!e2e_testing + +package tio + +import ( + "fmt" + + "golang.org/x/sys/unix" + + "github.com/slackhq/nebula/overlay/tio/virtio" +) + +// protoFromGSOType maps a virtio_net_hdr GSOType to the GSOProto value the +// segment-time helpers use. Returns an error for GSO_NONE or any unknown +// value — the caller should only invoke this on a confirmed superpacket. +func protoFromGSOType(t uint8) (GSOProto, error) { + switch t { + case unix.VIRTIO_NET_HDR_GSO_TCPV4, unix.VIRTIO_NET_HDR_GSO_TCPV6: + return GSOProtoTCP, nil + case unix.VIRTIO_NET_HDR_GSO_UDP_L4: + return GSOProtoUDP, nil + default: + return 0, fmt.Errorf("unsupported virtio gso type: %d", t) + } +} + +// SegmentSuperpacket invokes fn once per segment of pkt. For non-GSO pkts +// fn is called once with pkt.Bytes (no segmentation, no copy). For GSO/USO +// superpackets fn is called once per segment with a slice of pkt.Bytes +// holding that segment's plaintext (a freshly-patched L3+L4 header sliced +// in front of the original payload chunk). The slide is destructive: pkt is +// consumed by this call and its bytes are in an undefined state when +// SegmentSuperpacket returns. Callers must not retain pkt or any earlier +// seg slice past fn's return for that segment. The scratch parameter is +// unused on the destructive path and kept only for cross-platform +// signature compatibility. Aborts and returns the first error from fn or +// from per-segment construction. +func SegmentSuperpacket(pkt Packet, fn func(seg []byte) error) error { + if !pkt.GSO.IsSuperpacket() { + return fn(pkt.Bytes) + } + switch pkt.GSO.Proto { + case GSOProtoTCP: + return virtio.SegmentTCP(pkt.Bytes, pkt.GSO.HdrLen, pkt.GSO.CsumStart, pkt.GSO.Size, fn) + case GSOProtoUDP: + return virtio.SegmentUDP(pkt.Bytes, pkt.GSO.HdrLen, pkt.GSO.CsumStart, pkt.GSO.Size, fn) + default: + return fmt.Errorf("unsupported gso proto: %d", pkt.GSO.Proto) + } +} diff --git a/overlay/tio/tun_linux_offload_test.go b/overlay/tio/tun_linux_offload_test.go new file mode 100644 index 00000000..1cf64925 --- /dev/null +++ b/overlay/tio/tun_linux_offload_test.go @@ -0,0 +1,794 @@ +//go:build linux && !android && !e2e_testing +// +build linux,!android,!e2e_testing + +package tio + +import ( + "encoding/binary" + "os" + "testing" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/checksum" + + "github.com/slackhq/nebula/overlay/tio/virtio" +) + +// testSegScratchSize is a generous segmentation scratch sized to fit any +// of the synthetic TSO/USO superpackets these tests generate (one +// worst-case 64 KiB superpacket plus replicated per-segment headers). +const testSegScratchSize = 192 * 1024 + +// verifyChecksum confirms that the one's-complement sum across `b`, seeded +// with a folded pseudo-header sum, equals all-ones (valid). +func verifyChecksum(b []byte, pseudo uint16) bool { + return checksum.Checksum(b, pseudo) == 0xffff +} + +// segmentForTest is the test-only counterpart to the production +// SegmentSuperpacket path. It handles GSO_NONE (with optional +// finishChecksum) inline and dispatches GSO superpackets through +// SegmentSuperpacket, draining each yielded segment into a +// freshly-copied [][]byte slot so callers can iterate after the call +// returns. Tests pre-set hdr.HdrLen correctly, so correctHdrLen is not +// invoked here. +func segmentForTest(pkt []byte, hdr virtio.Hdr, out *[][]byte, scratch []byte) error { + if hdr.GSOType == unix.VIRTIO_NET_HDR_GSO_NONE { + cp := append([]byte(nil), pkt...) + if hdr.Flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 { + if err := virtio.FinishChecksum(cp, hdr); err != nil { + return err + } + } + *out = append(*out, cp) + return nil + } + proto, err := protoFromGSOType(hdr.GSOType) + if err != nil { + return err + } + gso := GSOInfo{ + Size: hdr.GSOSize, + HdrLen: hdr.HdrLen, + CsumStart: hdr.CsumStart, + Proto: proto, + } + return SegmentSuperpacket(Packet{Bytes: pkt, GSO: gso}, func(seg []byte) error { + *out = append(*out, append([]byte(nil), seg...)) + return nil + }) +} + +// pseudoHeaderIPv4 returns the folded pseudo-header sum used to verify a +// TCP/UDP segment's checksum in tests. src/dst are 4 bytes each. +func pseudoHeaderIPv4(src, dst []byte, proto byte, l4Len int) uint16 { + s := uint32(checksum.Checksum(src, 0)) + uint32(checksum.Checksum(dst, 0)) + s += uint32(proto) + uint32(l4Len) + s = (s & 0xffff) + (s >> 16) + s = (s & 0xffff) + (s >> 16) + return uint16(s) +} + +// pseudoHeaderIPv6 returns the folded pseudo-header sum used to verify a +// TCP/UDP segment's checksum in tests. src/dst are 16 bytes each. +func pseudoHeaderIPv6(src, dst []byte, proto byte, l4Len int) uint16 { + s := uint32(checksum.Checksum(src, 0)) + uint32(checksum.Checksum(dst, 0)) + s += uint32(l4Len>>16) + uint32(l4Len&0xffff) + uint32(proto) + s = (s & 0xffff) + (s >> 16) + s = (s & 0xffff) + (s >> 16) + return uint16(s) +} + +// buildTSOv4 builds a synthetic IPv4/TCP TSO superpacket with a payload of +// `payLen` bytes split at `mss`. +func buildTSOv4(t *testing.T, payLen, mss int) ([]byte, virtio.Hdr) { + t.Helper() + const ipLen = 20 + const tcpLen = 20 + pkt := make([]byte, ipLen+tcpLen+payLen) + + // IPv4 header + pkt[0] = 0x45 // version 4, IHL 5 + // total length is meaningless for TSO but set it anyway + binary.BigEndian.PutUint16(pkt[2:4], uint16(ipLen+tcpLen+payLen)) + binary.BigEndian.PutUint16(pkt[4:6], 0x4242) // original ID + pkt[8] = 64 // TTL + pkt[9] = unix.IPPROTO_TCP + copy(pkt[12:16], []byte{10, 0, 0, 1}) // src + copy(pkt[16:20], []byte{10, 0, 0, 2}) // dst + + // TCP header + binary.BigEndian.PutUint16(pkt[20:22], 12345) // sport + binary.BigEndian.PutUint16(pkt[22:24], 80) // dport + binary.BigEndian.PutUint32(pkt[24:28], 10000) // seq + binary.BigEndian.PutUint32(pkt[28:32], 20000) // ack + pkt[32] = 0x50 // data offset 5 words + pkt[33] = 0x18 // ACK | PSH + binary.BigEndian.PutUint16(pkt[34:36], 65535) // window + + // payload + for i := 0; i < payLen; i++ { + pkt[ipLen+tcpLen+i] = byte(i & 0xff) + } + + return pkt, virtio.Hdr{ + Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, + GSOType: unix.VIRTIO_NET_HDR_GSO_TCPV4, + HdrLen: uint16(ipLen + tcpLen), + GSOSize: uint16(mss), + CsumStart: uint16(ipLen), + CsumOffset: 16, + } +} + +func TestSegmentTCPv4(t *testing.T) { + const mss = 100 + const numSeg = 3 + pkt, hdr := buildTSOv4(t, mss*numSeg, mss) + + scratch := make([]byte, testSegScratchSize) + var out [][]byte + if err := segmentForTest(pkt, hdr, &out, scratch); err != nil { + t.Fatalf("segmentForTest: %v", err) + } + if len(out) != numSeg { + t.Fatalf("expected %d segments, got %d", numSeg, len(out)) + } + + for i, seg := range out { + if len(seg) != 40+mss { + t.Errorf("seg %d: unexpected len %d", i, len(seg)) + } + totalLen := binary.BigEndian.Uint16(seg[2:4]) + if totalLen != uint16(40+mss) { + t.Errorf("seg %d: total_len=%d want %d", i, totalLen, 40+mss) + } + id := binary.BigEndian.Uint16(seg[4:6]) + if id != 0x4242+uint16(i) { + t.Errorf("seg %d: ip id=%#x want %#x", i, id, 0x4242+uint16(i)) + } + seq := binary.BigEndian.Uint32(seg[24:28]) + wantSeq := uint32(10000 + i*mss) + if seq != wantSeq { + t.Errorf("seg %d: seq=%d want %d", i, seq, wantSeq) + } + flags := seg[33] + wantFlags := byte(0x10) // ACK only, PSH cleared + if i == numSeg-1 { + wantFlags = 0x18 // ACK | PSH preserved on last + } + if flags != wantFlags { + t.Errorf("seg %d: flags=%#x want %#x", i, flags, wantFlags) + } + // IPv4 header checksum must verify against itself. + if !verifyChecksum(seg[:20], 0) { + t.Errorf("seg %d: bad IPv4 header checksum", i) + } + // TCP checksum must verify against the pseudo-header. + psum := pseudoHeaderIPv4(seg[12:16], seg[16:20], unix.IPPROTO_TCP, 20+mss) + if !verifyChecksum(seg[20:], psum) { + t.Errorf("seg %d: bad TCP checksum", i) + } + } +} + +func TestSegmentTCPv4OddTail(t *testing.T) { + // Payload of 250 bytes with MSS 100 → segments of 100, 100, 50. + pkt, hdr := buildTSOv4(t, 250, 100) + scratch := make([]byte, testSegScratchSize) + var out [][]byte + if err := segmentForTest(pkt, hdr, &out, scratch); err != nil { + t.Fatalf("segmentForTest: %v", err) + } + if len(out) != 3 { + t.Fatalf("want 3 segments, got %d", len(out)) + } + wantPayLens := []int{100, 100, 50} + for i, seg := range out { + if len(seg)-40 != wantPayLens[i] { + t.Errorf("seg %d: pay len %d want %d", i, len(seg)-40, wantPayLens[i]) + } + if !verifyChecksum(seg[:20], 0) { + t.Errorf("seg %d: bad IPv4 header checksum", i) + } + psum := pseudoHeaderIPv4(seg[12:16], seg[16:20], unix.IPPROTO_TCP, 20+wantPayLens[i]) + if !verifyChecksum(seg[20:], psum) { + t.Errorf("seg %d: bad TCP checksum", i) + } + } +} + +func TestSegmentTCPv6(t *testing.T) { + const ipLen = 40 + const tcpLen = 20 + const mss = 120 + const numSeg = 2 + payLen := mss * numSeg + pkt := make([]byte, ipLen+tcpLen+payLen) + + // IPv6 header + pkt[0] = 0x60 // version 6 + binary.BigEndian.PutUint16(pkt[4:6], uint16(tcpLen+payLen)) + pkt[6] = unix.IPPROTO_TCP + pkt[7] = 64 + // src/dst fe80::1 / fe80::2 + pkt[8] = 0xfe + pkt[9] = 0x80 + pkt[23] = 1 + pkt[24] = 0xfe + pkt[25] = 0x80 + pkt[39] = 2 + + // TCP header + binary.BigEndian.PutUint16(pkt[40:42], 12345) + binary.BigEndian.PutUint16(pkt[42:44], 80) + binary.BigEndian.PutUint32(pkt[44:48], 7) + binary.BigEndian.PutUint32(pkt[48:52], 99) + pkt[52] = 0x50 + pkt[53] = 0x19 // FIN | ACK | PSH — exercise FIN clearing too + binary.BigEndian.PutUint16(pkt[54:56], 65535) + + for i := 0; i < payLen; i++ { + pkt[ipLen+tcpLen+i] = byte(i) + } + + hdr := virtio.Hdr{ + Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, + GSOType: unix.VIRTIO_NET_HDR_GSO_TCPV6, + HdrLen: uint16(ipLen + tcpLen), + GSOSize: uint16(mss), + CsumStart: uint16(ipLen), + CsumOffset: 16, + } + + scratch := make([]byte, testSegScratchSize) + var out [][]byte + if err := segmentForTest(pkt, hdr, &out, scratch); err != nil { + t.Fatalf("segmentForTest: %v", err) + } + if len(out) != numSeg { + t.Fatalf("want %d segments, got %d", numSeg, len(out)) + } + + for i, seg := range out { + if len(seg) != ipLen+tcpLen+mss { + t.Errorf("seg %d: len %d want %d", i, len(seg), ipLen+tcpLen+mss) + } + pl := binary.BigEndian.Uint16(seg[4:6]) + if pl != uint16(tcpLen+mss) { + t.Errorf("seg %d: payload_length=%d want %d", i, pl, tcpLen+mss) + } + seq := binary.BigEndian.Uint32(seg[44:48]) + if seq != uint32(7+i*mss) { + t.Errorf("seg %d: seq=%d want %d", i, seq, 7+i*mss) + } + flags := seg[53] + // Original flags = 0x19 (FIN|ACK|PSH). FIN(0x01)+PSH(0x08) should be + // cleared on all but the last; ACK(0x10) always preserved. + wantFlags := byte(0x10) + if i == numSeg-1 { + wantFlags = 0x19 + } + if flags != wantFlags { + t.Errorf("seg %d: flags=%#x want %#x", i, flags, wantFlags) + } + psum := pseudoHeaderIPv6(seg[8:24], seg[24:40], unix.IPPROTO_TCP, tcpLen+mss) + if !verifyChecksum(seg[ipLen:], psum) { + t.Errorf("seg %d: bad TCP checksum", i) + } + } +} + +func TestSegmentGSONonePassesThrough(t *testing.T) { + pkt, hdr := buildTSOv4(t, 100, 100) + hdr.GSOType = unix.VIRTIO_NET_HDR_GSO_NONE + hdr.Flags = 0 // no NEEDS_CSUM, leave packet untouched + + scratch := make([]byte, testSegScratchSize) + var out [][]byte + if err := segmentForTest(pkt, hdr, &out, scratch); err != nil { + t.Fatalf("segmentForTest: %v", err) + } + if len(out) != 1 { + t.Fatalf("want 1 segment, got %d", len(out)) + } + if len(out[0]) != len(pkt) { + t.Fatalf("unexpected length: %d vs %d", len(out[0]), len(pkt)) + } +} + +// TestSegmentRejectsLegacyUDPGSO ensures the legacy GSO_UDP (UFO) marker is +// still rejected; only modern GSO_UDP_L4 (USO) is supported. +func TestSegmentRejectsLegacyUDPGSO(t *testing.T) { + hdr := virtio.Hdr{GSOType: unix.VIRTIO_NET_HDR_GSO_UDP} + var out [][]byte + if err := segmentForTest(nil, hdr, &out, nil); err == nil { + t.Fatalf("expected rejection for legacy UDP GSO") + } +} + +// buildUSOv4 builds a synthetic IPv4/UDP USO superpacket with payload of +// payLen bytes, segmented at gsoSize. +func buildUSOv4(t *testing.T, payLen, gsoSize int) ([]byte, virtio.Hdr) { + t.Helper() + const ipLen = 20 + const udpLen = 8 + pkt := make([]byte, ipLen+udpLen+payLen) + + // IPv4 header + pkt[0] = 0x45 // version 4, IHL 5 + binary.BigEndian.PutUint16(pkt[2:4], uint16(ipLen+udpLen+payLen)) + binary.BigEndian.PutUint16(pkt[4:6], 0x4242) + pkt[8] = 64 + pkt[9] = unix.IPPROTO_UDP + copy(pkt[12:16], []byte{10, 0, 0, 1}) + copy(pkt[16:20], []byte{10, 0, 0, 2}) + + // UDP header (length + checksum filled in per segment by segmentUDPYield) + binary.BigEndian.PutUint16(pkt[20:22], 12345) // sport + binary.BigEndian.PutUint16(pkt[22:24], 53) // dport + + for i := 0; i < payLen; i++ { + pkt[ipLen+udpLen+i] = byte(i & 0xff) + } + + return pkt, virtio.Hdr{ + Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, + GSOType: unix.VIRTIO_NET_HDR_GSO_UDP_L4, + HdrLen: uint16(ipLen + udpLen), + GSOSize: uint16(gsoSize), + CsumStart: uint16(ipLen), + CsumOffset: 6, + } +} + +func TestSegmentUDPv4(t *testing.T) { + const gso = 100 + const numSeg = 3 + pkt, hdr := buildUSOv4(t, gso*numSeg, gso) + + scratch := make([]byte, testSegScratchSize) + var out [][]byte + if err := segmentForTest(pkt, hdr, &out, scratch); err != nil { + t.Fatalf("segmentForTest: %v", err) + } + if len(out) != numSeg { + t.Fatalf("expected %d segments, got %d", numSeg, len(out)) + } + + for i, seg := range out { + if len(seg) != 28+gso { + t.Errorf("seg %d: len %d want %d", i, len(seg), 28+gso) + } + totalLen := binary.BigEndian.Uint16(seg[2:4]) + if totalLen != uint16(28+gso) { + t.Errorf("seg %d: total_len=%d want %d", i, totalLen, 28+gso) + } + // kernel UDP-GSO does NOT bump the IPv4 ID across segments; every + // segment carries the same ID as the seed. + id := binary.BigEndian.Uint16(seg[4:6]) + if id != 0x4242 { + t.Errorf("seg %d: ip id=%#x want %#x", i, id, 0x4242) + } + udpLen := binary.BigEndian.Uint16(seg[24:26]) + if udpLen != uint16(8+gso) { + t.Errorf("seg %d: udp len=%d want %d", i, udpLen, 8+gso) + } + if !verifyChecksum(seg[:20], 0) { + t.Errorf("seg %d: bad IPv4 header checksum", i) + } + psum := pseudoHeaderIPv4(seg[12:16], seg[16:20], unix.IPPROTO_UDP, 8+gso) + if !verifyChecksum(seg[20:], psum) { + t.Errorf("seg %d: bad UDP checksum", i) + } + } +} + +func TestSegmentUDPv4OddTail(t *testing.T) { + // 250 bytes payload, gsoSize=100 → segments of 100, 100, 50. + pkt, hdr := buildUSOv4(t, 250, 100) + scratch := make([]byte, testSegScratchSize) + var out [][]byte + if err := segmentForTest(pkt, hdr, &out, scratch); err != nil { + t.Fatalf("segmentForTest: %v", err) + } + if len(out) != 3 { + t.Fatalf("want 3 segments, got %d", len(out)) + } + wantPay := []int{100, 100, 50} + for i, seg := range out { + if len(seg)-28 != wantPay[i] { + t.Errorf("seg %d: pay len %d want %d", i, len(seg)-28, wantPay[i]) + } + udpLen := binary.BigEndian.Uint16(seg[24:26]) + if udpLen != uint16(8+wantPay[i]) { + t.Errorf("seg %d: udp len=%d want %d", i, udpLen, 8+wantPay[i]) + } + if !verifyChecksum(seg[:20], 0) { + t.Errorf("seg %d: bad IPv4 header checksum", i) + } + psum := pseudoHeaderIPv4(seg[12:16], seg[16:20], unix.IPPROTO_UDP, 8+wantPay[i]) + if !verifyChecksum(seg[20:], psum) { + t.Errorf("seg %d: bad UDP checksum", i) + } + } +} + +func TestSegmentUDPv6(t *testing.T) { + const ipLen = 40 + const udpLen = 8 + const gso = 120 + const numSeg = 2 + payLen := gso * numSeg + pkt := make([]byte, ipLen+udpLen+payLen) + + // IPv6 header + pkt[0] = 0x60 + binary.BigEndian.PutUint16(pkt[4:6], uint16(udpLen+payLen)) + pkt[6] = unix.IPPROTO_UDP + pkt[7] = 64 + pkt[8] = 0xfe + pkt[9] = 0x80 + pkt[23] = 1 + pkt[24] = 0xfe + pkt[25] = 0x80 + pkt[39] = 2 + + binary.BigEndian.PutUint16(pkt[40:42], 12345) + binary.BigEndian.PutUint16(pkt[42:44], 53) + + for i := 0; i < payLen; i++ { + pkt[ipLen+udpLen+i] = byte(i) + } + + hdr := virtio.Hdr{ + Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, + GSOType: unix.VIRTIO_NET_HDR_GSO_UDP_L4, + HdrLen: uint16(ipLen + udpLen), + GSOSize: uint16(gso), + CsumStart: uint16(ipLen), + CsumOffset: 6, + } + + scratch := make([]byte, testSegScratchSize) + var out [][]byte + if err := segmentForTest(pkt, hdr, &out, scratch); err != nil { + t.Fatalf("segmentForTest: %v", err) + } + if len(out) != numSeg { + t.Fatalf("want %d segments, got %d", numSeg, len(out)) + } + + for i, seg := range out { + if len(seg) != ipLen+udpLen+gso { + t.Errorf("seg %d: len %d want %d", i, len(seg), ipLen+udpLen+gso) + } + pl := binary.BigEndian.Uint16(seg[4:6]) + if pl != uint16(udpLen+gso) { + t.Errorf("seg %d: payload_length=%d want %d", i, pl, udpLen+gso) + } + ul := binary.BigEndian.Uint16(seg[ipLen+4 : ipLen+6]) + if ul != uint16(udpLen+gso) { + t.Errorf("seg %d: udp len=%d want %d", i, ul, udpLen+gso) + } + psum := pseudoHeaderIPv6(seg[8:24], seg[24:40], unix.IPPROTO_UDP, udpLen+gso) + if !verifyChecksum(seg[ipLen:], psum) { + t.Errorf("seg %d: bad UDP checksum", i) + } + } +} + +// TestSegmentUDPCEPropagates confirms IP-level CE marks on the seed appear on +// every segment. UDP has no transport-level CWR/ECE: the IP TOS/TC byte is +// copied verbatim into every segment by the segment-prefix copy. +func TestSegmentUDPCEPropagates(t *testing.T) { + pkt, hdr := buildUSOv4(t, 200, 100) + pkt[1] = 0x03 // CE codepoint in IP-ECN + + scratch := make([]byte, testSegScratchSize) + var out [][]byte + if err := segmentForTest(pkt, hdr, &out, scratch); err != nil { + t.Fatalf("segmentForTest: %v", err) + } + if len(out) != 2 { + t.Fatalf("want 2 segments, got %d", len(out)) + } + for i, seg := range out { + if seg[1]&0x03 != 0x03 { + t.Errorf("seg %d: CE missing (tos=%#x)", i, seg[1]) + } + if !verifyChecksum(seg[:20], 0) { + t.Errorf("seg %d: bad IPv4 header checksum", i) + } + } +} + +// TestSegmentTCPCwrFirstSegmentOnly confirms RFC 3168 §6.1.2: when a TSO +// burst's seed has CWR set, only the first emitted segment carries CWR. +// ECE is preserved on every segment (different signal, persistent state). +func TestSegmentTCPCwrFirstSegmentOnly(t *testing.T) { + const mss = 100 + const numSeg = 3 + pkt, hdr := buildTSOv4(t, mss*numSeg, mss) + // Seed flags: CWR | ECE | ACK | PSH. + pkt[33] = 0x80 | 0x40 | 0x10 | 0x08 + + scratch := make([]byte, testSegScratchSize) + var out [][]byte + if err := segmentForTest(pkt, hdr, &out, scratch); err != nil { + t.Fatalf("segmentForTest: %v", err) + } + if len(out) != numSeg { + t.Fatalf("expected %d segments, got %d", numSeg, len(out)) + } + for i, seg := range out { + flags := seg[33] + hasCwr := flags&0x80 != 0 + hasEce := flags&0x40 != 0 + hasPsh := flags&0x08 != 0 + wantCwr := i == 0 + wantPsh := i == numSeg-1 + if hasCwr != wantCwr { + t.Errorf("seg %d: CWR=%v want %v (flags=%#x)", i, hasCwr, wantCwr, flags) + } + if !hasEce { + t.Errorf("seg %d: ECE missing (flags=%#x)", i, flags) + } + if hasPsh != wantPsh { + t.Errorf("seg %d: PSH=%v want %v (flags=%#x)", i, hasPsh, wantPsh, flags) + } + // IP and TCP checksums must still verify after the flag rewrite. + if !verifyChecksum(seg[:20], 0) { + t.Errorf("seg %d: bad IPv4 header checksum", i) + } + psum := pseudoHeaderIPv4(seg[12:16], seg[16:20], unix.IPPROTO_TCP, 20+mss) + if !verifyChecksum(seg[20:], psum) { + t.Errorf("seg %d: bad TCP checksum", i) + } + } +} + +func BenchmarkSegmentTCPv4(b *testing.B) { + sizes := []struct { + name string + payLen int + mss int + }{ + {"64KiB_MSS1460", 65000, 1460}, + {"16KiB_MSS1460", 16384, 1460}, + {"4KiB_MSS1460", 4096, 1460}, + } + for _, sz := range sizes { + b.Run(sz.name, func(b *testing.B) { + const ipLen = 20 + const tcpLen = 20 + pkt := make([]byte, ipLen+tcpLen+sz.payLen) + pkt[0] = 0x45 + binary.BigEndian.PutUint16(pkt[2:4], uint16(ipLen+tcpLen+sz.payLen)) + binary.BigEndian.PutUint16(pkt[4:6], 0x4242) + pkt[8] = 64 + pkt[9] = unix.IPPROTO_TCP + copy(pkt[12:16], []byte{10, 0, 0, 1}) + copy(pkt[16:20], []byte{10, 0, 0, 2}) + binary.BigEndian.PutUint16(pkt[20:22], 12345) + binary.BigEndian.PutUint16(pkt[22:24], 80) + binary.BigEndian.PutUint32(pkt[24:28], 10000) + binary.BigEndian.PutUint32(pkt[28:32], 20000) + pkt[32] = 0x50 + pkt[33] = 0x18 + binary.BigEndian.PutUint16(pkt[34:36], 65535) + for i := 0; i < sz.payLen; i++ { + pkt[ipLen+tcpLen+i] = byte(i) + } + hdr := virtio.Hdr{ + Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, + GSOType: unix.VIRTIO_NET_HDR_GSO_TCPV4, + HdrLen: uint16(ipLen + tcpLen), + GSOSize: uint16(sz.mss), + CsumStart: uint16(ipLen), + CsumOffset: 16, + } + + scratch := make([]byte, testSegScratchSize) + out := make([][]byte, 0, 64) + + // SegmentSuperpacket consumes its input destructively; restore + // pkt from a master copy each iteration. The restore mirrors the + // kernel→userspace copy that hands a fresh GSO blob to the + // segmenter in production, so it's representative cost rather + // than bench overhead. + master := append([]byte(nil), pkt...) + work := make([]byte, len(pkt)) + + b.SetBytes(int64(len(pkt))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + copy(work, master) + out = out[:0] + if err := segmentForTest(work, hdr, &out, scratch); err != nil { + b.Fatal(err) + } + } + }) + } +} + +// TestTunFileWriteVnetHdrNoAlloc verifies the IFF_VNET_HDR fast-path write is +// allocation-free. We write to /dev/null so every call succeeds synchronously. +func TestTunFileWriteVnetHdrNoAlloc(t *testing.T) { + fd, err := unix.Open("/dev/null", os.O_WRONLY, 0) + if err != nil { + t.Fatalf("open /dev/null: %v", err) + } + t.Cleanup(func() { _ = unix.Close(fd) }) + + tf := &Offload{fd: fd} + + payload := make([]byte, 1400) + // Warm up (first call may trigger one-time internal allocations elsewhere). + if _, err := tf.Write(payload); err != nil { + t.Fatalf("Write: %v", err) + } + + allocs := testing.AllocsPerRun(1000, func() { + if _, err := tf.Write(payload); err != nil { + t.Fatalf("Write: %v", err) + } + }) + if allocs != 0 { + t.Fatalf("Write allocated %.1f times per call, want 0", allocs) + } +} + +// buildTSOv6 builds a synthetic IPv6/TCP TSO superpacket with payLen bytes +// of payload, segmented at gso. Returns the packet bytes only; the +// virtio_net_hdr is the caller's responsibility. +func buildTSOv6(payLen, gso int) []byte { + const ipLen = 40 + const tcpLen = 20 + pkt := make([]byte, ipLen+tcpLen+payLen) + + pkt[0] = 0x60 // version 6 + binary.BigEndian.PutUint16(pkt[4:6], uint16(tcpLen+payLen)) + pkt[6] = unix.IPPROTO_TCP + pkt[7] = 64 + pkt[8] = 0xfe + pkt[9] = 0x80 + pkt[23] = 1 + pkt[24] = 0xfe + pkt[25] = 0x80 + pkt[39] = 2 + + binary.BigEndian.PutUint16(pkt[40:42], 12345) + binary.BigEndian.PutUint16(pkt[42:44], 80) + binary.BigEndian.PutUint32(pkt[44:48], 7) + binary.BigEndian.PutUint32(pkt[48:52], 99) + pkt[52] = 0x50 + pkt[53] = 0x10 // ACK only + binary.BigEndian.PutUint16(pkt[54:56], 65535) + + for i := 0; i < payLen; i++ { + pkt[ipLen+tcpLen+i] = byte(i) + } + return pkt +} + +// TestDecodeReadFitsMaxTSOAtDrainThreshold proves the rxBuf sizing is +// correct: when rxOff is at the maximum value the drain headroom check +// allows, decodeRead must still be able to absorb a worst-case 64KiB +// TSO superpacket without dropping the burst. With segmentation deferred +// to encrypt time, decodeRead writes only the kernel-supplied bytes into +// rxBuf, so the size requirement is just "fit one worst-case input." +// +// Regression history: in a prior layout the rx buffer doubled as the +// segmentation output, a near-threshold drain read returned "scratch too +// small", the whole 45-segment TSO burst was dropped, and the remote's TCP +// fast-retransmit collapsed cwnd. Keeping this test in the new layout +// guards against re-introducing a drain headroom shortfall. +func TestDecodeReadFitsMaxTSOAtDrainThreshold(t *testing.T) { + const ipv6HdrLen = 40 + const tcpHdrLen = 20 + const headerLen = ipv6HdrLen + tcpHdrLen + // Maximum TUN read body. The tunReadBufSize cap on readv's body iovec + // is what bounds the kernel's superpacket length. + pktLen := tunReadBufSize + payLen := pktLen - headerLen + const targetSegs = 64 + gsoSize := (payLen + targetSegs - 1) / targetSegs + + pkt := buildTSOv6(payLen, gsoSize) + if len(pkt) != pktLen { + t.Fatalf("buildTSOv6 produced %d bytes, want %d", len(pkt), pktLen) + } + + o := &Offload{ + rxBuf: make([]byte, tunRxBufCap), + } + // rxOff at the maximum value the drain headroom check permits before + // it would refuse another read. Any drain-time read up to this + // threshold MUST still process correctly. + o.rxOff = tunRxBufCap - tunRxBufSize + + // Stage the body in rxBuf as if readv(2) just placed it there. + copy(o.rxBuf[o.rxOff:], pkt) + + // Encode the matching virtio_net_hdr. + hdr := virtio.Hdr{ + Flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, + GSOType: unix.VIRTIO_NET_HDR_GSO_TCPV6, + HdrLen: uint16(headerLen), + GSOSize: uint16(gsoSize), + CsumStart: uint16(ipv6HdrLen), + CsumOffset: 16, + } + hdr.Encode(o.readVnetScratch[:]) + + startRxOff := o.rxOff + if err := o.decodeRead(pktLen); err != nil { + t.Fatalf("decodeRead at drain threshold returned %v — rxBuf sizing regression: "+ + "tunRxBufSize=%d must hold one worst-case input (%d)", + err, tunRxBufSize, pktLen) + } + + if len(o.pending) != 1 { + t.Fatalf("got %d packets, want 1 superpacket entry", len(o.pending)) + } + got := o.pending[0] + if !got.GSO.IsSuperpacket() { + t.Fatalf("expected superpacket GSO metadata, got %+v", got.GSO) + } + if got.GSO.Proto != GSOProtoTCP { + t.Errorf("GSO.Proto=%d want TCP", got.GSO.Proto) + } + if got.GSO.Size != uint16(gsoSize) { + t.Errorf("GSO.Size=%d want %d", got.GSO.Size, gsoSize) + } + if got.GSO.HdrLen != uint16(headerLen) { + t.Errorf("GSO.HdrLen=%d want %d", got.GSO.HdrLen, headerLen) + } + if got.GSO.CsumStart != uint16(ipv6HdrLen) { + t.Errorf("GSO.CsumStart=%d want %d", got.GSO.CsumStart, ipv6HdrLen) + } + if len(got.Bytes) != pktLen { + t.Errorf("len(Bytes)=%d want %d", len(got.Bytes), pktLen) + } + + // rxOff advances exactly by the kernel-supplied body length — no + // segmentation output to account for any more. + if o.rxOff != startRxOff+pktLen { + t.Errorf("rxOff=%d want %d", o.rxOff, startRxOff+pktLen) + } + if o.rxOff > tunRxBufCap { + t.Fatalf("rxOff=%d overran rxBuf (cap=%d)", o.rxOff, tunRxBufCap) + } + + // Validate that segmenting the returned superpacket reproduces the + // expected per-segment IPv6 payload length and TCP checksum. + wantSegs := (payLen + gsoSize - 1) / gsoSize + gotSegs := 0 + if err := SegmentSuperpacket(got, func(seg []byte) error { + defer func() { gotSegs++ }() + if len(seg) < headerLen+1 { + t.Errorf("seg %d too short: %d", gotSegs, len(seg)) + return nil + } + if seg[0]>>4 != 6 { + t.Errorf("seg %d: bad IP version %#x", gotSegs, seg[0]) + } + segPay := len(seg) - headerLen + gotPL := binary.BigEndian.Uint16(seg[4:6]) + if gotPL != uint16(tcpHdrLen+segPay) { + t.Errorf("seg %d: payload_len=%d want %d", gotSegs, gotPL, tcpHdrLen+segPay) + } + psum := pseudoHeaderIPv6(seg[8:24], seg[24:40], unix.IPPROTO_TCP, tcpHdrLen+segPay) + if !verifyChecksum(seg[ipv6HdrLen:], psum) { + t.Errorf("seg %d: bad TCP checksum", gotSegs) + } + return nil + }); err != nil { + t.Fatalf("SegmentSuperpacket: %v", err) + } + if gotSegs != wantSegs { + t.Fatalf("got %d segments, want %d", gotSegs, wantSegs) + } +} diff --git a/overlay/tio/virtio/header_linux.go b/overlay/tio/virtio/header_linux.go new file mode 100644 index 00000000..a080c9dc --- /dev/null +++ b/overlay/tio/virtio/header_linux.go @@ -0,0 +1,43 @@ +//go:build linux && !android && !e2e_testing +// +build linux,!android,!e2e_testing + +package virtio + +import "encoding/binary" + +// Size is the on-wire length of struct virtio_net_hdr the kernel +// prepends/expects on a TUN opened with IFF_VNET_HDR (TUNSETVNETHDRSZ +// not set). +const Size = 10 + +// Hdr is the Go view of the legacy virtio_net_hdr. +type Hdr struct { + Flags uint8 + GSOType uint8 + HdrLen uint16 + GSOSize uint16 + CsumStart uint16 + CsumOffset uint16 +} + +// Decode reads a virtio_net_hdr in host byte order (TUN default; we never +// call TUNSETVNETLE so the kernel matches our endianness). +func (h *Hdr) Decode(b []byte) { + h.Flags = b[0] + h.GSOType = b[1] + h.HdrLen = binary.NativeEndian.Uint16(b[2:4]) + h.GSOSize = binary.NativeEndian.Uint16(b[4:6]) + h.CsumStart = binary.NativeEndian.Uint16(b[6:8]) + h.CsumOffset = binary.NativeEndian.Uint16(b[8:10]) +} + +// Encode is the inverse of Decode: writes the virtio_net_hdr fields into b +// (must be at least Size bytes). Used to emit a TSO superpacket on egress. +func (h *Hdr) Encode(b []byte) { + b[0] = h.Flags + b[1] = h.GSOType + binary.NativeEndian.PutUint16(b[2:4], h.HdrLen) + binary.NativeEndian.PutUint16(b[4:6], h.GSOSize) + binary.NativeEndian.PutUint16(b[6:8], h.CsumStart) + binary.NativeEndian.PutUint16(b[8:10], h.CsumOffset) +} diff --git a/overlay/tio/virtio/segment_linux.go b/overlay/tio/virtio/segment_linux.go new file mode 100644 index 00000000..f0e90c0f --- /dev/null +++ b/overlay/tio/virtio/segment_linux.go @@ -0,0 +1,401 @@ +//go:build linux && !android && !e2e_testing +// +build linux,!android,!e2e_testing + +// Package virtio implements the pure validation, header-correction, and +// per-segment slicing logic for kernel-supplied TSO/USO superpackets on +// IFF_VNET_HDR TUN devices. It is FD-free and depends only on the byte +// layout of the virtio_net_hdr and the IP/TCP/UDP headers it describes, +// so it can be unit-tested in isolation from the tio Queue runtime. +package virtio + +import ( + "encoding/binary" + "errors" + "fmt" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/checksum" +) + +// Protocol header size bounds used to validate / cap kernel-supplied offsets. +const ( + ipv4HeaderMinLen = 20 // IHL=5, no options + ipv4HeaderMaxLen = 60 // IHL=15, max options + ipv6FixedLen = 40 // IPv6 base header; extensions would extend this + tcpHeaderMinLen = 20 // data-offset=5, no options + tcpHeaderMaxLen = 60 // data-offset=15, max options +) + +// Byte offsets inside an IPv4 header. +const ( + ipv4TotalLenOff = 2 + ipv4IDOff = 4 + ipv4ChecksumOff = 10 + ipv4SrcOff = 12 + ipv4AddrsEnd = 20 // end of dst address (ipv4SrcOff + 2*4) +) + +// Byte offsets inside an IPv6 header. +const ( + ipv6PayloadLenOff = 4 + ipv6SrcOff = 8 + ipv6AddrsEnd = 40 // end of dst address (ipv6SrcOff + 2*16) +) + +// Byte offsets inside a TCP header (relative to its start, i.e. csumStart). +const ( + tcpSeqOff = 4 + tcpDataOffOff = 12 // upper nibble is header len in 32-bit words + tcpFlagsOff = 13 + tcpChecksumOff = 16 +) + +// UDP header is fixed at 8 bytes: {sport, dport, length, checksum}. +const ( + udpHeaderLen = 8 + udpLengthOff = 4 + udpChecksumOff = 6 +) + +// tcpFinPshMask is cleared on every segment except the last of a TSO burst. +const tcpFinPshMask = 0x09 // FIN(0x01) | PSH(0x08) + +// tcpCwrFlag is cleared on every segment except the first. Per RFC 3168 +// §6.1.2 the CWR bit signals a one-shot transition (the sender just halved +// its window) and must appear on the first segment of a TSO burst only. +const tcpCwrFlag = 0x80 + +// CheckValid rejects packets whose virtio_net_hdr/IP combination would +// cause a downstream miscompute. The TUN should never emit RSC_INFO and +// the GSO type must agree with the IP version nibble. +func CheckValid(pkt []byte, hdr Hdr) error { + // When RSC_INFO is set the csum_start/csum_offset fields are repurposed to + // carry coalescing info rather than checksum offsets. A TUN writing via + // IFF_VNET_HDR should never emit this, but if it did we would silently + // miscompute the segment checksums — refuse the packet instead. + if hdr.Flags&unix.VIRTIO_NET_HDR_F_RSC_INFO != 0 { + return fmt.Errorf("virtio RSC_INFO flag not supported on TUN reads") + } + if len(pkt) < ipv4HeaderMinLen { + return fmt.Errorf("packet too short") + } + ipVersion := pkt[0] >> 4 + switch hdr.GSOType { + case unix.VIRTIO_NET_HDR_GSO_TCPV4: + if ipVersion != 4 { + return fmt.Errorf("invalid IP version %d for GSO type %d", ipVersion, hdr.GSOType) + } + case unix.VIRTIO_NET_HDR_GSO_TCPV6: + if ipVersion != 6 { + return fmt.Errorf("invalid IP version %d for GSO type %d", ipVersion, hdr.GSOType) + } + case unix.VIRTIO_NET_HDR_GSO_UDP_L4: + // USO carries either v4 or v6; the leading nibble disambiguates. + if !(ipVersion == 4 || ipVersion == 6) { + return fmt.Errorf("invalid IP version %d for GSO type %d", ipVersion, hdr.GSOType) + } + default: + if !(ipVersion == 6 || ipVersion == 4) { + return fmt.Errorf("invalid IP version %d for GSO type %d", ipVersion, hdr.GSOType) + } + } + + return nil +} + +// CorrectHdrLen rewrites hdr.HdrLen based on the actual transport header +// length read out of pkt. The kernel's hdr.HdrLen on the FORWARD path can +// be the length of the entire first packet, so we don't trust it. +func CorrectHdrLen(pkt []byte, hdr *Hdr) error { + // Thank you wireguard-go for documenting these edge-cases + // Don't trust hdr.hdrLen from the kernel as it can be equal to the length + // of the entire first packet when the kernel is handling it as part of a + // FORWARD path. Instead, parse the transport header length and add it onto + // csumStart, which is synonymous for IP header length. + + if hdr.GSOType == unix.VIRTIO_NET_HDR_GSO_UDP_L4 { + hdr.HdrLen = hdr.CsumStart + 8 + } else { + if len(pkt) <= int(hdr.CsumStart+tcpDataOffOff) { + return errors.New("packet is too short") + } + + tcpHLen := uint16(pkt[hdr.CsumStart+tcpDataOffOff] >> 4 * 4) + if tcpHLen < 20 || tcpHLen > 60 { + // A TCP header must be between 20 and 60 bytes in length. + return fmt.Errorf("tcp header len is invalid: %d", tcpHLen) + } + hdr.HdrLen = hdr.CsumStart + tcpHLen + } + + if len(pkt) < int(hdr.HdrLen) { + return fmt.Errorf("length of packet (%d) < virtioNetHdr.HdrLen (%d)", len(pkt), hdr.HdrLen) + } + + if hdr.HdrLen < hdr.CsumStart { + return fmt.Errorf("virtioNetHdr.HdrLen (%d) < virtioNetHdr.CsumStart (%d)", hdr.HdrLen, hdr.CsumStart) + } + cSumAt := int(hdr.CsumStart + hdr.CsumStart) + if cSumAt+1 >= len(pkt) { + return fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(pkt)) + } + return nil +} + +// SegmentTCP walks a TSO superpacket pkt, yielding each segment as a +// slice into pkt itself. Per-segment plaintext is laid out by sliding a +// freshly-patched copy of the L3+L4 header into pkt at offset i*gsoSize, +// where it sits immediately before that segment's payload chunk in the +// original buffer. The slide is destructive: iter i's header write overwrites +// the last hdrLen bytes of seg_{i-1}'s payload, which is dead by the time +// the next iteration begins. pkt is consumed by this call and must not be +// inspected by the caller after the final yield. +func SegmentTCP(pkt []byte, hdrLenU, csumStartU, gsoSizeU uint16, yield func(seg []byte) error) error { + if gsoSizeU == 0 { + return fmt.Errorf("gso_size is zero") + } + if csumStartU == 0 { + return fmt.Errorf("csum_start is zero") + } + + headerLen := int(hdrLenU) + csumStart := int(csumStartU) + isV4 := pkt[0]>>4 == 4 + + tcpHdrLen := int(pkt[csumStart+tcpDataOffOff]>>4) * 4 + payLen := len(pkt) - headerLen + gsoSize := int(gsoSizeU) + numSeg := (payLen + gsoSize - 1) / gsoSize + if numSeg == 0 { + numSeg = 1 + } + + origSeq := binary.BigEndian.Uint32(pkt[csumStart+tcpSeqOff : csumStart+tcpSeqOff+4]) + origFlags := pkt[csumStart+tcpFlagsOff] + + var tmp [tcpHeaderMaxLen]byte + copy(tmp[:tcpHdrLen], pkt[csumStart:headerLen]) + tmp[tcpSeqOff], tmp[tcpSeqOff+1], tmp[tcpSeqOff+2], tmp[tcpSeqOff+3] = 0, 0, 0, 0 + tmp[tcpFlagsOff] = 0 + tmp[tcpChecksumOff], tmp[tcpChecksumOff+1] = 0, 0 + baseTcpHdrSum := uint32(checksum.Checksum(tmp[:tcpHdrLen], 0)) + + var baseProtoSum uint32 + if isV4 { + baseProtoSum = uint32(checksum.Checksum(pkt[ipv4SrcOff:ipv4AddrsEnd], 0)) + } else { + baseProtoSum = uint32(checksum.Checksum(pkt[ipv6SrcOff:ipv6AddrsEnd], 0)) + } + baseProtoSum += uint32(unix.IPPROTO_TCP) + + var origIPID uint16 + var baseIPHdrSum uint32 + if isV4 { + origIPID = binary.BigEndian.Uint16(pkt[ipv4IDOff : ipv4IDOff+2]) + ihl := int(pkt[0]&0x0f) * 4 + if ihl < ipv4HeaderMinLen || ihl > csumStart { + return fmt.Errorf("bad IPv4 IHL: %d", ihl) + } + var ipTmp [ipv4HeaderMaxLen]byte + copy(ipTmp[:ihl], pkt[:ihl]) + ipTmp[ipv4TotalLenOff], ipTmp[ipv4TotalLenOff+1] = 0, 0 + ipTmp[ipv4IDOff], ipTmp[ipv4IDOff+1] = 0, 0 + ipTmp[ipv4ChecksumOff], ipTmp[ipv4ChecksumOff+1] = 0, 0 + baseIPHdrSum = uint32(checksum.Checksum(ipTmp[:ihl], 0)) + } + + for i := 0; i < numSeg; i++ { + segStart := i * gsoSize + segEnd := segStart + gsoSize + if segEnd > payLen { + segEnd = payLen + } + segPayLen := segEnd - segStart + segLen := headerLen + segPayLen + headerOff := i * gsoSize + + // Slide the header into place immediately before this segment's + // payload. Iter 0's header is already at pkt[:headerLen]; for + // i ≥ 1 we copy from there. The constant-byte fields of pkt[:headerLen] + // survive iter 0's in-place patches (only seq/flags/cksum/totalLen/id + // are touched), and iter 0's stale variable-field values are + // overwritten by the per-segment patches below. + if i > 0 { + copy(pkt[headerOff:headerOff+headerLen], pkt[:headerLen]) + } + seg := pkt[headerOff : headerOff+segLen] + + segSeq := origSeq + uint32(segStart) + segFlags := origFlags + if i != 0 { + segFlags &^= tcpCwrFlag + } + if i != numSeg-1 { + segFlags &^= tcpFinPshMask + } + totalLen := segLen + + if isV4 { + segID := origIPID + uint16(i) + binary.BigEndian.PutUint16(seg[ipv4TotalLenOff:ipv4TotalLenOff+2], uint16(totalLen)) + binary.BigEndian.PutUint16(seg[ipv4IDOff:ipv4IDOff+2], segID) + ipSum := baseIPHdrSum + uint32(totalLen) + uint32(segID) + binary.BigEndian.PutUint16(seg[ipv4ChecksumOff:ipv4ChecksumOff+2], foldComplement(ipSum)) + } else { + binary.BigEndian.PutUint16(seg[ipv6PayloadLenOff:ipv6PayloadLenOff+2], uint16(headerLen-ipv6FixedLen+segPayLen)) + } + + binary.BigEndian.PutUint32(seg[csumStart+tcpSeqOff:csumStart+tcpSeqOff+4], segSeq) + seg[csumStart+tcpFlagsOff] = segFlags + + tcpLen := tcpHdrLen + segPayLen + // Payload bytes still live at their original offset in pkt. The + // header slide above only writes into pkt[i*G : i*G+H], which is + // the tail of seg_{i-1}'s payload (already consumed) and never + // overlaps seg_i's own payload at pkt[H+i*G : H+(i+1)*G]. + paySum := uint32(checksum.Checksum(pkt[headerLen+segStart:headerLen+segEnd], 0)) + wide := uint64(baseTcpHdrSum) + uint64(paySum) + uint64(baseProtoSum) + wide += uint64(segSeq) + uint64(segFlags) + uint64(tcpLen) + wide = (wide & 0xffffffff) + (wide >> 32) + wide = (wide & 0xffffffff) + (wide >> 32) + binary.BigEndian.PutUint16(seg[csumStart+tcpChecksumOff:csumStart+tcpChecksumOff+2], foldComplement(uint32(wide))) + + if err := yield(seg); err != nil { + return err + } + } + + return nil +} + +// SegmentUDP walks a USO superpacket, sliding a per-segment-patched +// L3+L4 header into pkt at offset i*gsoSize and yielding pkt[i*G:i*G+segLen] +// to the caller. Per-segment patches are total_len + IPv4 csum (or IPv6 +// payload_len) plus the UDP length and checksum. pkt is consumed +// destructively; see SegmentTCP for the layout reasoning. +// +// UDP-GSO leaves the IPv4 ID identical across segments (the kernel does not +// bump it), which is why the IP-level per-segment work is limited to +// total_len + IPv4 header checksum (v4) or payload_len (v6). +func SegmentUDP(pkt []byte, hdrLenU, csumStartU, gsoSizeU uint16, yield func(seg []byte) error) error { + if gsoSizeU == 0 { + return fmt.Errorf("gso_size is zero") + } + if csumStartU == 0 { + return fmt.Errorf("csum_start is zero") + } + + isV4 := pkt[0]>>4 == 4 + headerLen := int(hdrLenU) + csumStart := int(csumStartU) + if headerLen-csumStart != udpHeaderLen { + return fmt.Errorf("udp header len mismatch: %d", headerLen-csumStart) + } + + payLen := len(pkt) - headerLen + gsoSize := int(gsoSizeU) + numSeg := (payLen + gsoSize - 1) / gsoSize + if numSeg == 0 { + numSeg = 1 + } + + var udpTmp [udpHeaderLen]byte + copy(udpTmp[:], pkt[csumStart:headerLen]) + udpTmp[udpLengthOff], udpTmp[udpLengthOff+1] = 0, 0 + udpTmp[udpChecksumOff], udpTmp[udpChecksumOff+1] = 0, 0 + baseUDPHdrSum := uint32(checksum.Checksum(udpTmp[:], 0)) + + var baseProtoSum uint32 + if isV4 { + baseProtoSum = uint32(checksum.Checksum(pkt[ipv4SrcOff:ipv4AddrsEnd], 0)) + } else { + baseProtoSum = uint32(checksum.Checksum(pkt[ipv6SrcOff:ipv6AddrsEnd], 0)) + } + baseProtoSum += uint32(unix.IPPROTO_UDP) + + var baseIPHdrSum uint32 + if isV4 { + ihl := int(pkt[0]&0x0f) * 4 + if ihl < ipv4HeaderMinLen || ihl > csumStart { + return fmt.Errorf("bad IPv4 IHL: %d", ihl) + } + var ipTmp [ipv4HeaderMaxLen]byte + copy(ipTmp[:ihl], pkt[:ihl]) + ipTmp[ipv4TotalLenOff], ipTmp[ipv4TotalLenOff+1] = 0, 0 + ipTmp[ipv4ChecksumOff], ipTmp[ipv4ChecksumOff+1] = 0, 0 + baseIPHdrSum = uint32(checksum.Checksum(ipTmp[:ihl], 0)) + } + + for i := 0; i < numSeg; i++ { + segStart := i * gsoSize + segEnd := segStart + gsoSize + if segEnd > payLen { + segEnd = payLen + } + segPayLen := segEnd - segStart + segLen := headerLen + segPayLen + headerOff := i * gsoSize + + if i > 0 { + copy(pkt[headerOff:headerOff+headerLen], pkt[:headerLen]) + } + seg := pkt[headerOff : headerOff+segLen] + + totalLen := segLen + udpLen := udpHeaderLen + segPayLen + + if isV4 { + binary.BigEndian.PutUint16(seg[ipv4TotalLenOff:ipv4TotalLenOff+2], uint16(totalLen)) + ipSum := baseIPHdrSum + uint32(totalLen) + binary.BigEndian.PutUint16(seg[ipv4ChecksumOff:ipv4ChecksumOff+2], foldComplement(ipSum)) + } else { + binary.BigEndian.PutUint16(seg[ipv6PayloadLenOff:ipv6PayloadLenOff+2], uint16(headerLen-ipv6FixedLen+segPayLen)) + } + + binary.BigEndian.PutUint16(seg[csumStart+udpLengthOff:csumStart+udpLengthOff+2], uint16(udpLen)) + + paySum := uint32(checksum.Checksum(pkt[headerLen+segStart:headerLen+segEnd], 0)) + wide := uint64(baseUDPHdrSum) + uint64(paySum) + uint64(baseProtoSum) + wide += uint64(udpLen) + uint64(udpLen) + wide = (wide & 0xffffffff) + (wide >> 32) + wide = (wide & 0xffffffff) + (wide >> 32) + csum := foldComplement(uint32(wide)) + if csum == 0 { + csum = 0xffff + } + binary.BigEndian.PutUint16(seg[csumStart+udpChecksumOff:csumStart+udpChecksumOff+2], csum) + + if err := yield(seg); err != nil { + return err + } + } + + return nil +} + +// FinishChecksum computes the L4 checksum for a non-GSO packet that the kernel +// handed us with NEEDS_CSUM set. csum_start / csum_offset point at the 16-bit +// checksum field; we zero it, fold a full sum (the field was pre-loaded with +// the pseudo-header partial sum by the kernel), and store the result. +func FinishChecksum(seg []byte, hdr Hdr) error { + cs := int(hdr.CsumStart) + co := int(hdr.CsumOffset) + if cs+co+2 > len(seg) { + return fmt.Errorf("csum offsets out of range: start=%d offset=%d len=%d", cs, co, len(seg)) + } + // The kernel stores a partial pseudo-header sum at [cs+co:]; sum over the + // L4 region starting at cs, folding the prior partial in as the seed. + partial := binary.BigEndian.Uint16(seg[cs+co : cs+co+2]) + seg[cs+co] = 0 + seg[cs+co+1] = 0 + binary.BigEndian.PutUint16(seg[cs+co:cs+co+2], ^checksum.Checksum(seg[cs:], partial)) + return nil +} + +// foldComplement folds a 32-bit one's-complement partial sum to 16 bits and +// complements it, yielding the on-wire Internet checksum value. +func foldComplement(sum uint32) uint16 { + sum = (sum & 0xffff) + (sum >> 16) + sum = (sum & 0xffff) + (sum >> 16) + return ^uint16(sum) +} diff --git a/overlay/tun_android.go b/overlay/tun_android.go index c2342556..ea2e1295 100644 --- a/overlay/tun_android.go +++ b/overlay/tun_android.go @@ -27,15 +27,15 @@ type tun struct { l *slog.Logger readBuf []byte - batchRet [1][]byte + batchRet [1]tio.Packet } -func (t *tun) Read() ([][]byte, error) { +func (t *tun) Read() ([]tio.Packet, error) { n, err := t.rwc.Read(t.readBuf) if err != nil { return nil, err } - t.batchRet[0] = t.readBuf[:n] + t.batchRet[0] = tio.Packet{Bytes: t.readBuf[:n]} return t.batchRet[:], nil } diff --git a/overlay/tun_darwin.go b/overlay/tun_darwin.go index b4254b05..9ace4fc8 100644 --- a/overlay/tun_darwin.go +++ b/overlay/tun_darwin.go @@ -37,7 +37,7 @@ type tun struct { out []byte readBuf []byte - batchRet [1][]byte + batchRet [1]tio.Packet } type ifReq struct { @@ -516,12 +516,12 @@ func (t *tun) readOne(to []byte) (int, error) { return n - 4, err } -func (t *tun) Read() ([][]byte, error) { +func (t *tun) Read() ([]tio.Packet, error) { n, err := t.readOne(t.readBuf) if err != nil { return nil, err } - t.batchRet[0] = t.readBuf[:n] + t.batchRet[0] = tio.Packet{Bytes: t.readBuf[:n]} return t.batchRet[:], nil } diff --git a/overlay/tun_disabled.go b/overlay/tun_disabled.go index 12f8b883..ff86bc29 100644 --- a/overlay/tun_disabled.go +++ b/overlay/tun_disabled.go @@ -23,23 +23,41 @@ type disabledTun struct { rx metrics.Counter l *slog.Logger numReaders int - - batchRet [1][]byte } -func (t *disabledTun) Read() ([][]byte, error) { - r, ok := <-t.read +// disabledQueue is one tio.Queue view onto a shared disabledTun. Each queue +// owns a private batchRet so concurrent Read calls from different reader +// goroutines do not race on the returned slice. +type disabledQueue struct { + parent *disabledTun + batchRet [1]tio.Packet +} + +func (q *disabledQueue) Read() ([]tio.Packet, error) { + r, ok := <-q.parent.read if !ok { return nil, io.EOF } - t.tx.Inc(1) - if t.l.Enabled(context.Background(), slog.LevelDebug) { - t.l.Debug("Write payload", "raw", prettyPacket(r)) + q.parent.tx.Inc(1) + if q.parent.l.Enabled(context.Background(), slog.LevelDebug) { + q.parent.l.Debug("Write payload", "raw", prettyPacket(r)) } - t.batchRet[0] = r - return t.batchRet[:], nil + q.batchRet[0] = tio.Packet{Bytes: r} + return q.batchRet[:], nil +} + +// Write on a queue forwards to the underlying disabledTun. All queues share +// one ICMP-handling/log path so this is a thin pass-through. +func (q *disabledQueue) Write(b []byte) (int, error) { + return q.parent.Write(b) +} + +// Close on a queue is a no-op. The shared channel and metrics are owned by +// the disabledTun; Close on the device tears them down once for everybody. +func (q *disabledQueue) Close() error { + return nil } func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *slog.Logger) *disabledTun { @@ -120,7 +138,7 @@ func (t *disabledTun) NewMultiQueueReader() error { func (t *disabledTun) Readers() []tio.Queue { out := make([]tio.Queue, t.numReaders) for i := range t.numReaders { - out[i] = t + out[i] = &disabledQueue{parent: t} } return out } diff --git a/overlay/tun_freebsd.go b/overlay/tun_freebsd.go index 90beb557..71784ad7 100644 --- a/overlay/tun_freebsd.go +++ b/overlay/tun_freebsd.go @@ -104,7 +104,7 @@ type tun struct { closed atomic.Bool readBuf []byte - batchRet [1][]byte + batchRet [1]tio.Packet } // blockOnRead waits until the tun fd is readable or shutdown has been signaled. @@ -159,12 +159,12 @@ func (t *tun) blockOnWrite() error { return nil } -func (t *tun) Read() ([][]byte, error) { +func (t *tun) Read() ([]tio.Packet, error) { n, err := t.readOne(t.readBuf) if err != nil { return nil, err } - t.batchRet[0] = t.readBuf[:n] + t.batchRet[0] = tio.Packet{Bytes: t.readBuf[:n]} return t.batchRet[:], nil } diff --git a/overlay/tun_ios.go b/overlay/tun_ios.go index 44607eb2..2c332e06 100644 --- a/overlay/tun_ios.go +++ b/overlay/tun_ios.go @@ -29,15 +29,15 @@ type tun struct { l *slog.Logger readBuf []byte - batchRet [1][]byte + batchRet [1]tio.Packet } -func (t *tun) Read() ([][]byte, error) { +func (t *tun) Read() ([]tio.Packet, error) { n, err := t.rwc.Read(t.readBuf) if err != nil { return nil, err } - t.batchRet[0] = t.readBuf[:n] + t.batchRet[0] = tio.Packet{Bytes: t.readBuf[:n]} return t.batchRet[:], nil } diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 19e3ceb0..c18fc38e 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -25,7 +25,7 @@ import ( ) type tun struct { - readers tio.Container + readers tio.QueueSet closeLock sync.Mutex Device string vpnNetworks []netip.Prefix @@ -34,6 +34,14 @@ type tun struct { TXQueueLen int deviceIndex int ioctlFd uintptr + vnetHdr bool + // routeFeatureECN, when true, sets RTAX_FEATURE_ECN on every route we + // install for the tun. The kernel then actively negotiates ECN for + // connections destined to those prefixes (equivalent to `ip route + // change ... features ecn`) regardless of net.ipv4.tcp_ecn, so flows + // across the nebula mesh use ECN even when the host default is the + // passive setting (=2). Disable via tunnels.ecn=false. + routeFeatureECN bool Routes atomic.Pointer[[]Route] routeTree atomic.Pointer[bart.Table[routing.Gateways]] @@ -72,7 +80,9 @@ type ifreqQLEN struct { } func newTunFromFd(c *config.C, l *slog.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { - t, err := newTunGeneric(c, l, deviceFd, vpnNetworks) + // We don't know what flags the caller opened this fd with and can't turn + // on IFF_VNET_HDR after TUNSETIFF, so skip offload on inherited fds. + t, err := newTunGeneric(c, l, deviceFd, false, false, vpnNetworks) if err != nil { return nil, err } @@ -117,6 +127,18 @@ func tunSetIff(fd int, name string, flags uint16) (string, error) { return strings.Trim(string(req.Name[:]), "\x00"), nil } +// tsoOffloadFlags are the TUN_F_* bits we ask the kernel to enable when a +// TSO-capable TUN is available. CSUM is required as a prerequisite for TSO. +// TSO_ECN tells the kernel we propagate ECN correctly through coalesce and +// segmentation, so it can deliver superpackets whose seed has CWR/ECE set +// or whose IP-level codepoint is CE. +const tsoOffloadFlags = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6 | unix.TUN_F_TSO_ECN + +// usoOffloadFlags adds UDP Segmentation Offload to tsoOffloadFlags. Requires +// Linux ≥ 6.2; older kernels reject it and we fall back to TCP-only TSO via +// tsoOffloadFlags. +const usoOffloadFlags = tsoOffloadFlags | unix.TUN_F_USO4 | unix.TUN_F_USO6 + func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) { baseFlags := uint16(unix.IFF_TUN | unix.IFF_NO_PI) if multiqueue { @@ -124,17 +146,51 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, multiqueue } nameStr := c.GetString("tun.dev", "") + // First try to enable IFF_VNET_HDR via TUNSETIFF and negotiate TUN_F_* + // offloads via TUNSETOFFLOAD so we can receive TSO/USO superpackets. + // We try TSO+USO first, fall back to TSO-only on kernels without USO + // (Linux < 6.2), and finally give up on virtio headers entirely and + // reopen as a plain TUN if neither offload mask is accepted. fd, err := openTunDev() if err != nil { return nil, err } - name, err := tunSetIff(fd, nameStr, baseFlags) + vnetHdr := true + usoEnabled := false + name, err := tunSetIff(fd, nameStr, baseFlags|unix.IFF_VNET_HDR) if err != nil { _ = unix.Close(fd) - return nil, &NameError{Name: nameStr, Underlying: err} + vnetHdr = false + } else { + // Try TSO+USO first. On kernels without USO support (Linux < 6.2) + // the ioctl returns EINVAL; fall back to the TCP-only mask before + // giving up on VNET_HDR entirely. + if err = ioctl(uintptr(fd), unix.TUNSETOFFLOAD, uintptr(usoOffloadFlags)); err == nil { + usoEnabled = true + } else if err = ioctl(uintptr(fd), unix.TUNSETOFFLOAD, uintptr(tsoOffloadFlags)); err != nil { + l.Warn("Failed to enable TUN offload (TSO); proceeding without virtio headers", "error", err) + _ = unix.Close(fd) + vnetHdr = false + } } - t, err := newTunGeneric(c, l, fd, vpnNetworks) + if !vnetHdr { + fd, err = openTunDev() + if err != nil { + return nil, err + } + name, err = tunSetIff(fd, nameStr, baseFlags) + if err != nil { + _ = unix.Close(fd) + return nil, &NameError{Name: nameStr, Underlying: err} + } + } + + if vnetHdr { + l.Info("TUN offload enabled", "tso", true, "uso", usoEnabled) + } + + t, err := newTunGeneric(c, l, fd, vnetHdr, usoEnabled, vpnNetworks) if err != nil { return nil, err } @@ -145,25 +201,34 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, multiqueue } // newTunGeneric does all the stuff common to different tun initialization paths. It will close your files on error. -func newTunGeneric(c *config.C, l *slog.Logger, fd int, vpnNetworks []netip.Prefix) (*tun, error) { - container, err := tio.NewPollContainer() +func newTunGeneric(c *config.C, l *slog.Logger, fd int, vnetHdr, usoEnabled bool, vpnNetworks []netip.Prefix) (*tun, error) { + var qs tio.QueueSet + var err error + if vnetHdr { + qs, err = tio.NewOffloadQueueSet(usoEnabled) + } else { + qs, err = tio.NewPollQueueSet() + } + if err != nil { _ = unix.Close(fd) return nil, err } - err = container.Add(fd) + err = qs.Add(fd) if err != nil { _ = unix.Close(fd) return nil, err } t := &tun{ - readers: container, + readers: qs, closeLock: sync.Mutex{}, + vnetHdr: vnetHdr, vpnNetworks: vpnNetworks, TXQueueLen: c.GetInt("tun.tx_queue", 500), useSystemRoutes: c.GetBool("tun.use_system_route_table", false), useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0), + routeFeatureECN: c.GetBool("tunnels.ecn", true), routesFromSystem: map[netip.Prefix]routing.Gateways{}, l: l, } @@ -271,11 +336,21 @@ func (t *tun) NewMultiQueueReader() error { } flags := uint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_MULTI_QUEUE) + if t.vnetHdr { + flags |= unix.IFF_VNET_HDR + } if _, err = tunSetIff(fd, t.Device, flags); err != nil { _ = unix.Close(fd) return err } + if t.vnetHdr { + if err = ioctl(uintptr(fd), unix.TUNSETOFFLOAD, uintptr(tsoOffloadFlags)); err != nil { + _ = unix.Close(fd) + return fmt.Errorf("failed to enable offload on multiqueue tun fd: %w", err) + } + } + err = t.readers.Add(fd) if err != nil { _ = unix.Close(fd) @@ -450,6 +525,18 @@ func (t *tun) setDefaultRoute(cidr netip.Prefix) error { Table: unix.RT_TABLE_MAIN, Type: unix.RTN_UNICAST, } + // Match the metric the kernel uses for its auto-installed connected + // route, so RouteReplace overwrites it in place instead of adding a + // second route at a worse metric. IPv6 connected routes are installed + // at metric 256 (IP6_RT_PRIO_KERN); IPv4 uses 0. Without this, the + // kernel route wins lookups and our MTU / AdvMSS / Features never + // apply on v6. + if cidr.Addr().Is6() { + nr.Priority = 256 + } + if t.routeFeatureECN { + nr.Features |= unix.RTAX_FEATURE_ECN + } err := netlink.RouteReplace(&nr) if err != nil { t.l.Warn("Failed to set default route MTU, retrying", "error", err, "cidr", cidr) @@ -499,6 +586,9 @@ func (t *tun) addRoutes(logErrors bool) error { if r.Metric > 0 { nr.Priority = r.Metric } + if t.routeFeatureECN { + nr.Features |= unix.RTAX_FEATURE_ECN + } err := netlink.RouteReplace(&nr) if err != nil { diff --git a/overlay/tun_netbsd.go b/overlay/tun_netbsd.go index 4a0f502f..e8678959 100644 --- a/overlay/tun_netbsd.go +++ b/overlay/tun_netbsd.go @@ -68,15 +68,15 @@ type tun struct { fd int readBuf []byte - batchRet [1][]byte + batchRet [1]tio.Packet } -func (t *tun) Read() ([][]byte, error) { +func (t *tun) Read() ([]tio.Packet, error) { n, err := t.readOne(t.readBuf) if err != nil { return nil, err } - t.batchRet[0] = t.readBuf[:n] + t.batchRet[0] = tio.Packet{Bytes: t.readBuf[:n]} return t.batchRet[:], nil } diff --git a/overlay/tun_openbsd.go b/overlay/tun_openbsd.go index 2369a7c2..0e754732 100644 --- a/overlay/tun_openbsd.go +++ b/overlay/tun_openbsd.go @@ -61,15 +61,15 @@ type tun struct { out []byte readBuf []byte - batchRet [1][]byte + batchRet [1]tio.Packet } -func (t *tun) Read() ([][]byte, error) { +func (t *tun) Read() ([]tio.Packet, error) { n, err := t.readOne(t.readBuf) if err != nil { return nil, err } - t.batchRet[0] = t.readBuf[:n] + t.batchRet[0] = tio.Packet{Bytes: t.readBuf[:n]} return t.batchRet[:], nil } diff --git a/overlay/tun_tester.go b/overlay/tun_tester.go index 6240414f..898adc23 100644 --- a/overlay/tun_tester.go +++ b/overlay/tun_tester.go @@ -30,7 +30,7 @@ type TestTun struct { rxPackets chan []byte // Packets to receive into nebula TxPackets chan []byte // Packets transmitted outside by nebula - batchRet [1][]byte + batchRet [1]tio.Packet } func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) { @@ -51,7 +51,9 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*T l: l, rxPackets: make(chan []byte, 10), TxPackets: make(chan []byte, 10), - batchRet: [1][]byte{make([]byte, udp.MTU)}, + batchRet: [1]tio.Packet{ + tio.Packet{Bytes: make([]byte, udp.MTU)}, + }, }, nil } @@ -166,13 +168,13 @@ func (t *TestTun) Close() error { return nil } -func (t *TestTun) Read() ([][]byte, error) { - t.batchRet[0] = t.batchRet[0][:udp.MTU] - n, err := t.read(t.batchRet[0]) +func (t *TestTun) Read() ([]tio.Packet, error) { + t.batchRet[0].Bytes = t.batchRet[0].Bytes[:udp.MTU] + n, err := t.read(t.batchRet[0].Bytes) if err != nil { return nil, err } - t.batchRet[0] = t.batchRet[0][:n] + t.batchRet[0].Bytes = t.batchRet[0].Bytes[:n] return t.batchRet[:], nil } diff --git a/overlay/tun_windows.go b/overlay/tun_windows.go index 1aa3cb27..a5ee063c 100644 --- a/overlay/tun_windows.go +++ b/overlay/tun_windows.go @@ -47,15 +47,15 @@ type winTun struct { tun *wintun.NativeTun readBuf []byte - batchRet [1][]byte + batchRet [1]tio.Packet } -func (t *winTun) Read() ([][]byte, error) { +func (t *winTun) Read() ([]tio.Packet, error) { n, err := t.tun.Read(t.readBuf, 0) if err != nil { return nil, err } - t.batchRet[0] = t.readBuf[:n] + t.batchRet[0] = tio.Packet{Bytes: t.readBuf[:n]} return t.batchRet[:], nil } diff --git a/overlay/user.go b/overlay/user.go index be6b327b..f3cf5adb 100644 --- a/overlay/user.go +++ b/overlay/user.go @@ -39,10 +39,10 @@ type UserDevice struct { inboundWriter *io.PipeWriter readBuf []byte - batchRet [1][]byte + batchRet [1]tio.Packet } -func (d *UserDevice) Read() ([][]byte, error) { +func (d *UserDevice) Read() ([]tio.Packet, error) { if d.readBuf == nil { d.readBuf = make([]byte, defaultBatchBufSize) } @@ -50,7 +50,7 @@ func (d *UserDevice) Read() ([][]byte, error) { if err != nil { return nil, err } - d.batchRet[0] = d.readBuf[:n] + d.batchRet[0] = tio.Packet{Bytes: d.readBuf[:n]} return d.batchRet[:], nil } diff --git a/udp/conn.go b/udp/conn.go index 14902a76..37277054 100644 --- a/udp/conn.go +++ b/udp/conn.go @@ -11,12 +11,25 @@ const MTU = 9001 // MaxWriteBatch is the largest batch any Conn.WriteBatch implementation is // required to accept. Callers SHOULD NOT pass more than this per call; Linux // backends preallocate sendmmsg scratch sized to this value, so exceeding it -// only costs a chunked retry. +// only costs additional sendmmsg chunks within a single WriteBatch call. const MaxWriteBatch = 128 +// RxMeta carries per-packet metadata extracted from the RX path (ancillary +// data, kernel offload state, etc.) and passed to EncReader callbacks. +// Backends that do not produce a particular signal leave its zero value. +// +// OuterECN is the 2-bit IP-level ECN codepoint stamped on the carrier +// datagram (extracted from IP_TOS / IPV6_TCLASS cmsg on Linux). Zero +// means Not-ECT, which is also the value backends without ECN RX support +// supply on every packet. +type RxMeta struct { + OuterECN byte +} + type EncReader func( addr netip.AddrPort, payload []byte, + meta RxMeta, ) type Conn interface { @@ -30,11 +43,14 @@ type Conn interface { ListenOut(r EncReader, flush func()) error WriteTo(b []byte, addr netip.AddrPort) error // WriteBatch sends a contiguous batch of packets, each with its own - // destination. bufs and addrs must have the same length. Linux uses - // sendmmsg(2) for a single syscall; other backends fall back to a - // WriteTo loop. Returns on the first error; callers may observe a - // partial send if some packets went out before the error. - WriteBatch(bufs [][]byte, addrs []netip.AddrPort) error + // destination. bufs and addrs must have the same length. outerECNs may + // be nil (treated as all-zero / Not-ECT); when non-nil it must have the + // same length as bufs, and outerECNs[i] is the 2-bit IP-level ECN + // codepoint to set on packet i's outer header. Linux uses sendmmsg(2) + // for a single syscall and attaches the value as IP_TOS / IPV6_TCLASS + // cmsg; other backends ignore it. Returns on the first error; callers + // may observe a partial send if some packets went out before the error. + WriteBatch(bufs [][]byte, addrs []netip.AddrPort, outerECNs []byte) error ReloadConfig(c *config.C) SupportsMultipleReaders() bool Close() error @@ -57,7 +73,7 @@ func (NoopConn) SupportsMultipleReaders() bool { func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error { return nil } -func (NoopConn) WriteBatch(_ [][]byte, _ []netip.AddrPort) error { +func (NoopConn) WriteBatch(_ [][]byte, _ []netip.AddrPort, _ []byte) error { return nil } func (NoopConn) ReloadConfig(_ *config.C) { diff --git a/udp/raw_sendmmsg_linux.go b/udp/raw_sendmmsg_linux.go new file mode 100644 index 00000000..ae4dcdc6 --- /dev/null +++ b/udp/raw_sendmmsg_linux.go @@ -0,0 +1,62 @@ +//go:build !android && !e2e_testing +// +build !android,!e2e_testing + +package udp + +import ( + "net" + "syscall" + "unsafe" + + "golang.org/x/sys/unix" +) + +// rawSendmmsg performs sendmmsg(2) over a syscall.RawConn without +// allocating a closure per call. The struct holds preallocated in/out +// scratch (chunk/sent/errno) and a method-value bound at construction so +// rawConn.Write receives a stable function pointer instead of a fresh +// closure on every send. +type rawSendmmsg struct { + msgs []rawMessage + chunk int + sent int + errno syscall.Errno + callback func(fd uintptr) bool +} + +// bind wires r.callback to r.run. Must be called once after r.msgs is set; +// subsequent send calls invoke r.callback without rebinding. +func (r *rawSendmmsg) bind() { r.callback = r.run } + +// run is the preallocated callback rawConn.Write invokes. It reads its +// input (r.chunk) and writes its outputs (r.sent, r.errno) through the +// rawSendmmsg fields so the method value does not capture per-call locals +// and therefore does not heap-allocate. +func (r *rawSendmmsg) run(fd uintptr) bool { + r1, _, errno := unix.Syscall6(unix.SYS_SENDMMSG, fd, + uintptr(unsafe.Pointer(&r.msgs[0])), uintptr(r.chunk), + 0, 0, 0, + ) + if errno == syscall.EAGAIN || errno == syscall.EWOULDBLOCK { + return false + } + r.sent = int(r1) + r.errno = errno + return true +} + +// send issues sendmmsg over rc against the first n entries of r.msgs. +// Returns the number of entries the kernel processed and any error; +// matches the original sendmmsg helper's contract. +func (r *rawSendmmsg) send(rc syscall.RawConn, n int) (int, error) { + r.chunk = n + r.sent = 0 + r.errno = 0 + if err := rc.Write(r.callback); err != nil { + return r.sent, err + } + if r.errno != 0 { + return r.sent, &net.OpError{Op: "sendmmsg", Err: r.errno} + } + return r.sent, nil +} diff --git a/udp/rx_reorder_linux.go b/udp/rx_reorder_linux.go new file mode 100644 index 00000000..7887b9f9 --- /dev/null +++ b/udp/rx_reorder_linux.go @@ -0,0 +1,86 @@ +//go:build !android && !e2e_testing +// +build !android,!e2e_testing + +package udp + +import ( + "cmp" + "net/netip" + "slices" +) + +// rxSegment is one nebula packet pulled out of a recvmmsg entry — either a +// lone datagram or one segment of a GRO superpacket. cnt is the big-endian +// uint64 message counter at bytes [8:16] of the nebula header; 0 if the +// segment is too short to contain a header. ecn is the 2-bit IP-level ECN +// codepoint stamped on the carrier (one value per slot, since GRO requires +// equal ECN across coalesced datagrams). +type rxSegment struct { + src netip.AddrPort + cnt uint64 + buf []byte + ecn byte +} + +// rxReorderBuffer accumulates one recvmmsg batch worth of segments, +// splits any GRO superpackets at gso_size boundaries, stable-sorts by +// (src, port, counter), then delivers in order. The reorder distance is +// bounded by len(buf), which the caller sizes to stay well within the +// receiver's ReplayWindow so older arrivals are not rejected as replays. +type rxReorderBuffer struct { + buf []rxSegment +} + +func newRxReorderBuffer(initialCap int) *rxReorderBuffer { + return &rxReorderBuffer{buf: make([]rxSegment, 0, initialCap)} +} + +// reset prepares the buffer for the next recvmmsg batch. +func (r *rxReorderBuffer) reset() { r.buf = r.buf[:0] } + +// addEntry expands one recvmmsg slot into rxSegments. When segSize <= 0 or +// segSize >= len(payload) the payload is appended as a single segment; +// otherwise the kernel-coalesced GRO superpacket is split at segSize +// boundaries (the kernel guarantees every segment is exactly segSize bytes +// except for the final one, which may be short). ecn applies uniformly to +// every produced segment because GRO requires equal ECN across coalesced +// datagrams. +func (r *rxReorderBuffer) addEntry(from netip.AddrPort, payload []byte, segSize int, ecn byte) { + if segSize <= 0 || segSize >= len(payload) { + r.buf = append(r.buf, rxSegment{from, headerCounter(payload), payload, ecn}) + return + } + for off := 0; off < len(payload); off += segSize { + end := off + segSize + if end > len(payload) { + end = len(payload) + } + seg := payload[off:end] + r.buf = append(r.buf, rxSegment{from, headerCounter(seg), seg, ecn}) + } +} + +// sortStable orders the accumulated segments by (src addr, src port, +// counter). Same-source segments are reordered into counter order; +// cross-source relative order is determined by a stable address compare so +// the sort is total and predictable. +func (r *rxReorderBuffer) sortStable() { + slices.SortStableFunc(r.buf, func(a, b rxSegment) int { + if c := a.src.Addr().Compare(b.src.Addr()); c != 0 { + return c + } + if c := cmp.Compare(a.src.Port(), b.src.Port()); c != 0 { + return c + } + return cmp.Compare(a.cnt, b.cnt) + }) +} + +// deliver invokes fn once per segment in sorted order, then nils the +// per-entry buf reference so the next batch's append doesn't alias it. +func (r *rxReorderBuffer) deliver(fn EncReader) { + for k := range r.buf { + fn(r.buf[k].src, r.buf[k].buf, RxMeta{OuterECN: r.buf[k].ecn}) + r.buf[k].buf = nil + } +} diff --git a/udp/rx_reorder_linux_test.go b/udp/rx_reorder_linux_test.go new file mode 100644 index 00000000..180e2e4a --- /dev/null +++ b/udp/rx_reorder_linux_test.go @@ -0,0 +1,203 @@ +//go:build !android && !e2e_testing +// +build !android,!e2e_testing + +package udp + +import ( + "encoding/binary" + "net/netip" + "testing" +) + +// makeNebulaPkt returns a buffer whose [8:16] bytes encode the given +// counter big-endian, the rest left zero. Anything shorter than 16 bytes +// would yield counter 0; tests use this to simulate well-formed nebula +// headers (the rxReorderBuffer doesn't care about anything else). +func makeNebulaPkt(cnt uint64, payLen int) []byte { + if payLen < 16 { + payLen = 16 + } + b := make([]byte, payLen) + binary.BigEndian.PutUint64(b[8:16], cnt) + return b +} + +func srcOf(addr string, port uint16) netip.AddrPort { + return netip.AddrPortFrom(netip.MustParseAddr(addr), port) +} + +func TestRxReorderBuffer_LonePassesThrough(t *testing.T) { + r := newRxReorderBuffer(8) + pkt := makeNebulaPkt(42, 100) + r.addEntry(srcOf("1.1.1.1", 4242), pkt, 0, 0x02) + + if got := len(r.buf); got != 1 { + t.Fatalf("want 1 entry, got %d", got) + } + if r.buf[0].cnt != 42 { + t.Errorf("counter=%d want 42", r.buf[0].cnt) + } + if r.buf[0].ecn != 0x02 { + t.Errorf("ecn=%#x want 0x02", r.buf[0].ecn) + } + if len(r.buf[0].buf) != 100 { + t.Errorf("buf len=%d want 100", len(r.buf[0].buf)) + } +} + +func TestRxReorderBuffer_SegSizeGEPayloadIsLone(t *testing.T) { + // segSize >= len(payload) means the kernel did not coalesce this slot. + r := newRxReorderBuffer(8) + pkt := makeNebulaPkt(7, 50) + r.addEntry(srcOf("1.1.1.1", 1), pkt, 50, 0) + if got := len(r.buf); got != 1 { + t.Fatalf("segSize==len: want 1 entry, got %d", got) + } + r.reset() + r.addEntry(srcOf("1.1.1.1", 1), pkt, 60, 0) + if got := len(r.buf); got != 1 { + t.Fatalf("segSize>len: want 1 entry, got %d", got) + } +} + +func TestRxReorderBuffer_GROSplitExactMultiple(t *testing.T) { + // 3 segments of 80 bytes each, packed into one 240-byte GRO superpacket. + const segSize = 80 + const numSeg = 3 + pkt := make([]byte, segSize*numSeg) + for i := range numSeg { + off := i * segSize + binary.BigEndian.PutUint64(pkt[off+8:off+16], uint64(100+i)) + } + + r := newRxReorderBuffer(8) + r.addEntry(srcOf("2.2.2.2", 5555), pkt, segSize, 0x03) + if got := len(r.buf); got != numSeg { + t.Fatalf("want %d segments, got %d", numSeg, got) + } + for i, seg := range r.buf { + if seg.cnt != uint64(100+i) { + t.Errorf("seg %d: cnt=%d want %d", i, seg.cnt, 100+i) + } + if len(seg.buf) != segSize { + t.Errorf("seg %d: buf len=%d want %d", i, len(seg.buf), segSize) + } + if seg.ecn != 0x03 { + t.Errorf("seg %d: ecn=%#x want 0x03 (uniform across GRO)", i, seg.ecn) + } + } +} + +func TestRxReorderBuffer_GROSplitShortFinal(t *testing.T) { + // 200-byte payload, segSize=80 → segments of 80, 80, 40. + const segSize = 80 + pkt := make([]byte, 200) + binary.BigEndian.PutUint64(pkt[8:16], 1) + binary.BigEndian.PutUint64(pkt[80+8:80+16], 2) + binary.BigEndian.PutUint64(pkt[160+8:160+16], 3) + + r := newRxReorderBuffer(8) + r.addEntry(srcOf("3.3.3.3", 1), pkt, segSize, 0) + if got := len(r.buf); got != 3 { + t.Fatalf("want 3 segments, got %d", got) + } + wantLens := []int{80, 80, 40} + for i, seg := range r.buf { + if len(seg.buf) != wantLens[i] { + t.Errorf("seg %d: len=%d want %d", i, len(seg.buf), wantLens[i]) + } + } +} + +func TestRxReorderBuffer_SortGroupsBySrcThenCounter(t *testing.T) { + r := newRxReorderBuffer(8) + a := srcOf("1.1.1.1", 1) + b := srcOf("2.2.2.2", 1) + // Insert deliberately scrambled. + r.addEntry(a, makeNebulaPkt(3, 16), 0, 0) + r.addEntry(b, makeNebulaPkt(1, 16), 0, 0) + r.addEntry(a, makeNebulaPkt(1, 16), 0, 0) + r.addEntry(b, makeNebulaPkt(2, 16), 0, 0) + r.addEntry(a, makeNebulaPkt(2, 16), 0, 0) + + r.sortStable() + + want := []struct { + src netip.AddrPort + cnt uint64 + }{ + {a, 1}, {a, 2}, {a, 3}, {b, 1}, {b, 2}, + } + if got := len(r.buf); got != len(want) { + t.Fatalf("len=%d want %d", got, len(want)) + } + for i, w := range want { + if r.buf[i].src != w.src || r.buf[i].cnt != w.cnt { + t.Errorf("idx %d: got %v/%d want %v/%d", + i, r.buf[i].src, r.buf[i].cnt, w.src, w.cnt) + } + } +} + +func TestRxReorderBuffer_SortStableAcrossPorts(t *testing.T) { + // Same source addr but different ports — must group by port. + r := newRxReorderBuffer(8) + addr := netip.MustParseAddr("4.4.4.4") + p1 := netip.AddrPortFrom(addr, 1) + p2 := netip.AddrPortFrom(addr, 2) + r.addEntry(p2, makeNebulaPkt(10, 16), 0, 0) + r.addEntry(p1, makeNebulaPkt(20, 16), 0, 0) + r.addEntry(p2, makeNebulaPkt(5, 16), 0, 0) + + r.sortStable() + + // Expect: p1/20 then p2/5 then p2/10. + if r.buf[0].src.Port() != 1 || r.buf[1].src.Port() != 2 || r.buf[2].src.Port() != 2 { + t.Fatalf("port order broken: %v %v %v", + r.buf[0].src.Port(), r.buf[1].src.Port(), r.buf[2].src.Port()) + } + if r.buf[1].cnt != 5 || r.buf[2].cnt != 10 { + t.Errorf("counter order in p2: %d %d (want 5 10)", r.buf[1].cnt, r.buf[2].cnt) + } +} + +func TestRxReorderBuffer_DeliverInOrderAndNilsRefs(t *testing.T) { + r := newRxReorderBuffer(4) + a := srcOf("5.5.5.5", 1) + r.addEntry(a, makeNebulaPkt(2, 32), 0, 0x01) + r.addEntry(a, makeNebulaPkt(1, 32), 0, 0x01) + r.sortStable() + + var seenCnts []uint64 + var seenECN []byte + r.deliver(func(src netip.AddrPort, buf []byte, meta RxMeta) { + seenCnts = append(seenCnts, binary.BigEndian.Uint64(buf[8:16])) + seenECN = append(seenECN, meta.OuterECN) + }) + + if len(seenCnts) != 2 || seenCnts[0] != 1 || seenCnts[1] != 2 { + t.Errorf("delivery order broken: %v", seenCnts) + } + if seenECN[0] != 0x01 || seenECN[1] != 0x01 { + t.Errorf("ecn passed wrong: %v", seenECN) + } + for i := range r.buf { + if r.buf[i].buf != nil { + t.Errorf("buf[%d].buf not nil after deliver", i) + } + } +} + +func TestRxReorderBuffer_ResetIsReusable(t *testing.T) { + r := newRxReorderBuffer(2) + r.addEntry(srcOf("6.6.6.6", 1), makeNebulaPkt(1, 16), 0, 0) + r.addEntry(srcOf("6.6.6.6", 1), makeNebulaPkt(2, 16), 0, 0) + r.reset() + if got := len(r.buf); got != 0 { + t.Fatalf("after reset len=%d want 0", got) + } + r.addEntry(srcOf("6.6.6.6", 1), makeNebulaPkt(7, 16), 0, 0) + if r.buf[0].cnt != 7 { + t.Errorf("after reset+add: cnt=%d want 7", r.buf[0].cnt) + } +} diff --git a/udp/udp_darwin.go b/udp/udp_darwin.go index 2468c6c4..e6ecea8f 100644 --- a/udp/udp_darwin.go +++ b/udp/udp_darwin.go @@ -140,7 +140,7 @@ func (u *StdConn) WriteTo(b []byte, ap netip.AddrPort) error { } } -func (u *StdConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort) error { +func (u *StdConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort, _ []byte) error { for i, b := range bufs { if err := u.WriteTo(b, addrs[i]); err != nil { return err @@ -188,7 +188,7 @@ func (u *StdConn) ListenOut(r EncReader, flush func()) error { u.l.Error("unexpected udp socket receive error", "error", err) } - r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n]) + r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n], RxMeta{}) flush() } } diff --git a/udp/udp_ecn_outer_linux_test.go b/udp/udp_ecn_outer_linux_test.go new file mode 100644 index 00000000..b77b61c4 --- /dev/null +++ b/udp/udp_ecn_outer_linux_test.go @@ -0,0 +1,61 @@ +//go:build linux && !android && !e2e_testing + +package udp + +import ( + "net/netip" + "testing" +) + +// TestPlanRunBreaksOnECNChange confirms that two same-destination, same-size +// packets with different outer ECN end up in separate sendmmsg entries (the +// kernel stamps one outer codepoint per entry, so a run that straddled the +// boundary would silently lose information). +func TestPlanRunBreaksOnECNChange(t *testing.T) { + u := &StdConn{gsoSupported: true} + dst := netip.MustParseAddrPort("10.0.0.1:4242") + + bufs := [][]byte{ + make([]byte, 1200), + make([]byte, 1200), + make([]byte, 1200), + } + addrs := []netip.AddrPort{dst, dst, dst} + + t.Run("uniform_ecn_runs_together", func(t *testing.T) { + ecns := []byte{0x02, 0x02, 0x02} + runLen, segSize := u.planRun(bufs, addrs, ecns, 0, 64) + if runLen != 3 { + t.Errorf("runLen=%d want 3 (uniform ECT(0))", runLen) + } + if segSize != 1200 { + t.Errorf("segSize=%d want 1200", segSize) + } + }) + + t.Run("ecn_change_truncates_run", func(t *testing.T) { + // 0,0,3: first two run together, CE seeds a fresh entry. + ecns := []byte{0x00, 0x00, 0x03} + runLen, _ := u.planRun(bufs, addrs, ecns, 0, 64) + if runLen != 2 { + t.Errorf("runLen=%d want 2 (ECN changes at index 2)", runLen) + } + }) + + t.Run("nil_ecns_runs_full", func(t *testing.T) { + runLen, _ := u.planRun(bufs, addrs, nil, 0, 64) + if runLen != 3 { + t.Errorf("runLen=%d want 3 (nil ecns means no break)", runLen) + } + }) + + t.Run("first_ecn_is_singleton", func(t *testing.T) { + // Second packet has different ECN from the first → run halts at 1 + // (the first packet alone forms the run). + ecns := []byte{0x00, 0x03, 0x03} + runLen, _ := u.planRun(bufs, addrs, ecns, 0, 64) + if runLen != 1 { + t.Errorf("runLen=%d want 1 (different ECN immediately)", runLen) + } + }) +} diff --git a/udp/udp_generic.go b/udp/udp_generic.go index c0dacedb..0c254906 100644 --- a/udp/udp_generic.go +++ b/udp/udp_generic.go @@ -44,7 +44,7 @@ func (u *GenericConn) WriteTo(b []byte, addr netip.AddrPort) error { return err } -func (u *GenericConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort) error { +func (u *GenericConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort, _ []byte) error { for i, b := range bufs { if _, err := u.UDPConn.WriteToUDPAddrPort(b, addrs[i]); err != nil { return err @@ -102,7 +102,7 @@ func (u *GenericConn) ListenOut(r EncReader, flush func()) error { continue } - r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n]) + r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n], RxMeta{}) flush() } } diff --git a/udp/udp_linux.go b/udp/udp_linux.go index ec840426..5ae5847b 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -24,6 +24,58 @@ type StdConn struct { isV4 bool l *slog.Logger batch int + + // sendmmsg scratch. Each queue has its own StdConn, so no locking is + // needed. Sized to MaxWriteBatch at construction; WriteBatch chunks + // larger inputs. + writeMsgs []rawMessage + writeIovs []iovec + writeNames [][]byte + + // Per-entry cmsg scratch. writeCmsg is one contiguous slab of + // MaxWriteBatch * writeCmsgSpace bytes; each entry holds two cmsg + // headers (UDP_SEGMENT then IP_TOS / IPV6_TCLASS) pre-filled once in + // prepareWriteMessages. WriteBatch only rewrites the per-call data + // payloads and toggles Hdr.Control / Hdr.Controllen to point at + // whichever subset of the two cmsgs applies. + writeCmsg []byte + writeCmsgSpace int + writeCmsgSegSpace int + writeCmsgEcnSpace int + + // writeEntryEnd[e] is the bufs index *after* the last packet packed + // into mmsghdr entry e. Used to rewind `i` on partial sendmmsg success. + writeEntryEnd []int + + // rawSend wraps the sendmmsg(2) callback in a closure-free helper so + // the hot path doesn't heap-allocate a fresh closure per call. + rawSend rawSendmmsg + + // UDP GSO (sendmsg with UDP_SEGMENT cmsg) support. gsoSupported is + // probed once at socket creation. When true, WriteBatch packs same- + // destination consecutive packets into a single sendmmsg entry with a + // UDP_SEGMENT cmsg; otherwise each packet is its own entry. + gsoSupported bool + + // UDP GRO (recvmsg with UDP_GRO cmsg) support. groSupported is probed + // once at socket creation. When true, listenOutBatch allocates larger + // RX buffers and a per-entry cmsg slot so the kernel can coalesce + // consecutive same-flow datagrams into a single recvmmsg entry; the + // delivered cmsg carries the gso_size used to split them back apart. + groSupported bool + + // ecnRecvSupported is true when IP_RECVTOS / IPV6_RECVTCLASS was + // successfully enabled — the kernel will deliver the outer IP-ECN of + // each arriving datagram as a per-slot cmsg, and listenOutBatch passes + // the parsed value to the EncReader callback for RFC 6040 combine. + ecnRecvSupported bool + + // rxOrder is the per-batch scratch listenOutBatch uses to gather every + // segment in a recvmmsg call (after splitting GRO superpackets) and + // stable-sort by (source, message-counter) before delivery. Reordering + // fits within the receiver's replay window so briefly out-of-order + // arrivals do not get rejected as replays. + rxOrder *rxReorderBuffer } func setReusePort(network, address string, c syscall.RawConn) error { @@ -70,9 +122,196 @@ func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) } out.isV4 = af == unix.AF_INET + out.prepareWriteMessages(MaxWriteBatch) + out.rawSend.msgs = out.writeMsgs + out.rawSend.bind() + + out.prepareGSO() + // GRO delivers coalesced superpackets that need a cmsg to split back + // into segments. The single-packet RX path uses ReadFromUDPAddrPort + // and cannot see that cmsg, so only enable GRO for the batch path. + if batch > 1 { + out.prepareGRO() + } + // Best-effort: ask the kernel to deliver outer IP-ECN as ancillary data + // on every recvmmsg slot so the decap side can apply RFC 6040 combine. + // On older kernels these may not exist; failing here just means we get + // 0 (Not-ECT) on every slot, which is the same as ecn_mode=disable. + out.prepareECNRecv() + return out, nil } +// prepareWriteMessages allocates one mmsghdr/iovec/sockaddr/cmsg scratch +// slot per sendmmsg entry. The iovec slab is sized to n so all entries' +// iovecs share one allocation; per-entry fan-out is further capped at +// maxGSOSegments. Hdr.Iov / Hdr.Iovlen / Hdr.Control / Hdr.Controllen are +// wired per call since each entry can span a variable number of iovecs +// and may or may not carry a cmsg. +// +// Per-mmsghdr cmsg layout. Each entry's slot of length writeCmsgSpace holds +// up to two cmsg headers placed at fixed offsets: +// +// [0 .. writeCmsgSegSpace) UDP_SEGMENT (gso_size, uint16) +// [writeCmsgSegSpace .. writeCmsgSpace) IP_TOS or IPV6_TCLASS (int32) +// +// Both headers are pre-filled once here; per-call we only rewrite the data +// payload and toggle Hdr.Control / Hdr.Controllen to point at whichever +// subset applies (none / segment-only / ecn-only / both). +func (u *StdConn) prepareWriteMessages(n int) { + u.writeMsgs = make([]rawMessage, n) + u.writeIovs = make([]iovec, n) + u.writeNames = make([][]byte, n) + u.writeEntryEnd = make([]int, n) + + u.writeCmsgSegSpace = unix.CmsgSpace(2) + u.writeCmsgEcnSpace = unix.CmsgSpace(4) + u.writeCmsgSpace = u.writeCmsgSegSpace + u.writeCmsgEcnSpace + u.writeCmsg = make([]byte, n*u.writeCmsgSpace) + + ecnLevel := int32(unix.IPPROTO_IP) + ecnType := int32(unix.IP_TOS) + if !u.isV4 { + ecnLevel = unix.IPPROTO_IPV6 + ecnType = unix.IPV6_TCLASS + } + + for k := 0; k < n; k++ { + base := k * u.writeCmsgSpace + seg := (*unix.Cmsghdr)(unsafe.Pointer(&u.writeCmsg[base])) + seg.Level = unix.SOL_UDP + seg.Type = unix.UDP_SEGMENT + setCmsgLen(seg, unix.CmsgLen(2)) + + ecn := (*unix.Cmsghdr)(unsafe.Pointer(&u.writeCmsg[base+u.writeCmsgSegSpace])) + ecn.Level = ecnLevel + ecn.Type = ecnType + setCmsgLen(ecn, unix.CmsgLen(4)) + } + + for i := range u.writeMsgs { + u.writeNames[i] = make([]byte, unix.SizeofSockaddrInet6) + u.writeMsgs[i].Hdr.Name = &u.writeNames[i][0] + } +} + +// maxGSOSegments caps the per-sendmsg GSO fan-out. Linux kernels have +// historically capped UDP_MAX_SEGMENTS at 64; newer kernels raise it to 128. +// We stay one below 64 because the kernel's check is +// +// if (cork->length > cork->gso_size * UDP_MAX_SEGMENTS) return -EINVAL; +// +// and cork->length includes the 8-byte UDP header (udp_sendmsg passes +// ulen = len + sizeof(udphdr) to ip_append_data). Packing exactly 64 +// same-size segments puts cork->length at gso_size*64 + 8, which is one +// UDP-header over the bound and the kernel rejects the whole sendmmsg +// with EINVAL. 63 leaves room for the header for any segSize >= 8. +const maxGSOSegments = 63 + +// maxGSOBytes bounds the total payload per sendmsg() when UDP_SEGMENT is +// set. The kernel stitches all iovecs into a single skb whose length the +// UDP length field can represent, and also enforces sk_gso_max_size (which +// on most devices is 65536). We use 65000 to leave headroom under the +// 65535 UDP-length cap, avoiding EMSGSIZE on large TSO superpackets. +const maxGSOBytes = 65000 + +// prepareGSO probes UDP_SEGMENT support and sets u.gsoSupported on success. +// Best-effort; failure leaves it false. +func (u *StdConn) prepareGSO() { + var probeErr error + if err := u.rawConn.Control(func(fd uintptr) { + probeErr = unix.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT, 0) + }); err != nil { + u.l.Info("udp: GSO disabled", "reason", "rawconn control failed", "error", err) + recordCapability("udp.gso.enabled", false) + return + } + if probeErr != nil { + u.l.Info("udp: GSO disabled", "reason", "kernel rejected probe", "error", probeErr) + recordCapability("udp.gso.enabled", false) + return + } + u.gsoSupported = true + u.l.Info("udp: GSO enabled") + recordCapability("udp.gso.enabled", true) +} + +// udpGROBufferSize sizes the per-entry recvmmsg buffer when UDP_GRO is on. +// The kernel stitches a run of same-flow datagrams into a single skb whose +// length is bounded by sk_gso_max_size (typically 65535); anything larger +// would be MSG_TRUNCed. We use the maximum representable UDP length so a +// full superpacket always lands intact. +const udpGROBufferSize = 65535 + +// udpGROCmsgPayload is the size of the UDP_GRO cmsg data delivered by the +// kernel: a single int (gso_size in bytes). See udp_cmsg_recv() in +// net/ipv4/udp.c. +const udpGROCmsgPayload = 4 + +// prepareGRO turns on UDP_GRO so the kernel coalesces consecutive same-flow +// datagrams into one recvmmsg entry, with a cmsg carrying the gso_size used +// to split them back apart on the application side. +func (u *StdConn) prepareGRO() { + var probeErr error + if err := u.rawConn.Control(func(fd uintptr) { + probeErr = unix.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1) + }); err != nil { + u.l.Info("udp: GRO disabled", "reason", "rawconn control failed", "error", err) + recordCapability("udp.gro.enabled", false) + return + } + if probeErr != nil { + u.l.Info("udp: GRO disabled", "reason", "kernel rejected probe", "error", probeErr) + recordCapability("udp.gro.enabled", false) + return + } + u.groSupported = true + u.l.Info("udp: GRO enabled") + recordCapability("udp.gro.enabled", true) +} + +// prepareECNRecv turns on IP_RECVTOS / IPV6_RECVTCLASS so the outer IP-ECN +// field of each arriving datagram is delivered as ancillary data alongside +// the payload. listenOutBatch reads it via parseRecvCmsg and passes the +// codepoint through the EncReader for RFC 6040 combine on the decap side. +// Best-effort: we keep going on failure. +func (u *StdConn) prepareECNRecv() { + var probeErr error + if err := u.rawConn.Control(func(fd uintptr) { + if u.isV4 { + probeErr = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_RECVTOS, 1) + } else { + probeErr = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVTCLASS, 1) + } + }); err != nil { + u.l.Info("udp: outer-ECN RX disabled", "reason", "rawconn control failed", "error", err) + recordCapability("udp.ecn_rx.enabled", false) + return + } + if probeErr != nil { + u.l.Info("udp: outer-ECN RX disabled", "reason", "kernel rejected probe", "error", probeErr) + recordCapability("udp.ecn_rx.enabled", false) + return + } + u.ecnRecvSupported = true + u.l.Info("udp: outer-ECN RX enabled") + recordCapability("udp.ecn_rx.enabled", true) +} + +// recordCapability registers (or updates) a boolean gauge for one of the +// kernel-feature probes. Gauges go to 1 when the feature is enabled, 0 when +// it is not — dashboards can show degraded state on partially-supported +// kernels at a glance. Calling repeatedly with the same name updates the +// existing gauge rather than registering a duplicate. +func recordCapability(name string, enabled bool) { + g := metrics.GetOrRegisterGauge(name, nil) + if enabled { + g.Update(1) + } else { + g.Update(0) + } +} + func (u *StdConn) SupportsMultipleReaders() bool { return true } @@ -183,7 +422,10 @@ func (u *StdConn) listenOutSingle(r EncReader, flush func()) error { return err } from = netip.AddrPortFrom(from.Addr().Unmap(), from.Port()) - r(from, buffer[:n]) + // listenOutSingle uses ReadFromUDPAddrPort which discards cmsgs, + // so the outer ECN field is not visible on this path. Zero RxMeta + // (Not-ECT) means RFC 6040 combine is a no-op. + r(from, buffer[:n], RxMeta{}) flush() } } @@ -194,7 +436,22 @@ func (u *StdConn) listenOutBatch(r EncReader, flush func()) error { var operr error bufSize := MTU - msgs, buffers, names := u.PrepareRawMessages(u.batch, bufSize) + cmsgSpace := 0 + if u.groSupported { + bufSize = udpGROBufferSize + cmsgSpace = unix.CmsgSpace(udpGROCmsgPayload) + } + if u.ecnRecvSupported { + // IP_TOS arrives as 1 byte; IPV6_TCLASS arrives as a 4-byte int. + // Reserve enough for the wider of the two so the same buffer fits + // either family alongside any UDP_GRO cmsg. + cmsgSpace += unix.CmsgSpace(4) + } + msgs, buffers, names, _ := u.PrepareRawMessages(u.batch, bufSize, cmsgSpace) + + if u.rxOrder == nil { + u.rxOrder = newRxReorderBuffer(u.batch * 64) + } //reader needs to capture variables from this function, since it's used as a lambda with rawConn.Read //defining it outside the loop so it gets re-used @@ -204,6 +461,11 @@ func (u *StdConn) listenOutBatch(r EncReader, flush func()) error { } for { + if cmsgSpace > 0 { + for i := range msgs { + setMsgControllen(&msgs[i].Hdr, cmsgSpace) + } + } err := u.rawConn.Read(reader) if err != nil { return err @@ -212,6 +474,9 @@ func (u *StdConn) listenOutBatch(r EncReader, flush func()) error { return operr } + // Phase 1: gather every segment from this recvmmsg into rxOrder, + // splitting GRO superpackets into their constituent segments. + u.rxOrder.reset() for i := 0; i < n; i++ { // Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic if u.isV4 { @@ -222,14 +487,77 @@ func (u *StdConn) listenOutBatch(r EncReader, flush func()) error { from := netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])) payload := buffers[i][:msgs[i].Len] - r(from, payload) + segSize := 0 + outerECN := byte(0) + if cmsgSpace > 0 { + segSize, outerECN = parseRecvCmsg(&msgs[i].Hdr, u.groSupported, u.ecnRecvSupported, u.isV4) + } + u.rxOrder.addEntry(from, payload, segSize, outerECN) } + + // Phase 2 + 3: stable-sort by (src, port, counter), then deliver in + // order. Reorder distance is bounded by len(u.rxOrder.buf), which + // stays well within the receiver's ReplayWindow (currently 8192) so + // older arrivals are not rejected as replays. + u.rxOrder.sortStable() + u.rxOrder.deliver(r) // End-of-batch: let callers (e.g. TUN write coalescer) flush any // state they accumulated across this batch. flush() } } +// headerCounter returns the big-endian uint64 message counter at bytes +// [8:16] of a nebula packet, or 0 if the buffer is too short. +func headerCounter(buf []byte) uint64 { + if len(buf) < 16 { + return 0 + } + return binary.BigEndian.Uint64(buf[8:16]) +} + +// parseRecvCmsg walks the per-slot ancillary buffer once and extracts up to +// two values of interest in a single pass: the UDP_GRO gso_size (when +// wantGRO is true) and the outer IP-level ECN codepoint stamped on the +// carrier (when wantECN is true). Returns zeros for whichever field is not +// requested or not present. isV4 selects between IP_TOS (1-byte) and +// IPV6_TCLASS (4-byte int) cmsg payloads. +func parseRecvCmsg(hdr *msghdr, wantGRO, wantECN bool, isV4 bool) (gso int, ecn byte) { + controllen := int(hdr.Controllen) + if controllen < unix.SizeofCmsghdr || hdr.Control == nil { + return 0, 0 + } + ctrl := unsafe.Slice(hdr.Control, controllen) + off := 0 + for off+unix.SizeofCmsghdr <= len(ctrl) { + ch := (*unix.Cmsghdr)(unsafe.Pointer(&ctrl[off])) + clen := int(ch.Len) + if clen < unix.SizeofCmsghdr || off+clen > len(ctrl) { + return gso, ecn + } + dataOff := off + unix.CmsgLen(0) + switch { + case wantGRO && ch.Level == unix.SOL_UDP && ch.Type == unix.UDP_GRO: + if dataOff+udpGROCmsgPayload <= len(ctrl) { + gso = int(int32(binary.NativeEndian.Uint32(ctrl[dataOff : dataOff+udpGROCmsgPayload]))) + } + case wantECN && isV4 && ch.Level == unix.IPPROTO_IP && ch.Type == unix.IP_TOS: + // IP_TOS arrives as a single byte; only the low 2 bits are ECN. + if dataOff+1 <= len(ctrl) { + ecn = ctrl[dataOff] & 0x03 + } + case wantECN && !isV4 && ch.Level == unix.IPPROTO_IPV6 && ch.Type == unix.IPV6_TCLASS: + // IPV6_TCLASS arrives as a 4-byte int; ECN is the low 2 bits. + if dataOff+4 <= len(ctrl) { + ecn = byte(binary.NativeEndian.Uint32(ctrl[dataOff:dataOff+4])) & 0x03 + } + } + // Advance by the aligned cmsg space. + off += unix.CmsgSpace(clen - unix.CmsgLen(0)) + } + return gso, ecn +} + func (u *StdConn) ListenOut(r EncReader, flush func()) error { if u.batch == 1 { return u.listenOutSingle(r, flush) @@ -243,19 +571,255 @@ func (u *StdConn) WriteTo(b []byte, ip netip.AddrPort) error { return err } -func (u *StdConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort) error { +// WriteBatch sends bufs via sendmmsg(2) using the preallocated scratch on +// StdConn. Consecutive packets to the same destination with matching segment +// sizes (all but possibly the last) are coalesced into a single mmsghdr entry +// carrying a UDP_SEGMENT cmsg, so one syscall can mix runs of GSO superpackets +// with plain one-off datagrams. Without GSO support every packet is its own +// entry, matching the prior behaviour. +// +// Chunks larger than the scratch are processed across multiple syscalls. If +// sendmmsg returns an error AND zero entries went out we fall back to +// per-packet WriteTo for that chunk so the caller still gets best-effort +// delivery; on a partial-success error we just replay the remainder. +func (u *StdConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort, ecns []byte) error { if len(bufs) != len(addrs) { return fmt.Errorf("WriteBatch: len(bufs)=%d != len(addrs)=%d", len(bufs), len(addrs)) } - //todo use sendmmsg - for i := 0; i < len(bufs); i++ { - if _, err := u.udpConn.WriteToUDPAddrPort(bufs[i], addrs[i]); err != nil { - return err + if ecns != nil && len(ecns) != len(bufs) { + return fmt.Errorf("WriteBatch: len(ecns)=%d != len(bufs)=%d", len(ecns), len(bufs)) + } + + // Callers deliver same-destination packets contiguously and in counter + // order, so we run the GSO planner directly without a pre-sort. A + // sorting pass measurably hurt throughput in microbenchmarks while + // providing no observed reordering benefit. + + i := 0 + for i < len(bufs) { + baseI := i + entry := 0 + iovIdx := 0 + for entry < len(u.writeMsgs) && i < len(bufs) { + iovBudget := len(u.writeIovs) - iovIdx + if iovBudget < 1 { + break + } + runLen, segSize := u.planRun(bufs, addrs, ecns, i, iovBudget) + if runLen == 0 { + break + } + + for k := 0; k < runLen; k++ { + b := bufs[i+k] + if len(b) == 0 { + u.writeIovs[iovIdx+k].Base = nil + setIovLen(&u.writeIovs[iovIdx+k], 0) + } else { + u.writeIovs[iovIdx+k].Base = &b[0] + setIovLen(&u.writeIovs[iovIdx+k], len(b)) + } + } + + nlen, err := writeSockaddr(u.writeNames[entry], addrs[i], u.isV4) + if err != nil { + return err + } + + hdr := &u.writeMsgs[entry].Hdr + hdr.Iov = &u.writeIovs[iovIdx] + setMsgIovlen(hdr, runLen) + hdr.Namelen = uint32(nlen) + + var ecn byte + if ecns != nil { + ecn = ecns[i] + } + u.writeEntryCmsg(entry, runLen, segSize, ecn) + + i += runLen + iovIdx += runLen + u.writeEntryEnd[entry] = i + entry++ } + + if entry == 0 { + return fmt.Errorf("sendmmsg: no progress") + } + + sent, serr := u.sendmmsg(entry) + if serr != nil && sent <= 0 { + // Nothing went out for this chunk; fall back to WriteTo for each + // packet that was queued this iteration. We only enter this path + // when sendmmsg returned an error AND zero entries succeeded — + // otherwise the partial-success advance below replays only the + // remainder, avoiding duplicates of already-sent packets. + // + // sent=-1 from sendmmsg means message 0 itself failed (partial + // success returns the count instead), so log entry 0's parameters + // — that's the entry the kernel rejected. + hdr0 := &u.writeMsgs[0].Hdr + runLen0 := u.writeEntryEnd[0] - baseI + seg0 := len(bufs[baseI]) + ecn0 := byte(0) + if ecns != nil { + ecn0 = ecns[baseI] + } + u.l.Warn("sendmmsg had problem", + "sent", sent, "err", serr, + "entries", entry, + "entry0_runLen", runLen0, + "entry0_segSize", seg0, + "entry0_iovlen", hdr0.Iovlen, + "entry0_controllen", hdr0.Controllen, + "entry0_namelen", hdr0.Namelen, + "entry0_ecn", ecn0, + "entry0_dst", addrs[baseI], + "isV4", u.isV4, + "gso", u.gsoSupported, + "gro", u.groSupported, + ) + for k := baseI; k < i; k++ { + if werr := u.WriteTo(bufs[k], addrs[k]); werr != nil { + return werr + } + } + continue + } + if sent == 0 { + return fmt.Errorf("sendmmsg made no progress") + } + // Rewind i to the end of the last successfully sent entry. For a + // full-success send this leaves i unchanged; for a partial send it + // replays the remainder on the next outer-loop iteration. + i = u.writeEntryEnd[sent-1] } return nil } +// planRun groups consecutive packets starting at `start` that can be sent as +// a single UDP GSO superpacket (one sendmmsg entry with UDP_SEGMENT cmsg). +// A run of length 1 means the entry carries no UDP_SEGMENT cmsg and the +// kernel treats it as a plain datagram. Returns the run length and the +// per-segment size (which equals len(bufs[start])). Without GSO support +// every call returns runLen=1. Outer ECN (when ecns != nil) is also a run +// boundary — the kernel stamps one outer codepoint per sendmsg entry, so +// mixing values inside a run would lose information. +func (u *StdConn) planRun(bufs [][]byte, addrs []netip.AddrPort, ecns []byte, start, iovBudget int) (int, int) { + if start >= len(bufs) || iovBudget < 1 { + return 0, 0 + } + segSize := len(bufs[start]) + if !u.gsoSupported || segSize == 0 || segSize > maxGSOBytes { + return 1, segSize + } + dst := addrs[start] + var ecn byte + if ecns != nil { + ecn = ecns[start] + } + maxLen := maxGSOSegments + if iovBudget < maxLen { + maxLen = iovBudget + } + runLen := 1 + total := segSize + for runLen < maxLen && start+runLen < len(bufs) { + nextLen := len(bufs[start+runLen]) + if nextLen == 0 || nextLen > segSize { + break + } + if addrs[start+runLen] != dst { + break + } + if ecns != nil && ecns[start+runLen] != ecn { + break + } + if total+nextLen > maxGSOBytes { + break + } + total += nextLen + runLen++ + if nextLen < segSize { + // A short packet must be the last in the run. + break + } + } + return runLen, segSize +} + +// writeEntryCmsg sets up the per-mmsghdr Hdr.Control / Hdr.Controllen for one +// entry. It writes the UDP_SEGMENT payload when runLen >= 2 and the +// IP_TOS/IPV6_TCLASS payload when ecn != 0, then points hdr.Control at the +// smallest contiguous span that covers whichever cmsg(s) actually apply. +func (u *StdConn) writeEntryCmsg(entry, runLen, segSize int, ecn byte) { + hdr := &u.writeMsgs[entry].Hdr + useSeg := runLen >= 2 + useEcn := ecn != 0 + base := entry * u.writeCmsgSpace + + if useSeg { + dataOff := base + unix.CmsgLen(0) + binary.NativeEndian.PutUint16(u.writeCmsg[dataOff:dataOff+2], uint16(segSize)) + } + if useEcn { + dataOff := base + u.writeCmsgSegSpace + unix.CmsgLen(0) + binary.NativeEndian.PutUint32(u.writeCmsg[dataOff:dataOff+4], uint32(ecn)) + } + + switch { + case useSeg && useEcn: + hdr.Control = &u.writeCmsg[base] + setMsgControllen(hdr, u.writeCmsgSpace) + case useSeg: + hdr.Control = &u.writeCmsg[base] + setMsgControllen(hdr, u.writeCmsgSegSpace) + case useEcn: + hdr.Control = &u.writeCmsg[base+u.writeCmsgSegSpace] + setMsgControllen(hdr, u.writeCmsgEcnSpace) + default: + hdr.Control = nil + setMsgControllen(hdr, 0) + } +} + +// sendmmsg issues sendmmsg(2) over u.rawConn against the first n entries +// of u.writeMsgs. Routes through u.rawSend so the per-call kernel callback +// stays alloc-free. +func (u *StdConn) sendmmsg(n int) (int, error) { + return u.rawSend.send(u.rawConn, n) +} + +// writeSockaddr encodes addr into buf (which must be at least +// SizeofSockaddrInet6 bytes). Returns the number of bytes used. If isV4 is +// true and addr is not a v4 (or v4-in-v6) address, returns an error. +func writeSockaddr(buf []byte, addr netip.AddrPort, isV4 bool) (int, error) { + ap := addr.Addr().Unmap() + if isV4 { + if !ap.Is4() { + return 0, ErrInvalidIPv6RemoteForSocket + } + // struct sockaddr_in: { sa_family_t(2), in_port_t(2, BE), in_addr(4), zero(8) } + // sa_family is host endian. + binary.NativeEndian.PutUint16(buf[0:2], unix.AF_INET) + binary.BigEndian.PutUint16(buf[2:4], addr.Port()) + ip4 := ap.As4() + copy(buf[4:8], ip4[:]) + for j := 8; j < 16; j++ { + buf[j] = 0 + } + return unix.SizeofSockaddrInet4, nil + } + // struct sockaddr_in6: { sa_family_t(2), in_port_t(2, BE), flowinfo(4), in6_addr(16), scope_id(4) } + binary.NativeEndian.PutUint16(buf[0:2], unix.AF_INET6) + binary.BigEndian.PutUint16(buf[2:4], addr.Port()) + binary.NativeEndian.PutUint32(buf[4:8], 0) + ip6 := addr.Addr().As16() + copy(buf[8:24], ip6[:]) + binary.NativeEndian.PutUint32(buf[24:28], 0) + return unix.SizeofSockaddrInet6, nil +} + func (u *StdConn) ReloadConfig(c *config.C) { b := c.GetInt("listen.read_buffer", 0) if b > 0 { diff --git a/udp/udp_linux_32.go b/udp/udp_linux_32.go index e253784b..0f153a49 100644 --- a/udp/udp_linux_32.go +++ b/udp/udp_linux_32.go @@ -30,11 +30,16 @@ type rawMessage struct { Len uint32 } -func (u *StdConn) PrepareRawMessages(n, bufSize int) ([]rawMessage, [][]byte, [][]byte) { +func (u *StdConn) PrepareRawMessages(n, bufSize, cmsgSpace int) ([]rawMessage, [][]byte, [][]byte, []byte) { msgs := make([]rawMessage, n) buffers := make([][]byte, n) names := make([][]byte, n) + var cmsgs []byte + if cmsgSpace > 0 { + cmsgs = make([]byte, n*cmsgSpace) + } + for i := range msgs { buffers[i] = make([]byte, bufSize) names[i] = make([]byte, unix.SizeofSockaddrInet6) @@ -48,9 +53,14 @@ func (u *StdConn) PrepareRawMessages(n, bufSize int) ([]rawMessage, [][]byte, [] msgs[i].Hdr.Name = &names[i][0] msgs[i].Hdr.Namelen = uint32(len(names[i])) + + if cmsgSpace > 0 { + msgs[i].Hdr.Control = &cmsgs[i*cmsgSpace] + msgs[i].Hdr.Controllen = uint32(cmsgSpace) + } } - return msgs, buffers, names + return msgs, buffers, names, cmsgs } func setIovLen(v *iovec, n int) { diff --git a/udp/udp_linux_64.go b/udp/udp_linux_64.go index d18ca281..dc373538 100644 --- a/udp/udp_linux_64.go +++ b/udp/udp_linux_64.go @@ -33,11 +33,16 @@ type rawMessage struct { Pad0 [4]byte } -func (u *StdConn) PrepareRawMessages(n, bufSize int) ([]rawMessage, [][]byte, [][]byte) { +func (u *StdConn) PrepareRawMessages(n, bufSize, cmsgSpace int) ([]rawMessage, [][]byte, [][]byte, []byte) { msgs := make([]rawMessage, n) buffers := make([][]byte, n) names := make([][]byte, n) + var cmsgs []byte + if cmsgSpace > 0 { + cmsgs = make([]byte, n*cmsgSpace) + } + for i := range msgs { buffers[i] = make([]byte, bufSize) names[i] = make([]byte, unix.SizeofSockaddrInet6) @@ -51,9 +56,14 @@ func (u *StdConn) PrepareRawMessages(n, bufSize int) ([]rawMessage, [][]byte, [] msgs[i].Hdr.Name = &names[i][0] msgs[i].Hdr.Namelen = uint32(len(names[i])) + + if cmsgSpace > 0 { + msgs[i].Hdr.Control = &cmsgs[i*cmsgSpace] + msgs[i].Hdr.Controllen = uint64(cmsgSpace) + } } - return msgs, buffers, names + return msgs, buffers, names, cmsgs } func setIovLen(v *iovec, n int) { diff --git a/udp/udp_rio_windows.go b/udp/udp_rio_windows.go index 007384b1..a95ad3d0 100644 --- a/udp/udp_rio_windows.go +++ b/udp/udp_rio_windows.go @@ -161,7 +161,7 @@ func (u *RIOConn) ListenOut(r EncReader, flush func()) error { continue } - r(netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), buffer[:n]) + r(netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), buffer[:n], RxMeta{}) flush() } } @@ -317,7 +317,7 @@ func (u *RIOConn) WriteTo(buf []byte, ip netip.AddrPort) error { return winrio.SendEx(u.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0) } -func (u *RIOConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort) error { +func (u *RIOConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort, _ []byte) error { for i, b := range bufs { if err := u.WriteTo(b, addrs[i]); err != nil { return err diff --git a/udp/udp_tester.go b/udp/udp_tester.go index 028015e2..6b877b71 100644 --- a/udp/udp_tester.go +++ b/udp/udp_tester.go @@ -157,7 +157,7 @@ func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error { return nil } } -func (u *TesterConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort) error { +func (u *TesterConn) WriteBatch(bufs [][]byte, addrs []netip.AddrPort, _ []byte) error { for i, b := range bufs { if err := u.WriteTo(b, addrs[i]); err != nil { return err @@ -172,7 +172,7 @@ func (u *TesterConn) ListenOut(r EncReader, flush func()) error { case <-u.done: return os.ErrClosed case p := <-u.RxPackets: - r(p.From, p.Data) + r(p.From, p.Data, RxMeta{}) p.Release() flush() }